diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 0ce08c6eea751..08b2561e8d630 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -143,6 +143,16 @@ Changelog :pr:`13649` by :user:`Samuel Ronsin `, initiated by :user:`Patrick O'Reilly `. + +:mod:`sklearn.neighbors` +........................ + +- |API| :class:`neighbors.KNeighborsRegressor` now accepts + :class:`metric.DistanceMetric` objects directly via the `metric` keyword + argument allowing for the use of accelerated third-party + :class:`metric.DistanceMetric` objects. + :pr:`26267` by :user:`Meekail Zain ` + :mod:`sklearn.metrics` ...................... diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index 7edc64c59a050..dd66299223efe 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -36,7 +36,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): X, Y, intp_t k, - str metric="euclidean", + metric="euclidean", chunk_size=None, dict metric_kwargs=None, str strategy=None, diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 2fb258741c555..8899f49330440 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -53,7 +53,7 @@ cdef class DatasetsPair{{name_suffix}}: cls, X, Y, - str metric="euclidean", + metric="euclidean", dict metric_kwargs=None, ) -> DatasetsPair{{name_suffix}}: """Return the DatasetsPair implementation for the given arguments. @@ -70,7 +70,7 @@ cdef class DatasetsPair{{name_suffix}}: If provided as a ndarray, it must be C-contiguous. If provided as a sparse matrix, it must be in CSR format. - metric : str, default='euclidean' + metric : str or DistanceMetric object, default='euclidean' The distance metric to compute between rows of X and Y. The default metric is a fast implementation of the Euclidean metric. For a list of available metrics, see the documentation diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 796f15ab6fca0..df13f31dc2e51 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -5,7 +5,11 @@ from scipy.sparse import issparse from ... import get_config -from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING64 +from .._dist_metrics import ( + BOOL_METRICS, + METRIC_MAPPING64, + DistanceMetric, +) from ._argkmin import ( ArgKmin32, ArgKmin64, @@ -117,7 +121,7 @@ def is_valid_sparse_matrix(X): and (is_numpy_c_ordered(Y) or is_valid_sparse_matrix(Y)) and X.dtype == Y.dtype and X.dtype in (np.float32, np.float64) - and metric in cls.valid_metrics() + and (metric in cls.valid_metrics() or isinstance(metric, DistanceMetric)) ) return is_usable diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index dcff18e10fa48..519db9bead3d3 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -19,7 +19,7 @@ from ..base import BaseEstimator, MultiOutputMixin, is_classifier from ..exceptions import DataConversionWarning, EfficiencyWarning -from ..metrics import pairwise_distances_chunked +from ..metrics import DistanceMetric, pairwise_distances_chunked from ..metrics._pairwise_distances_reduction import ( ArgKmin, RadiusNeighbors, @@ -414,7 +414,11 @@ def _check_algorithm_metric(self): if self.algorithm == "auto": if self.metric == "precomputed": alg_check = "brute" - elif callable(self.metric) or self.metric in VALID_METRICS["ball_tree"]: + elif ( + callable(self.metric) + or self.metric in VALID_METRICS["ball_tree"] + or isinstance(self.metric, DistanceMetric) + ): alg_check = "ball_tree" else: alg_check = "brute" @@ -430,7 +434,9 @@ def _check_algorithm_metric(self): "in very poor performance." % self.metric ) - elif self.metric not in VALID_METRICS[alg_check]: + elif self.metric not in VALID_METRICS[alg_check] and not isinstance( + self.metric, DistanceMetric + ): raise ValueError( "Metric '%s' not valid. Use " "sorted(sklearn.neighbors.VALID_METRICS['%s']) " @@ -563,9 +569,11 @@ def _fit(self, X, y=None): if self.algorithm not in ("auto", "brute"): warnings.warn("cannot use tree with sparse input: using brute force") - if self.effective_metric_ not in VALID_METRICS_SPARSE[ - "brute" - ] and not callable(self.effective_metric_): + if ( + self.effective_metric_ not in VALID_METRICS_SPARSE["brute"] + and not callable(self.effective_metric_) + and not isinstance(self.effective_metric_, DistanceMetric) + ): raise ValueError( "Metric '%s' not valid for sparse input. " "Use sorted(sklearn.neighbors." diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index b9b7f4030d02c..2897c1ce409e8 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -15,6 +15,7 @@ import numpy as np from ..base import RegressorMixin, _fit_context +from ..metrics import DistanceMetric from ..utils._param_validation import StrOptions from ._base import KNeighborsMixin, NeighborsBase, RadiusNeighborsMixin, _get_weights @@ -71,7 +72,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used. - metric : str or callable, default='minkowski' + metric : str, DistanceMetric object or callable, default='minkowski' Metric to use for distance computation. Default is "minkowski", which results in the standard Euclidean distance when p = 2. See the documentation of `scipy.spatial.distance @@ -89,6 +90,9 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): between those vectors. This works for Scipy's metrics, but is less efficient than passing the metric name as a string. + If metric is a DistanceMetric object, it will be passed directly to + the underlying computation routines. + metric_params : dict, default=None Additional keyword arguments for the metric function. @@ -164,6 +168,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): **NeighborsBase._parameter_constraints, "weights": [StrOptions({"uniform", "distance"}), callable, None], } + _parameter_constraints["metric"].append(DistanceMetric) _parameter_constraints.pop("radius") def __init__( diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 405ac3a6d0847..c81132d795f56 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -24,6 +24,9 @@ ) from sklearn.base import clone from sklearn.exceptions import DataConversionWarning, EfficiencyWarning, NotFittedError +from sklearn.metrics._dist_metrics import ( + DistanceMetric, +) from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS from sklearn.metrics.tests.test_pairwise_distances_reduction import ( @@ -69,6 +72,7 @@ COMMON_VALID_METRICS = sorted( set.intersection(*map(set, neighbors.VALID_METRICS.values())) ) # type: ignore + P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) @@ -76,6 +80,25 @@ neighbors.kneighbors_graph = ignore_warnings(neighbors.kneighbors_graph) neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) +# A list containing metrics where the string specifies the use of the +# DistanceMetric object directly (as resolved in _parse_metric) +DISTANCE_METRIC_OBJS = ["DM_euclidean"] + + +def _parse_metric(metric: str, dtype=None): + """ + Helper function for properly building a type-specialized DistanceMetric instances. + + Constructs a type-specialized DistanceMetric instance from a string + beginning with "DM_" while allowing a pass-through for other metric-specifying + strings. This is necessary since we wish to parameterize dtype independent of + metric, yet DistanceMetric requires it for construction. + + """ + if metric[:3] == "DM_": + return DistanceMetric.get_metric(metric[3:], dtype=dtype) + return metric + def _generate_test_params_for(metric: str, n_features: int): """Return list of DistanceMetric kwargs for tests.""" @@ -129,7 +152,7 @@ def _weight_func(dist): ], ) @pytest.mark.parametrize("query_is_train", [False, True]) -@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa def test_unsupervised_kneighbors( global_dtype, n_samples, @@ -143,6 +166,8 @@ def test_unsupervised_kneighbors( # on their common metrics, with and without returning # distances + metric = _parse_metric(metric, global_dtype) + # Redefining the rng locally to use the same generated X local_rng = np.random.RandomState(0) X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -157,6 +182,12 @@ def test_unsupervised_kneighbors( results = [] for algorithm in ALGORITHMS: + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: + if "tree" in algorithm: # pragma: nocover + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects." + ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, metric=metric ) @@ -206,7 +237,7 @@ def test_unsupervised_kneighbors( (1000, 5, 100), ], ) -@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa @pytest.mark.parametrize("n_neighbors, radius", [(1, 100), (50, 500), (100, 1000)]) @pytest.mark.parametrize( "NeighborsMixinSubclass", @@ -230,6 +261,19 @@ def test_neigh_predictions_algorithm_agnosticity( # The different algorithms must return identical predictions results # on their common metrics. + metric = _parse_metric(metric, global_dtype) + if isinstance(metric, DistanceMetric): + if "Classifier" in NeighborsMixinSubclass.__name__: + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for" + " classifiers." + ) + if "Radius" in NeighborsMixinSubclass.__name__: + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for" + " radius-neighbor estimators." + ) + # Redefining the rng locally to use the same generated X local_rng = np.random.RandomState(0) X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -244,6 +288,12 @@ def test_neigh_predictions_algorithm_agnosticity( ) for algorithm in ALGORITHMS: + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: + if "tree" in algorithm: # pragma: nocover + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects." + ) neigh = NeighborsMixinSubclass(parameter, algorithm=algorithm, metric=metric) neigh.fit(X, y) @@ -985,15 +1035,26 @@ def test_query_equidistant_kth_nn(algorithm): @pytest.mark.parametrize( ["algorithm", "metric"], - [ - ("ball_tree", "euclidean"), - ("kd_tree", "euclidean"), + list( + product( + ("kd_tree", "ball_tree", "brute"), + ("euclidean", *DISTANCE_METRIC_OBJS), + ) + ) + + [ ("brute", "euclidean"), ("brute", "precomputed"), ], ) def test_radius_neighbors_sort_results(algorithm, metric): # Test radius_neighbors[_graph] output when sort_result is True + + metric = _parse_metric(metric, np.float64) + if isinstance(metric, DistanceMetric): + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for radius-neighbor" + " estimators." + ) n_samples = 10 rng = np.random.RandomState(42) X = rng.random_sample((n_samples, 4)) @@ -1560,11 +1621,14 @@ def test_nearest_neighbors_validate_params(): neighbors.VALID_METRICS["brute"] ) - set(["pyfunc", *BOOL_METRICS]) - ), + ) + + DISTANCE_METRIC_OBJS, ) def test_neighbors_metrics( global_dtype, metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5 ): + metric = _parse_metric(metric, global_dtype) + # Test computing the neighbors for various metrics algorithms = ["brute", "ball_tree", "kd_tree"] X_train = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -1574,12 +1638,21 @@ def test_neighbors_metrics( for metric_params in metric_params_list: # Some metric (e.g. Weighted minkowski) are not supported by KDTree - exclude_kd_tree = metric not in neighbors.VALID_METRICS["kd_tree"] or ( - "minkowski" in metric and "w" in metric_params + exclude_kd_tree = ( + False + if isinstance(metric, DistanceMetric) + else metric not in neighbors.VALID_METRICS["kd_tree"] + or ("minkowski" in metric and "w" in metric_params) ) results = {} p = metric_params.pop("p", 2) for algorithm in algorithms: + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: + if "tree" in algorithm: # pragma: nocover + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects." + ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, @@ -1684,10 +1757,14 @@ def custom_metric(x1, x2): assert_allclose(dist1, dist2) -@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) +@pytest.mark.parametrize( + "metric", neighbors.VALID_METRICS["brute"] + DISTANCE_METRIC_OBJS +) def test_valid_brute_metric_for_auto_algorithm( global_dtype, metric, n_samples=20, n_features=12 ): + metric = _parse_metric(metric, global_dtype) + X = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) Xcsr = csr_matrix(X)