From d192333d2f7e4a6bc4267be127dd03e163995b8f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 21 Apr 2023 09:59:21 -0400 Subject: [PATCH 01/11] Minimal changes --- sklearn/metrics/__init__.py | 3 ++- .../_argkmin.pyx.tp | 2 +- .../_datasets_pair.pyx.tp | 4 ++-- .../_dispatcher.py | 6 +++++- sklearn/neighbors/_base.py | 19 +++++++++++++++---- sklearn/neighbors/_regression.py | 5 ++++- 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 4224bfbb9c04c..7acc995c3aa21 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -37,7 +37,7 @@ from ._classification import brier_score_loss from ._classification import multilabel_confusion_matrix -from ._dist_metrics import DistanceMetric +from ._dist_metrics import DistanceMetric, DistanceMetric32 from . import cluster from .cluster import adjusted_mutual_info_score @@ -122,6 +122,7 @@ "DetCurveDisplay", "det_curve", "DistanceMetric", + "DistanceMetric32", "euclidean_distances", "explained_variance_score", "f1_score", diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index 7edc64c59a050..dd66299223efe 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -36,7 +36,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): X, Y, intp_t k, - str metric="euclidean", + metric="euclidean", chunk_size=None, dict metric_kwargs=None, str strategy=None, diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 5569c1f231d62..489ba7e85b814 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -57,7 +57,7 @@ cdef class DatasetsPair{{name_suffix}}: cls, X, Y, - str metric="euclidean", + metric="euclidean", dict metric_kwargs=None, ) -> DatasetsPair{{name_suffix}}: """Return the DatasetsPair implementation for the given arguments. @@ -74,7 +74,7 @@ cdef class DatasetsPair{{name_suffix}}: If provided as a ndarray, it must be C-contiguous. If provided as a sparse matrix, it must be in CSR format. - metric : str, default='euclidean' + metric : str or DistanceMetric object, default='euclidean' The distance metric to compute between rows of X and Y. The default metric is a fast implementation of the Euclidean metric. For a list of available metrics, see the documentation diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 73d98f2ebe6b2..1ab6a38f6ba0c 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -7,6 +7,7 @@ from scipy.sparse import isspmatrix_csr, issparse from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING +from .. import DistanceMetric, DistanceMetric32 from ._base import _sqeuclidean_row_norms32, _sqeuclidean_row_norms64 from ._argkmin import ( @@ -122,7 +123,10 @@ def is_valid_sparse_matrix(X): and (is_numpy_c_ordered(Y) or is_valid_sparse_matrix(Y)) and X.dtype == Y.dtype and X.dtype in (np.float32, np.float64) - and metric in cls.valid_metrics() + and ( + metric in cls.valid_metrics() + or isinstance(metric, (DistanceMetric, DistanceMetric32)) + ) ) return is_usable diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 9af85a38f0b6c..00b64a5d0c0d1 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -23,6 +23,7 @@ from ..base import BaseEstimator, MultiOutputMixin from ..base import is_classifier from ..metrics import pairwise_distances_chunked +from ..metrics import DistanceMetric, DistanceMetric32 from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS from ..metrics._pairwise_distances_reduction import ( ArgKmin, @@ -390,7 +391,12 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})], "leaf_size": [Interval(Integral, 1, None, closed="left")], "p": [Interval(Real, 0, None, closed="right"), None], - "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], + "metric": [ + StrOptions(set(itertools.chain(*VALID_METRICS.values()))), + callable, + DistanceMetric, + DistanceMetric32, + ], "metric_params": [dict, None], "n_jobs": [Integral, None], } @@ -420,7 +426,11 @@ def _check_algorithm_metric(self): if self.algorithm == "auto": if self.metric == "precomputed": alg_check = "brute" - elif callable(self.metric) or self.metric in VALID_METRICS["ball_tree"]: + elif ( + callable(self.metric) + or self.metric in VALID_METRICS["ball_tree"] + or isinstance(self.metric, (DistanceMetric, DistanceMetric32)) + ): alg_check = "ball_tree" else: alg_check = "brute" @@ -436,7 +446,9 @@ def _check_algorithm_metric(self): "in very poor performance." % self.metric ) - elif self.metric not in VALID_METRICS[alg_check]: + elif self.metric not in VALID_METRICS[alg_check] and not isinstance( + self.metric, (DistanceMetric, DistanceMetric32) + ): raise ValueError( "Metric '%s' not valid. Use " "sorted(sklearn.neighbors.VALID_METRICS['%s']) " @@ -805,7 +817,6 @@ class from an array representing our data set and ask who's "n_neighbors does not take %s value, enter integer value" % type(n_neighbors) ) - query_is_train = X is None if query_is_train: X = self._fit_X diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index 003b534074ecd..4d98766ca104c 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -72,7 +72,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used. - metric : str or callable, default='minkowski' + metric : str, DistanceMetric object or callable, default='minkowski' Metric to use for distance computation. Default is "minkowski", which results in the standard Euclidean distance when p = 2. See the documentation of `scipy.spatial.distance @@ -90,6 +90,9 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): between those vectors. This works for Scipy's metrics, but is less efficient than passing the metric name as a string. + If metric is a DistanceMetric object, it will be passed directly to + the underlying computation routines. + metric_params : dict, default=None Additional keyword arguments for the metric function. From e32cd0d265d4576e16b935cc764ee22e95d3e1ea Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 4 May 2023 22:18:03 -0400 Subject: [PATCH 02/11] Updated tests --- sklearn/neighbors/_base.py | 17 +++- sklearn/neighbors/tests/test_neighbors.py | 110 ++++++++++++++++++++-- 2 files changed, 114 insertions(+), 13 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 00b64a5d0c0d1..f455b15fd13d8 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -581,15 +581,24 @@ def _fit(self, X, y=None): if self.algorithm not in ("auto", "brute"): warnings.warn("cannot use tree with sparse input: using brute force") - if self.effective_metric_ not in VALID_METRICS_SPARSE[ - "brute" - ] and not callable(self.effective_metric_): + if ( + self.effective_metric_ not in VALID_METRICS_SPARSE["brute"] + and not callable(self.effective_metric_) + and not isinstance( + self.effective_metric_, (DistanceMetric, DistanceMetric32) + ) + ): raise ValueError( "Metric '%s' not valid for sparse input. " "Use sorted(sklearn.neighbors." "VALID_METRICS_SPARSE['brute']) " "to get valid options. " - "Metric can also be a callable function." % (self.effective_metric_) + "Metric can also be a callable function." + % ( + self.effective_metric_ + if isinstance(self.effective_metric_, str) + else self.effective_metric_.__class__.__name__ + ) ) self._fit_X = X.copy() self._tree = None diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 092b85ad9dcd0..5ea7831b969b5 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -27,6 +27,12 @@ from sklearn.exceptions import EfficiencyWarning from sklearn.exceptions import NotFittedError from sklearn.metrics.pairwise import pairwise_distances +from sklearn.metrics._dist_metrics import ( + METRIC_MAPPING, + METRIC_MAPPING32, + DistanceMetric, + DistanceMetric32, +) from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS from sklearn.metrics.tests.test_pairwise_distances_reduction import ( assert_radius_neighbors_results_equality, @@ -74,6 +80,20 @@ COMMON_VALID_METRICS = sorted( set.intersection(*map(set, neighbors.VALID_METRICS.values())) ) # type: ignore + +# This can be extended to cover all distance metric objects, however that is +# probably unnecessary and would slow down tests significantly. +DISTANCE_METRIC_OBJS = [] +for m in ("euclidean", "manhattan"): + d = {} + for dtype, MAPPING in zip( + (np.float64, np.float32), (METRIC_MAPPING, METRIC_MAPPING32) + ): + metric = MAPPING.get(m, None) + if metric is not None: + d[dtype] = metric + DISTANCE_METRIC_OBJS.append(d) + P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) @@ -82,6 +102,14 @@ neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) +def _parse_metric(metric, dtype=None): + if isinstance(metric, str): + return metric + if isinstance(metric, dict): + return metric[dtype]() + return -1 + + def _generate_test_params_for(metric: str, n_features: int): """Return list of DistanceMetric kwargs for tests.""" @@ -145,7 +173,7 @@ def _weight_func(dist): ], ) @pytest.mark.parametrize("query_is_train", [False, True]) -@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa def test_unsupervised_kneighbors( global_dtype, n_samples, @@ -159,6 +187,10 @@ def test_unsupervised_kneighbors( # on their common metrics, with and without returning # distances + # Handle the case where metric is a dict containing mappings from `dtype` + # to the corresponding `DistanceMetric` objects + metric = _parse_metric(metric, global_dtype) + # Redefining the rng locally to use the same generated X local_rng = np.random.RandomState(0) X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -173,6 +205,12 @@ def test_unsupervised_kneighbors( results = [] for algorithm in ALGORITHMS: + if isinstance(metric, DistanceMetric32): + if "tree" in algorithm: + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric objects" + " (DistanceMetric32)." + ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, metric=metric ) @@ -222,7 +260,7 @@ def test_unsupervised_kneighbors( (1000, 5, 100), ], ) -@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa @pytest.mark.parametrize("n_neighbors, radius", [(1, 100), (50, 500), (100, 1000)]) @pytest.mark.parametrize( "NeighborsMixinSubclass", @@ -246,6 +284,21 @@ def test_neigh_predictions_algorithm_agnosticity( # The different algorithms must return identical predictions results # on their common metrics. + # Handle the case where metric is a dict containing mappings from `dtype` + # to the corresponding `DistanceMetric` objects + metric = _parse_metric(metric, global_dtype) + if isinstance(metric, (DistanceMetric, DistanceMetric32)): + if "Classifier" in NeighborsMixinSubclass.__name__: + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for" + " classifiers." + ) + if "Radius" in NeighborsMixinSubclass.__name__: + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for" + " radius-neighbor estimators." + ) + # Redefining the rng locally to use the same generated X local_rng = np.random.RandomState(0) X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -260,6 +313,12 @@ def test_neigh_predictions_algorithm_agnosticity( ) for algorithm in ALGORITHMS: + if isinstance(metric, DistanceMetric32): + if "tree" in algorithm: + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric objects" + " (DistanceMetric32)." + ) neigh = NeighborsMixinSubclass(parameter, algorithm=algorithm, metric=metric) neigh.fit(X, y) @@ -1001,15 +1060,28 @@ def test_query_equidistant_kth_nn(algorithm): @pytest.mark.parametrize( ["algorithm", "metric"], - [ - ("ball_tree", "euclidean"), - ("kd_tree", "euclidean"), + list( + product( + ("kd_tree", "ball_tree", "brute"), + ("euclidean", *DISTANCE_METRIC_OBJS), + ) + ) + + [ ("brute", "euclidean"), ("brute", "precomputed"), ], ) def test_radius_neighbors_sort_results(algorithm, metric): # Test radius_neighbors[_graph] output when sort_result is True + + # Handle the case where metric is a dict containing mappings from `dtype` + # to the corresponding `DistanceMetric` objects + metric = _parse_metric(metric, np.float64) + if isinstance(metric, (DistanceMetric, DistanceMetric32)): + pytest.skip( + "Metrics of type `DistanceMetric` are not yet supported for radius-neighbor" + " estimators." + ) n_samples = 10 rng = np.random.RandomState(42) X = rng.random_sample((n_samples, 4)) @@ -1578,11 +1650,16 @@ def test_nearest_neighbors_validate_params(): neighbors.VALID_METRICS["brute"] ) - set(["pyfunc", *BOOL_METRICS]) - ), + ) + + DISTANCE_METRIC_OBJS, ) def test_neighbors_metrics( global_dtype, metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5 ): + # Handle the case where metric is a dict containing mappings from `dtype` + # to the corresponding `DistanceMetric` objects + metric = _parse_metric(metric, global_dtype) + # Test computing the neighbors for various metrics algorithms = ["brute", "ball_tree", "kd_tree"] X_train = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) @@ -1592,12 +1669,21 @@ def test_neighbors_metrics( for metric_params in metric_params_list: # Some metric (e.g. Weighted minkowski) are not supported by KDTree - exclude_kd_tree = metric not in neighbors.VALID_METRICS["kd_tree"] or ( - "minkowski" in metric and "w" in metric_params + exclude_kd_tree = ( + False + if isinstance(metric, DistanceMetric) + else metric not in neighbors.VALID_METRICS["kd_tree"] + or ("minkowski" in metric and "w" in metric_params) ) results = {} p = metric_params.pop("p", 2) for algorithm in algorithms: + if isinstance(metric, DistanceMetric32): + if "tree" in algorithm: + pytest.skip( + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects (DistanceMetric32)." + ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, @@ -1726,10 +1812,16 @@ def custom_metric(x1, x2): # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") -@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) +@pytest.mark.parametrize( + "metric", neighbors.VALID_METRICS["brute"] + DISTANCE_METRIC_OBJS +) def test_valid_brute_metric_for_auto_algorithm( global_dtype, metric, n_samples=20, n_features=12 ): + # Handle the case where metric is a dict containing mappings from `dtype` + # to the corresponding `DistanceMetric` objects + metric = _parse_metric(metric, global_dtype) + X = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) Xcsr = csr_matrix(X) From 0923b93e8250cc6666a7f701dadf32686899c834 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 4 May 2023 23:19:44 -0400 Subject: [PATCH 03/11] Removed extraneous return --- sklearn/neighbors/tests/test_neighbors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 5ea7831b969b5..de1a6fcaa63c1 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -107,7 +107,6 @@ def _parse_metric(metric, dtype=None): return metric if isinstance(metric, dict): return metric[dtype]() - return -1 def _generate_test_params_for(metric: str, n_features: int): From 1312e0315a983e10a6e3d5379726f8ea92ac991f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 11 May 2023 22:19:41 -0400 Subject: [PATCH 04/11] Added changelog entry --- doc/whats_new/v1.3.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 12f50ca7fc2b5..bf8b035c2c698 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -437,6 +437,11 @@ Changelog when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics. :pr:`24076` by :user:`Meekail Zain `, :user:`Julien Jerphanion `. +- |API| The :class:`neighbors.KNeighborsRegressor` can now accept `DistanceMetric{32}` + objects directly via the `metric` keyword argument allowing for the use of + accelerated third-party `DistanceMetric{32}` objects. + :pr:`26267` by :user:`Meekail Zain ` + :mod:`sklearn.neural_network` ............................. From 3e6622b2c8930623d63b51d99eea5ae951539d74 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 May 2023 13:39:29 -0400 Subject: [PATCH 05/11] Updated test to appease coverage gods --- sklearn/metrics/pairwise.py | 2 +- sklearn/neighbors/tests/test_neighbors.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 0605c99c9536d..088dd4529bfaf 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1735,7 +1735,7 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): return out -_VALID_METRICS = [ +_VALID_METRICS: list = [ "euclidean", "l2", "l1", diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index de1a6fcaa63c1..f54c02bd1ff34 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -89,9 +89,7 @@ for dtype, MAPPING in zip( (np.float64, np.float32), (METRIC_MAPPING, METRIC_MAPPING32) ): - metric = MAPPING.get(m, None) - if metric is not None: - d[dtype] = metric + d[dtype] = MAPPING[m] DISTANCE_METRIC_OBJS.append(d) P = (1, 2, 3, 4, np.inf) @@ -105,8 +103,7 @@ def _parse_metric(metric, dtype=None): if isinstance(metric, str): return metric - if isinstance(metric, dict): - return metric[dtype]() + return metric[dtype]() def _generate_test_params_for(metric: str, n_features: int): @@ -205,7 +202,7 @@ def test_unsupervised_kneighbors( for algorithm in ALGORITHMS: if isinstance(metric, DistanceMetric32): - if "tree" in algorithm: + if "tree" in algorithm: # pragma: nocover pytest.skip( "Neither KDTree nor BallTree support 32-bit distance metric objects" " (DistanceMetric32)." @@ -313,7 +310,7 @@ def test_neigh_predictions_algorithm_agnosticity( for algorithm in ALGORITHMS: if isinstance(metric, DistanceMetric32): - if "tree" in algorithm: + if "tree" in algorithm: # pragma: nocover pytest.skip( "Neither KDTree nor BallTree support 32-bit distance metric objects" " (DistanceMetric32)." @@ -1678,7 +1675,7 @@ def test_neighbors_metrics( p = metric_params.pop("p", 2) for algorithm in algorithms: if isinstance(metric, DistanceMetric32): - if "tree" in algorithm: + if "tree" in algorithm: # pragma: nocover pytest.skip( "Neither KDTree nor BallTree support 32-bit distance metric" " objects (DistanceMetric32)." From 71326f2e8281b900a3bb80896f7e6fa6bed3b161 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 7 Jul 2023 14:03:36 -0400 Subject: [PATCH 06/11] Updated with new DistanceMetric API --- doc/whats_new/v1.4.rst | 7 ++++--- sklearn/metrics/__init__.py | 3 +-- .../_dispatcher.py | 6 +----- sklearn/neighbors/_base.py | 12 +++++------- sklearn/neighbors/tests/test_neighbors.py | 17 ++++++++--------- 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index bd3429f1354ab..c37177f900c1d 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -93,9 +93,10 @@ Changelog :mod:`sklearn.neighbors` ............................ -- |API| The :class:`neighbors.KNeighborsRegressor` can now accept `DistanceMetric{32}` - objects directly via the `metric` keyword argument allowing for the use of - accelerated third-party `DistanceMetric{32}` objects. +- |API| The :class:`neighbors.KNeighborsRegressor` can now accept + :class:`metric.DistanceMetric` objects directly via the `metric` keyword + argument allowing for the use of accelerated third-party + :class:`metric.DistanceMetric` objects. :pr:`26267` by :user:`Meekail Zain ` Code and Documentation Contributors diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index e492a42d390cc..488c776ae9a86 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -26,7 +26,7 @@ recall_score, zero_one_loss, ) -from ._dist_metrics import DistanceMetric, DistanceMetric32 +from ._dist_metrics import DistanceMetric from ._plot.confusion_matrix import ConfusionMatrixDisplay from ._plot.det_curve import DetCurveDisplay from ._plot.precision_recall_curve import PrecisionRecallDisplay @@ -118,7 +118,6 @@ "DetCurveDisplay", "det_curve", "DistanceMetric", - "DistanceMetric32", "euclidean_distances", "explained_variance_score", "f1_score", diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 27feefdfd04f3..cc51bed7803c8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -9,7 +9,6 @@ BOOL_METRICS, METRIC_MAPPING64, DistanceMetric, - DistanceMetric32, ) from ._argkmin import ( ArgKmin32, @@ -121,10 +120,7 @@ def is_valid_sparse_matrix(X): and (is_numpy_c_ordered(Y) or is_valid_sparse_matrix(Y)) and X.dtype == Y.dtype and X.dtype in (np.float32, np.float64) - and ( - metric in cls.valid_metrics() - or isinstance(metric, (DistanceMetric, DistanceMetric32)) - ) + and (metric in cls.valid_metrics() or isinstance(metric, DistanceMetric)) ) return is_usable diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 95e894e833e0c..478fb9980a86f 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -19,7 +19,7 @@ from ..base import BaseEstimator, MultiOutputMixin, is_classifier from ..exceptions import DataConversionWarning, EfficiencyWarning -from ..metrics import DistanceMetric, DistanceMetric32, pairwise_distances_chunked +from ..metrics import DistanceMetric, pairwise_distances_chunked from ..metrics._pairwise_distances_reduction import ( ArgKmin, RadiusNeighbors, @@ -388,7 +388,6 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable, DistanceMetric, - DistanceMetric32, ], "metric_params": [dict, None], "n_jobs": [Integral, None], @@ -422,7 +421,7 @@ def _check_algorithm_metric(self): elif ( callable(self.metric) or self.metric in VALID_METRICS["ball_tree"] - or isinstance(self.metric, (DistanceMetric, DistanceMetric32)) + or isinstance(self.metric, DistanceMetric) ): alg_check = "ball_tree" else: @@ -440,7 +439,7 @@ def _check_algorithm_metric(self): % self.metric ) elif self.metric not in VALID_METRICS[alg_check] and not isinstance( - self.metric, (DistanceMetric, DistanceMetric32) + self.metric, DistanceMetric ): raise ValueError( "Metric '%s' not valid. Use " @@ -577,9 +576,7 @@ def _fit(self, X, y=None): if ( self.effective_metric_ not in VALID_METRICS_SPARSE["brute"] and not callable(self.effective_metric_) - and not isinstance( - self.effective_metric_, (DistanceMetric, DistanceMetric32) - ) + and not isinstance(self.effective_metric_, DistanceMetric) ): raise ValueError( "Metric '%s' not valid for sparse input. " @@ -810,6 +807,7 @@ class from an array representing our data set and ask who's "n_neighbors does not take %s value, enter integer value" % type(n_neighbors) ) + query_is_train = X is None if query_is_train: X = self._fit_X diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index d3f27f94ea4d0..1c32d977c72df 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -28,7 +28,6 @@ METRIC_MAPPING32, METRIC_MAPPING64, DistanceMetric, - DistanceMetric32, ) from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS @@ -185,11 +184,11 @@ def test_unsupervised_kneighbors( results = [] for algorithm in ALGORITHMS: - if isinstance(metric, DistanceMetric32): + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: if "tree" in algorithm: # pragma: nocover pytest.skip( - "Neither KDTree nor BallTree support 32-bit distance metric objects" - " (DistanceMetric32)." + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects." ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, metric=metric @@ -293,11 +292,11 @@ def test_neigh_predictions_algorithm_agnosticity( ) for algorithm in ALGORITHMS: - if isinstance(metric, DistanceMetric32): + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: if "tree" in algorithm: # pragma: nocover pytest.skip( - "Neither KDTree nor BallTree support 32-bit distance metric objects" - " (DistanceMetric32)." + "Neither KDTree nor BallTree support 32-bit distance metric" + " objects." ) neigh = NeighborsMixinSubclass(parameter, algorithm=algorithm, metric=metric) neigh.fit(X, y) @@ -1656,11 +1655,11 @@ def test_neighbors_metrics( results = {} p = metric_params.pop("p", 2) for algorithm in algorithms: - if isinstance(metric, DistanceMetric32): + if isinstance(metric, DistanceMetric) and global_dtype == np.float32: if "tree" in algorithm: # pragma: nocover pytest.skip( "Neither KDTree nor BallTree support 32-bit distance metric" - " objects (DistanceMetric32)." + " objects." ) neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, From cf3b5ebf0c0be5eb3265e690b1a39ee832c3b88a Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 18 Jul 2023 19:39:47 -0400 Subject: [PATCH 07/11] Moved acceptence of DistanceMetric to specific estimator --- sklearn/metrics/pairwise.py | 2 +- sklearn/neighbors/_base.py | 6 +----- sklearn/neighbors/_regression.py | 2 ++ 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 64f21f6dcfb88..3fc7795876814 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -619,7 +619,7 @@ def _argmin_reduce(dist, start): return dist.argmin(axis=1) -_VALID_METRICS: list = [ +_VALID_METRICS = [ "euclidean", "l2", "l1", diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 478fb9980a86f..b6afe121f7066 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -384,11 +384,7 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})], "leaf_size": [Interval(Integral, 1, None, closed="left")], "p": [Interval(Real, 0, None, closed="right"), None], - "metric": [ - StrOptions(set(itertools.chain(*VALID_METRICS.values()))), - callable, - DistanceMetric, - ], + "metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable], "metric_params": [dict, None], "n_jobs": [Integral, None], } diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index dd66a950c9c7a..2897c1ce409e8 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -15,6 +15,7 @@ import numpy as np from ..base import RegressorMixin, _fit_context +from ..metrics import DistanceMetric from ..utils._param_validation import StrOptions from ._base import KNeighborsMixin, NeighborsBase, RadiusNeighborsMixin, _get_weights @@ -167,6 +168,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase): **NeighborsBase._parameter_constraints, "weights": [StrOptions({"uniform", "distance"}), callable, None], } + _parameter_constraints["metric"].append(DistanceMetric) _parameter_constraints.pop("radius") def __init__( From 41c6257764744e1069d4b32729eaa9979c1ae0d9 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Thu, 20 Jul 2023 12:01:52 -0400 Subject: [PATCH 08/11] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.4.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 39694ed31a02d..565ff23bf37a4 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -133,9 +133,9 @@ Changelog :mod:`sklearn.neighbors` -............................ +........................ -- |API| The :class:`neighbors.KNeighborsRegressor` can now accept +- |API| :class:`neighbors.KNeighborsRegressor` now accepts :class:`metric.DistanceMetric` objects directly via the `metric` keyword argument allowing for the use of accelerated third-party :class:`metric.DistanceMetric` objects. From 4c132a3f7f09b0db4818adde2cd0524deb0b8f58 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 20 Jul 2023 12:17:46 -0400 Subject: [PATCH 09/11] Updated metric parsing for new DistanceMetric API --- sklearn/neighbors/tests/test_neighbors.py | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index f9554550ffbe3..6d9877b4f9a98 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -25,8 +25,6 @@ from sklearn.base import clone from sklearn.exceptions import DataConversionWarning, EfficiencyWarning, NotFittedError from sklearn.metrics._dist_metrics import ( - METRIC_MAPPING32, - METRIC_MAPPING64, DistanceMetric, ) from sklearn.metrics.pairwise import pairwise_distances @@ -75,17 +73,6 @@ set.intersection(*map(set, neighbors.VALID_METRICS.values())) ) # type: ignore -# This can be extended to cover all distance metric objects, however that is -# probably unnecessary and would slow down tests significantly. -DISTANCE_METRIC_OBJS = [] -for m in ("euclidean", "manhattan"): - d = {} - for dtype, MAPPING in zip( - (np.float64, np.float32), (METRIC_MAPPING64, METRIC_MAPPING32) - ): - d[dtype] = MAPPING[m] - DISTANCE_METRIC_OBJS.append(d) - P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) @@ -93,11 +80,24 @@ neighbors.kneighbors_graph = ignore_warnings(neighbors.kneighbors_graph) neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) +# A list containing metrics where the string specifies the use of the +# DistanceMetric object directly (as resolved in _parse_metric) +DISTANCE_METRIC_OBJS = ["DM_euclidean", "DM_manhattan"] -def _parse_metric(metric, dtype=None): - if isinstance(metric, str): - return metric - return metric[dtype]() + +def _parse_metric(metric: str, dtype=None): + """ + Helper function for properly building a type-specialized DistanceMetric instances. + + Constructs a type-specialized DistanceMetric instance from a string + beginning with "DM_" while allowing a pass-through for other metric-specifying + strings. This is necessary since we wish to parameterize dtype independent of + metric, yet DistanceMetric requires it for construction. + + """ + if metric[:3] == "DM_": + return DistanceMetric.get_metric(metric[3:], dtype=dtype) + return metric def _generate_test_params_for(metric: str, n_features: int): From 85d18ea31038afbf81ee073d498c4b148a3246a8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 20 Jul 2023 12:25:16 -0400 Subject: [PATCH 10/11] Added DistanceMetric __repr__ and removed old inline comments --- sklearn/neighbors/_base.py | 7 +------ sklearn/neighbors/tests/test_neighbors.py | 10 ---------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index b6afe121f7066..519db9bead3d3 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -579,12 +579,7 @@ def _fit(self, X, y=None): "Use sorted(sklearn.neighbors." "VALID_METRICS_SPARSE['brute']) " "to get valid options. " - "Metric can also be a callable function." - % ( - self.effective_metric_ - if isinstance(self.effective_metric_, str) - else self.effective_metric_.__class__.__name__ - ) + "Metric can also be a callable function." % (self.effective_metric_) ) self._fit_X = X.copy() self._tree = None diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 6d9877b4f9a98..6af45dfbc8633 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -166,8 +166,6 @@ def test_unsupervised_kneighbors( # on their common metrics, with and without returning # distances - # Handle the case where metric is a dict containing mappings from `dtype` - # to the corresponding `DistanceMetric` objects metric = _parse_metric(metric, global_dtype) # Redefining the rng locally to use the same generated X @@ -263,8 +261,6 @@ def test_neigh_predictions_algorithm_agnosticity( # The different algorithms must return identical predictions results # on their common metrics. - # Handle the case where metric is a dict containing mappings from `dtype` - # to the corresponding `DistanceMetric` objects metric = _parse_metric(metric, global_dtype) if isinstance(metric, DistanceMetric): if "Classifier" in NeighborsMixinSubclass.__name__: @@ -1053,8 +1049,6 @@ def test_query_equidistant_kth_nn(algorithm): def test_radius_neighbors_sort_results(algorithm, metric): # Test radius_neighbors[_graph] output when sort_result is True - # Handle the case where metric is a dict containing mappings from `dtype` - # to the corresponding `DistanceMetric` objects metric = _parse_metric(metric, np.float64) if isinstance(metric, DistanceMetric): pytest.skip( @@ -1633,8 +1627,6 @@ def test_nearest_neighbors_validate_params(): def test_neighbors_metrics( global_dtype, metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5 ): - # Handle the case where metric is a dict containing mappings from `dtype` - # to the corresponding `DistanceMetric` objects metric = _parse_metric(metric, global_dtype) # Test computing the neighbors for various metrics @@ -1771,8 +1763,6 @@ def custom_metric(x1, x2): def test_valid_brute_metric_for_auto_algorithm( global_dtype, metric, n_samples=20, n_features=12 ): - # Handle the case where metric is a dict containing mappings from `dtype` - # to the corresponding `DistanceMetric` objects metric = _parse_metric(metric, global_dtype) X = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) From e4d8413a402f9fb41f3f4290907086ee4f748be0 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Tue, 1 Aug 2023 12:40:16 -0400 Subject: [PATCH 11/11] Update sklearn/neighbors/tests/test_neighbors.py Co-authored-by: Thomas J. Fan --- sklearn/neighbors/tests/test_neighbors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 6af45dfbc8633..c81132d795f56 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -82,7 +82,7 @@ # A list containing metrics where the string specifies the use of the # DistanceMetric object directly (as resolved in _parse_metric) -DISTANCE_METRIC_OBJS = ["DM_euclidean", "DM_manhattan"] +DISTANCE_METRIC_OBJS = ["DM_euclidean"] def _parse_metric(metric: str, dtype=None):