Thanks to visit codestin.com
Credit goes to github.com

Skip to content

API Allow users to pass DistanceMetric objects to metric keyword argument in neighbors.KNeighborsRegressor #26267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ Changelog
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`, initiated by
:user:`Patrick O'Reilly <pat-oreilly>`.


:mod:`sklearn.neighbors`
........................

- |API| :class:`neighbors.KNeighborsRegressor` now accepts
:class:`metric.DistanceMetric` objects directly via the `metric` keyword
argument allowing for the use of accelerated third-party
:class:`metric.DistanceMetric` objects.
:pr:`26267` by :user:`Meekail Zain <micky774>`

:mod:`sklearn.metrics`
......................

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
X,
Y,
intp_t k,
str metric="euclidean",
metric="euclidean",
chunk_size=None,
dict metric_kwargs=None,
str strategy=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ cdef class DatasetsPair{{name_suffix}}:
cls,
X,
Y,
str metric="euclidean",
metric="euclidean",
dict metric_kwargs=None,
) -> DatasetsPair{{name_suffix}}:
"""Return the DatasetsPair implementation for the given arguments.
Expand All @@ -70,7 +70,7 @@ cdef class DatasetsPair{{name_suffix}}:
If provided as a ndarray, it must be C-contiguous.
If provided as a sparse matrix, it must be in CSR format.

metric : str, default='euclidean'
metric : str or DistanceMetric object, default='euclidean'
The distance metric to compute between rows of X and Y.
The default metric is a fast implementation of the Euclidean
metric. For a list of available metrics, see the documentation
Expand Down
8 changes: 6 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from scipy.sparse import issparse

from ... import get_config
from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING64
from .._dist_metrics import (
BOOL_METRICS,
METRIC_MAPPING64,
DistanceMetric,
)
from ._argkmin import (
ArgKmin32,
ArgKmin64,
Expand Down Expand Up @@ -117,7 +121,7 @@ def is_valid_sparse_matrix(X):
and (is_numpy_c_ordered(Y) or is_valid_sparse_matrix(Y))
and X.dtype == Y.dtype
and X.dtype in (np.float32, np.float64)
and metric in cls.valid_metrics()
and (metric in cls.valid_metrics() or isinstance(metric, DistanceMetric))
)

return is_usable
Expand Down
20 changes: 14 additions & 6 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..base import BaseEstimator, MultiOutputMixin, is_classifier
from ..exceptions import DataConversionWarning, EfficiencyWarning
from ..metrics import pairwise_distances_chunked
from ..metrics import DistanceMetric, pairwise_distances_chunked
from ..metrics._pairwise_distances_reduction import (
ArgKmin,
RadiusNeighbors,
Expand Down Expand Up @@ -414,7 +414,11 @@ def _check_algorithm_metric(self):
if self.algorithm == "auto":
if self.metric == "precomputed":
alg_check = "brute"
elif callable(self.metric) or self.metric in VALID_METRICS["ball_tree"]:
elif (
callable(self.metric)
or self.metric in VALID_METRICS["ball_tree"]
or isinstance(self.metric, DistanceMetric)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API wise, self.algorithm is selecting ball_tree all the time if self.metric is a DistanceMetric. Is this the preferred default for self.algorithm="auto"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This preserves the standard behavior, e.g. metric="euclidean" and metric=DistanceMetric.get_metric("euclidean") will both result in a BallTree with a EuclideanDistance object.

):
alg_check = "ball_tree"
else:
alg_check = "brute"
Expand All @@ -430,7 +434,9 @@ def _check_algorithm_metric(self):
"in very poor performance."
% self.metric
)
elif self.metric not in VALID_METRICS[alg_check]:
elif self.metric not in VALID_METRICS[alg_check] and not isinstance(
self.metric, DistanceMetric
):
raise ValueError(
"Metric '%s' not valid. Use "
"sorted(sklearn.neighbors.VALID_METRICS['%s']) "
Expand Down Expand Up @@ -563,9 +569,11 @@ def _fit(self, X, y=None):
if self.algorithm not in ("auto", "brute"):
warnings.warn("cannot use tree with sparse input: using brute force")

if self.effective_metric_ not in VALID_METRICS_SPARSE[
"brute"
] and not callable(self.effective_metric_):
if (
self.effective_metric_ not in VALID_METRICS_SPARSE["brute"]
and not callable(self.effective_metric_)
and not isinstance(self.effective_metric_, DistanceMetric)
):
raise ValueError(
"Metric '%s' not valid for sparse input. "
"Use sorted(sklearn.neighbors."
Expand Down
7 changes: 6 additions & 1 deletion sklearn/neighbors/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from ..base import RegressorMixin, _fit_context
from ..metrics import DistanceMetric
from ..utils._param_validation import StrOptions
from ._base import KNeighborsMixin, NeighborsBase, RadiusNeighborsMixin, _get_weights

Expand Down Expand Up @@ -71,7 +72,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase):
equivalent to using manhattan_distance (l1), and euclidean_distance
(l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.

metric : str or callable, default='minkowski'
metric : str, DistanceMetric object or callable, default='minkowski'
Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2. See the
documentation of `scipy.spatial.distance
Expand All @@ -89,6 +90,9 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase):
between those vectors. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.

If metric is a DistanceMetric object, it will be passed directly to
the underlying computation routines.

metric_params : dict, default=None
Additional keyword arguments for the metric function.

Expand Down Expand Up @@ -164,6 +168,7 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase):
**NeighborsBase._parameter_constraints,
"weights": [StrOptions({"uniform", "distance"}), callable, None],
}
_parameter_constraints["metric"].append(DistanceMetric)
_parameter_constraints.pop("radius")

def __init__(
Expand Down
95 changes: 86 additions & 9 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
)
from sklearn.base import clone
from sklearn.exceptions import DataConversionWarning, EfficiencyWarning, NotFittedError
from sklearn.metrics._dist_metrics import (
DistanceMetric,
)
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS
from sklearn.metrics.tests.test_pairwise_distances_reduction import (
Expand Down Expand Up @@ -69,13 +72,33 @@
COMMON_VALID_METRICS = sorted(
set.intersection(*map(set, neighbors.VALID_METRICS.values()))
) # type: ignore

P = (1, 2, 3, 4, np.inf)
JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys())

# Filter deprecation warnings.
neighbors.kneighbors_graph = ignore_warnings(neighbors.kneighbors_graph)
neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph)

# A list containing metrics where the string specifies the use of the
# DistanceMetric object directly (as resolved in _parse_metric)
DISTANCE_METRIC_OBJS = ["DM_euclidean"]


def _parse_metric(metric: str, dtype=None):
"""
Helper function for properly building a type-specialized DistanceMetric instances.

Constructs a type-specialized DistanceMetric instance from a string
beginning with "DM_" while allowing a pass-through for other metric-specifying
strings. This is necessary since we wish to parameterize dtype independent of
metric, yet DistanceMetric requires it for construction.

"""
if metric[:3] == "DM_":
return DistanceMetric.get_metric(metric[3:], dtype=dtype)
return metric


def _generate_test_params_for(metric: str, n_features: int):
"""Return list of DistanceMetric kwargs for tests."""
Expand Down Expand Up @@ -129,7 +152,7 @@ def _weight_func(dist):
],
)
@pytest.mark.parametrize("query_is_train", [False, True])
@pytest.mark.parametrize("metric", COMMON_VALID_METRICS)
@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa
def test_unsupervised_kneighbors(
global_dtype,
n_samples,
Expand All @@ -143,6 +166,8 @@ def test_unsupervised_kneighbors(
# on their common metrics, with and without returning
# distances

metric = _parse_metric(metric, global_dtype)

# Redefining the rng locally to use the same generated X
local_rng = np.random.RandomState(0)
X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False)
Expand All @@ -157,6 +182,12 @@ def test_unsupervised_kneighbors(
results = []

for algorithm in ALGORITHMS:
if isinstance(metric, DistanceMetric) and global_dtype == np.float32:
if "tree" in algorithm: # pragma: nocover
pytest.skip(
"Neither KDTree nor BallTree support 32-bit distance metric"
" objects."
)
neigh = neighbors.NearestNeighbors(
n_neighbors=n_neighbors, algorithm=algorithm, metric=metric
)
Expand Down Expand Up @@ -206,7 +237,7 @@ def test_unsupervised_kneighbors(
(1000, 5, 100),
],
)
@pytest.mark.parametrize("metric", COMMON_VALID_METRICS)
@pytest.mark.parametrize("metric", COMMON_VALID_METRICS + DISTANCE_METRIC_OBJS) # type: ignore # noqa
@pytest.mark.parametrize("n_neighbors, radius", [(1, 100), (50, 500), (100, 1000)])
@pytest.mark.parametrize(
"NeighborsMixinSubclass",
Expand All @@ -230,6 +261,19 @@ def test_neigh_predictions_algorithm_agnosticity(
# The different algorithms must return identical predictions results
# on their common metrics.

metric = _parse_metric(metric, global_dtype)
if isinstance(metric, DistanceMetric):
if "Classifier" in NeighborsMixinSubclass.__name__:
pytest.skip(
"Metrics of type `DistanceMetric` are not yet supported for"
" classifiers."
)
if "Radius" in NeighborsMixinSubclass.__name__:
pytest.skip(
"Metrics of type `DistanceMetric` are not yet supported for"
" radius-neighbor estimators."
)

# Redefining the rng locally to use the same generated X
local_rng = np.random.RandomState(0)
X = local_rng.rand(n_samples, n_features).astype(global_dtype, copy=False)
Expand All @@ -244,6 +288,12 @@ def test_neigh_predictions_algorithm_agnosticity(
)

for algorithm in ALGORITHMS:
if isinstance(metric, DistanceMetric) and global_dtype == np.float32:
if "tree" in algorithm: # pragma: nocover
pytest.skip(
"Neither KDTree nor BallTree support 32-bit distance metric"
" objects."
)
neigh = NeighborsMixinSubclass(parameter, algorithm=algorithm, metric=metric)
neigh.fit(X, y)

Expand Down Expand Up @@ -985,15 +1035,26 @@ def test_query_equidistant_kth_nn(algorithm):

@pytest.mark.parametrize(
["algorithm", "metric"],
[
("ball_tree", "euclidean"),
("kd_tree", "euclidean"),
list(
product(
("kd_tree", "ball_tree", "brute"),
("euclidean", *DISTANCE_METRIC_OBJS),
)
)
+ [
("brute", "euclidean"),
("brute", "precomputed"),
],
)
def test_radius_neighbors_sort_results(algorithm, metric):
# Test radius_neighbors[_graph] output when sort_result is True

metric = _parse_metric(metric, np.float64)
if isinstance(metric, DistanceMetric):
pytest.skip(
"Metrics of type `DistanceMetric` are not yet supported for radius-neighbor"
" estimators."
)
n_samples = 10
rng = np.random.RandomState(42)
X = rng.random_sample((n_samples, 4))
Expand Down Expand Up @@ -1560,11 +1621,14 @@ def test_nearest_neighbors_validate_params():
neighbors.VALID_METRICS["brute"]
)
- set(["pyfunc", *BOOL_METRICS])
),
)
+ DISTANCE_METRIC_OBJS,
)
def test_neighbors_metrics(
global_dtype, metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5
):
metric = _parse_metric(metric, global_dtype)

# Test computing the neighbors for various metrics
algorithms = ["brute", "ball_tree", "kd_tree"]
X_train = rng.rand(n_samples, n_features).astype(global_dtype, copy=False)
Expand All @@ -1574,12 +1638,21 @@ def test_neighbors_metrics(

for metric_params in metric_params_list:
# Some metric (e.g. Weighted minkowski) are not supported by KDTree
exclude_kd_tree = metric not in neighbors.VALID_METRICS["kd_tree"] or (
"minkowski" in metric and "w" in metric_params
exclude_kd_tree = (
False
if isinstance(metric, DistanceMetric)
else metric not in neighbors.VALID_METRICS["kd_tree"]
or ("minkowski" in metric and "w" in metric_params)
)
results = {}
p = metric_params.pop("p", 2)
for algorithm in algorithms:
if isinstance(metric, DistanceMetric) and global_dtype == np.float32:
if "tree" in algorithm: # pragma: nocover
pytest.skip(
"Neither KDTree nor BallTree support 32-bit distance metric"
" objects."
)
neigh = neighbors.NearestNeighbors(
n_neighbors=n_neighbors,
algorithm=algorithm,
Expand Down Expand Up @@ -1684,10 +1757,14 @@ def custom_metric(x1, x2):
assert_allclose(dist1, dist2)


@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"])
@pytest.mark.parametrize(
"metric", neighbors.VALID_METRICS["brute"] + DISTANCE_METRIC_OBJS
)
def test_valid_brute_metric_for_auto_algorithm(
global_dtype, metric, n_samples=20, n_features=12
):
metric = _parse_metric(metric, global_dtype)

X = rng.rand(n_samples, n_features).astype(global_dtype, copy=False)
Xcsr = csr_matrix(X)

Expand Down