Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6145cae

Browse files
jjerphanogrisellorentzenchrthomasjpfan
authored
FIX Fallback to ball_tree using minkowski with w for kd_tree (#22241)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Christian Lorentzen <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent d7fc1df commit 6145cae

File tree

2 files changed

+102
-48
lines changed

2 files changed

+102
-48
lines changed

sklearn/neighbors/_base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,21 @@ def _fit(self, X, y=None):
535535
):
536536
self._fit_method = "brute"
537537
else:
538-
if self.effective_metric_ in VALID_METRICS["kd_tree"]:
538+
if (
539+
self.effective_metric_ == "minkowski"
540+
and self.effective_metric_params_.get("w") is not None
541+
):
542+
# Be consistent with scipy 1.8 conventions: in scipy 1.8,
543+
# 'wminkowski' was removed in favor of passing a
544+
# weight vector directly to 'minkowski'.
545+
#
546+
# 'wminkowski' is not part of valid metrics for KDTree but
547+
# the 'minkowski' without weights is.
548+
#
549+
# Hence, we detect this case and choose BallTree
550+
# which supports 'wminkowski'.
551+
self._fit_method = "ball_tree"
552+
elif self.effective_metric_ in VALID_METRICS["kd_tree"]:
539553
self._fit_method = "kd_tree"
540554
elif (
541555
callable(self.effective_metric_)
@@ -553,6 +567,16 @@ def _fit(self, X, y=None):
553567
**self.effective_metric_params_,
554568
)
555569
elif self._fit_method == "kd_tree":
570+
if (
571+
self.effective_metric_ == "minkowski"
572+
and self.effective_metric_params_.get("w") is not None
573+
):
574+
raise ValueError(
575+
"algorithm='kd_tree' is not valid for "
576+
"metric='minkowski' with a weight parameter 'w': "
577+
"try algorithm='ball_tree' "
578+
"or algorithm='brute' instead."
579+
)
556580
self._tree = KDTree(
557581
X,
558582
self.leaf_size,

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.metrics.pairwise import pairwise_distances
2323
from sklearn.model_selection import cross_val_score
2424
from sklearn.model_selection import train_test_split
25-
from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS
25+
from sklearn.neighbors import VALID_METRICS_SPARSE
2626
from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed
2727
from sklearn.pipeline import make_pipeline
2828
from sklearn.utils._testing import assert_array_almost_equal
@@ -58,6 +58,46 @@
5858
neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph)
5959

6060

61+
def _generate_test_params_for(metric: str, n_features: int):
62+
"""Return list of dummy DistanceMetric kwargs for tests."""
63+
64+
# Distinguishing on cases not to compute unneeded datastructures.
65+
rng = np.random.RandomState(1)
66+
weights = rng.random_sample(n_features)
67+
68+
if metric == "minkowski":
69+
minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)]
70+
if sp_version >= parse_version("1.8.0.dev0"):
71+
# TODO: remove the test once we no longer support scipy < 1.8.0.
72+
# Recent scipy versions accept weights in the Minkowski metric directly:
73+
# type: ignore
74+
minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
75+
return minkowski_kwargs
76+
77+
# TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0.
78+
if metric == "wminkowski":
79+
weights /= weights.sum()
80+
wminkowski_kwargs = [dict(p=1.5, w=weights)]
81+
if sp_version < parse_version("1.8.0.dev0"):
82+
# wminkowski was removed in scipy 1.8.0 but should work for previous
83+
# versions.
84+
wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
85+
return wminkowski_kwargs
86+
87+
if metric == "seuclidean":
88+
return [dict(V=rng.rand(n_features))]
89+
90+
if metric == "mahalanobis":
91+
A = rng.rand(n_features, n_features)
92+
# Make the matrix symmetric positive definite
93+
VI = A + A.T + 3 * np.eye(n_features)
94+
return [dict(VI=VI)]
95+
96+
# Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric.
97+
# In those cases, no kwargs are needed.
98+
return [{}]
99+
100+
61101
def _weight_func(dist):
62102
"""Weight function to replace lambda d: d ** -2.
63103
The lambda function is not valid because:
@@ -1385,58 +1425,48 @@ def custom_metric(x1, x2):
13851425

13861426
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
13871427
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
1388-
def test_valid_brute_metric_for_auto_algorithm():
1389-
X = rng.rand(12, 12)
1428+
@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"])
1429+
def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12):
1430+
# Any valid metric for algorithm="brute" must be a valid for algorithm="auto".
1431+
# It's the responsibility of the estimator to select which algorithm is likely
1432+
# to be the most efficient from the subset of the algorithm compatible with
1433+
# that metric (and params). Worst case is to fallback to algorithm="brute".
1434+
X = rng.rand(n_samples, n_features)
13901435
Xcsr = csr_matrix(X)
13911436

1392-
# check that there is a metric that is valid for brute
1393-
# but not ball_tree (so we actually test something)
1394-
assert "cosine" in VALID_METRICS["brute"]
1395-
assert "cosine" not in VALID_METRICS["ball_tree"]
1437+
metric_params_list = _generate_test_params_for(metric, n_features)
1438+
1439+
if metric == "precomputed":
1440+
X_precomputed = rng.random_sample((10, 4))
1441+
Y_precomputed = rng.random_sample((3, 4))
1442+
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
1443+
DYX = metrics.pairwise_distances(
1444+
Y_precomputed, X_precomputed, metric="euclidean"
1445+
)
1446+
nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed")
1447+
nb_p.fit(DXX)
1448+
nb_p.kneighbors(DYX)
13961449

1397-
# Metric which don't required any additional parameter
1398-
require_params = ["mahalanobis", "wminkowski", "seuclidean"]
1399-
for metric in VALID_METRICS["brute"]:
1400-
if metric != "precomputed" and metric not in require_params:
1450+
else:
1451+
for metric_params in metric_params_list:
14011452
nn = neighbors.NearestNeighbors(
1402-
n_neighbors=3, algorithm="auto", metric=metric
1403-
)
1404-
if metric != "haversine":
1405-
nn.fit(X)
1406-
nn.kneighbors(X)
1407-
else:
1408-
nn.fit(X[:, :2])
1409-
nn.kneighbors(X[:, :2])
1410-
elif metric == "precomputed":
1411-
X_precomputed = rng.random_sample((10, 4))
1412-
Y_precomputed = rng.random_sample((3, 4))
1413-
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
1414-
DYX = metrics.pairwise_distances(
1415-
Y_precomputed, X_precomputed, metric="euclidean"
1453+
n_neighbors=3,
1454+
algorithm="auto",
1455+
metric=metric,
1456+
metric_params=metric_params,
14161457
)
1417-
nb_p = neighbors.NearestNeighbors(n_neighbors=3)
1418-
nb_p.fit(DXX)
1419-
nb_p.kneighbors(DYX)
1458+
# Haversine distance only accepts 2D data
1459+
if metric == "haversine":
1460+
X = np.ascontiguousarray(X[:, :2])
14201461

1421-
for metric in VALID_METRICS_SPARSE["brute"]:
1422-
if metric != "precomputed" and metric not in require_params:
1423-
nn = neighbors.NearestNeighbors(
1424-
n_neighbors=3, algorithm="auto", metric=metric
1425-
).fit(Xcsr)
1426-
nn.kneighbors(Xcsr)
1427-
1428-
# Metric with parameter
1429-
VI = np.dot(X, X.T)
1430-
list_metrics = [
1431-
("seuclidean", dict(V=rng.rand(12))),
1432-
("wminkowski", dict(w=rng.rand(12))),
1433-
("mahalanobis", dict(VI=VI)),
1434-
]
1435-
for metric, params in list_metrics:
1436-
nn = neighbors.NearestNeighbors(
1437-
n_neighbors=3, algorithm="auto", metric=metric, metric_params=params
1438-
).fit(X)
1439-
nn.kneighbors(X)
1462+
nn.fit(X)
1463+
nn.kneighbors(X)
1464+
1465+
if metric in VALID_METRICS_SPARSE["brute"]:
1466+
nn = neighbors.NearestNeighbors(
1467+
n_neighbors=3, algorithm="auto", metric=metric
1468+
).fit(Xcsr)
1469+
nn.kneighbors(Xcsr)
14401470

14411471

14421472
def test_metric_params_interface():

0 commit comments

Comments
 (0)