diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index c453ca84a4784..179a82e150f93 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -535,7 +535,21 @@ def _fit(self, X, y=None): ): self._fit_method = "brute" else: - if self.effective_metric_ in VALID_METRICS["kd_tree"]: + if ( + self.effective_metric_ == "minkowski" + and self.effective_metric_params_.get("w") is not None + ): + # Be consistent with scipy 1.8 conventions: in scipy 1.8, + # 'wminkowski' was removed in favor of passing a + # weight vector directly to 'minkowski'. + # + # 'wminkowski' is not part of valid metrics for KDTree but + # the 'minkowski' without weights is. + # + # Hence, we detect this case and choose BallTree + # which supports 'wminkowski'. + self._fit_method = "ball_tree" + elif self.effective_metric_ in VALID_METRICS["kd_tree"]: self._fit_method = "kd_tree" elif ( callable(self.effective_metric_) @@ -553,6 +567,16 @@ def _fit(self, X, y=None): **self.effective_metric_params_, ) elif self._fit_method == "kd_tree": + if ( + self.effective_metric_ == "minkowski" + and self.effective_metric_params_.get("w") is not None + ): + raise ValueError( + "algorithm='kd_tree' is not valid for " + "metric='minkowski' with a weight parameter 'w': " + "try algorithm='ball_tree' " + "or algorithm='brute' instead." + ) self._tree = KDTree( X, self.leaf_size, diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index c528c3ab900f8..5ba940c0bd1a7 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -22,7 +22,7 @@ from sklearn.metrics.pairwise import pairwise_distances from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split -from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS +from sklearn.neighbors import VALID_METRICS_SPARSE from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed from sklearn.pipeline import make_pipeline from sklearn.utils._testing import assert_array_almost_equal @@ -58,6 +58,46 @@ neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) +def _generate_test_params_for(metric: str, n_features: int): + """Return list of dummy DistanceMetric kwargs for tests.""" + + # Distinguishing on cases not to compute unneeded datastructures. + rng = np.random.RandomState(1) + weights = rng.random_sample(n_features) + + if metric == "minkowski": + minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)] + if sp_version >= parse_version("1.8.0.dev0"): + # TODO: remove the test once we no longer support scipy < 1.8.0. + # Recent scipy versions accept weights in the Minkowski metric directly: + # type: ignore + minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + return minkowski_kwargs + + # TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0. + if metric == "wminkowski": + weights /= weights.sum() + wminkowski_kwargs = [dict(p=1.5, w=weights)] + if sp_version < parse_version("1.8.0.dev0"): + # wminkowski was removed in scipy 1.8.0 but should work for previous + # versions. + wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + return wminkowski_kwargs + + if metric == "seuclidean": + return [dict(V=rng.rand(n_features))] + + if metric == "mahalanobis": + A = rng.rand(n_features, n_features) + # Make the matrix symmetric positive definite + VI = A + A.T + 3 * np.eye(n_features) + return [dict(VI=VI)] + + # Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric. + # In those cases, no kwargs are needed. + return [{}] + + def _weight_func(dist): """Weight function to replace lambda d: d ** -2. The lambda function is not valid because: @@ -1385,58 +1425,48 @@ def custom_metric(x1, x2): # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") -def test_valid_brute_metric_for_auto_algorithm(): - X = rng.rand(12, 12) +@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) +def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12): + # Any valid metric for algorithm="brute" must be a valid for algorithm="auto". + # It's the responsibility of the estimator to select which algorithm is likely + # to be the most efficient from the subset of the algorithm compatible with + # that metric (and params). Worst case is to fallback to algorithm="brute". + X = rng.rand(n_samples, n_features) Xcsr = csr_matrix(X) - # check that there is a metric that is valid for brute - # but not ball_tree (so we actually test something) - assert "cosine" in VALID_METRICS["brute"] - assert "cosine" not in VALID_METRICS["ball_tree"] + metric_params_list = _generate_test_params_for(metric, n_features) + + if metric == "precomputed": + X_precomputed = rng.random_sample((10, 4)) + Y_precomputed = rng.random_sample((3, 4)) + DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") + DYX = metrics.pairwise_distances( + Y_precomputed, X_precomputed, metric="euclidean" + ) + nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed") + nb_p.fit(DXX) + nb_p.kneighbors(DYX) - # Metric which don't required any additional parameter - require_params = ["mahalanobis", "wminkowski", "seuclidean"] - for metric in VALID_METRICS["brute"]: - if metric != "precomputed" and metric not in require_params: + else: + for metric_params in metric_params_list: nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric - ) - if metric != "haversine": - nn.fit(X) - nn.kneighbors(X) - else: - nn.fit(X[:, :2]) - nn.kneighbors(X[:, :2]) - elif metric == "precomputed": - X_precomputed = rng.random_sample((10, 4)) - Y_precomputed = rng.random_sample((3, 4)) - DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") - DYX = metrics.pairwise_distances( - Y_precomputed, X_precomputed, metric="euclidean" + n_neighbors=3, + algorithm="auto", + metric=metric, + metric_params=metric_params, ) - nb_p = neighbors.NearestNeighbors(n_neighbors=3) - nb_p.fit(DXX) - nb_p.kneighbors(DYX) + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) - for metric in VALID_METRICS_SPARSE["brute"]: - if metric != "precomputed" and metric not in require_params: - nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric - ).fit(Xcsr) - nn.kneighbors(Xcsr) - - # Metric with parameter - VI = np.dot(X, X.T) - list_metrics = [ - ("seuclidean", dict(V=rng.rand(12))), - ("wminkowski", dict(w=rng.rand(12))), - ("mahalanobis", dict(VI=VI)), - ] - for metric, params in list_metrics: - nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric, metric_params=params - ).fit(X) - nn.kneighbors(X) + nn.fit(X) + nn.kneighbors(X) + + if metric in VALID_METRICS_SPARSE["brute"]: + nn = neighbors.NearestNeighbors( + n_neighbors=3, algorithm="auto", metric=metric + ).fit(Xcsr) + nn.kneighbors(Xcsr) def test_metric_params_interface():