diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index dfd6791d9a3d3..7112c2a697651 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -136,9 +136,13 @@ have the same interface; we'll show an example of using the KD Tree here: Refer to the :class:`KDTree` and :class:`BallTree` class documentation for more information on the options available for nearest neighbors searches, including specification of query strategies, distance metrics, etc. For a list -of available metrics, see the documentation of the :class:`DistanceMetric` class -and the metrics listed in `sklearn.metrics.pairwise.PAIRWISE_DISTANCE_FUNCTIONS`. -Note that the "cosine" metric uses :func:`~sklearn.metrics.pairwise.cosine_distances`. +of valid metrics use :meth:`KDTree.valid_metrics` and :meth:`BallTree.valid_metrics`: + + >>> from sklearn.neighbors import KDTree, BallTree + >>> KDTree.valid_metrics() + ['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity'] + >>> BallTree.valid_metrics() + ['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity', 'seuclidean', 'mahalanobis', 'wminkowski', 'hamming', 'canberra', 'braycurtis', 'matching', 'jaccard', 'dice', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'haversine', 'pyfunc'] .. _classification: @@ -476,7 +480,7 @@ A list of valid metrics for any of the above algorithms can be obtained by using ``valid_metric`` attribute. For example, valid metrics for ``KDTree`` can be generated by: >>> from sklearn.neighbors import KDTree - >>> print(sorted(KDTree.valid_metrics)) + >>> print(sorted(KDTree.valid_metrics())) ['chebyshev', 'cityblock', 'euclidean', 'infinity', 'l1', 'l2', 'manhattan', 'minkowski', 'p'] diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 35bfbc00137e4..88febbd9a3aea 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -66,8 +66,8 @@ SCIPY_METRICS += ["kulsinski"] VALID_METRICS = dict( - ball_tree=BallTree.valid_metrics, - kd_tree=KDTree.valid_metrics, + ball_tree=BallTree._valid_metrics, + kd_tree=KDTree._valid_metrics, # The following list comes from the # sklearn.metrics.pairwise doc string brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)), diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 00b5b3c2758d3..1251932ab73f9 100644 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -234,11 +234,11 @@ leaf_size : positive int, default=40 metric : str or DistanceMetric object, default='minkowski' Metric to use for distance computation. Default is "minkowski", which results in the standard Euclidean distance when p = 2. - {binary_tree}.valid_metrics gives a list of the metrics which are valid for - {BinaryTree}. See the documentation of `scipy.spatial.distance - `_ and the - metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for - more information. + A list of valid metrics for {BinaryTree} is given by + :meth:`{BinaryTree}.valid_metrics`. + See the documentation of `scipy.spatial.distance + `_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for + more information on any distance metric. Additional keywords are passed to the distance metric class. Note: Callable functions in the metric parameter are NOT supported for KDTree @@ -791,7 +791,7 @@ cdef class BinaryTree: cdef int n_splits cdef int n_calls - valid_metrics = VALID_METRIC_IDS + _valid_metrics = VALID_METRIC_IDS # Use cinit to initialize all arrays to empty: this will prevent memory # errors and seg-faults in rare cases where __init__ is not called @@ -979,6 +979,19 @@ cdef class BinaryTree: self.node_bounds.base, ) + @classmethod + def valid_metrics(cls): + """Get list of valid distance metrics. + + .. versionadded:: 1.3 + + Returns + ------- + valid_metrics: list of str + List of valid distance metrics. + """ + return cls._valid_metrics + cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2, ITYPE_t size) nogil except -1: """Compute the distance between arrays x1 and x2""" diff --git a/sklearn/neighbors/_kde.py b/sklearn/neighbors/_kde.py index d7ffed501b1ae..8aa6e8c8ffc0d 100644 --- a/sklearn/neighbors/_kde.py +++ b/sklearn/neighbors/_kde.py @@ -174,12 +174,12 @@ def _choose_algorithm(self, algorithm, metric): # algorithm to compute the result. if algorithm == "auto": # use KD Tree if possible - if metric in KDTree.valid_metrics: + if metric in KDTree.valid_metrics(): return "kd_tree" - elif metric in BallTree.valid_metrics: + elif metric in BallTree.valid_metrics(): return "ball_tree" else: # kd_tree or ball_tree - if metric not in TREE_DICT[algorithm].valid_metrics: + if metric not in TREE_DICT[algorithm].valid_metrics(): raise ValueError( "invalid metric for {0}: '{1}'".format(TREE_DICT[algorithm], metric) ) diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index 23fa12a3c3a56..69cd3c8f5693f 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -114,7 +114,7 @@ def test_kde_algorithm_metric_choice(algorithm, metric): kde = KernelDensity(algorithm=algorithm, metric=metric) - if algorithm == "kd_tree" and metric not in KDTree.valid_metrics: + if algorithm == "kd_tree" and metric not in KDTree.valid_metrics(): with pytest.raises(ValueError, match="invalid metric"): kde.fit(X) else: @@ -165,7 +165,7 @@ def test_kde_sample_weights(): test_points = rng.rand(n_samples_test, d) for algorithm in ["auto", "ball_tree", "kd_tree"]: for metric in ["euclidean", "minkowski", "manhattan", "chebyshev"]: - if algorithm != "kd_tree" or metric in KDTree.valid_metrics: + if algorithm != "kd_tree" or metric in KDTree.valid_metrics(): kde = KernelDensity(algorithm=algorithm, metric=metric) # Test that adding a constant sample weight has no effect