From cde97aeb5c366c0c40a9643ab4300e1b4716aa04 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 18 Jan 2022 14:44:43 +0100 Subject: [PATCH 01/12] MAINT Fallback to 'brute' when using 'wminkowski' for 'kdtree' This is follow-up for #21741. --- sklearn/neighbors/_base.py | 28 ++- sklearn/neighbors/tests/test_neighbors.py | 208 +++++++++++----------- 2 files changed, 127 insertions(+), 109 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index c453ca84a4784..81f57fcfd187a 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -553,12 +553,28 @@ def _fit(self, X, y=None): **self.effective_metric_params_, ) elif self._fit_method == "kd_tree": - self._tree = KDTree( - X, - self.leaf_size, - metric=self.effective_metric_, - **self.effective_metric_params_, - ) + if ( + self.effective_metric_ == "minkowski" + and "w" in self.effective_metric_params_ + ): + # Be consistent with scipy 1.8 conventions: in scipy 1.8, + # 'wminkowski' was removed in favor of passing a + # weight vector directly to 'minkowski'. + # + # 'wminkowski' is not part of valid metrics for KDTree but + # the weights-less 'minkowski' is. + # + # Hence, we detect this case here and choose bruteforce instead. + self._fit_method = "brute" + self._tree = None + + else: + self._tree = KDTree( + X, + self.leaf_size, + metric=self.effective_metric_, + **self.effective_metric_params_, + ) elif self._fit_method == "brute": self._tree = None else: diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 2a4d500610051..76e81564b0ce2 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -22,11 +22,14 @@ from sklearn.metrics.pairwise import pairwise_distances from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split -from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS +from sklearn.neighbors import VALID_METRICS_SPARSE from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed from sklearn.pipeline import make_pipeline -from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_array_equal +from sklearn.utils._testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, +) from sklearn.utils._testing import ignore_warnings from sklearn.utils.validation import check_random_state from sklearn.utils.fixes import sp_version, parse_version @@ -50,6 +53,9 @@ SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,) ALGORITHMS = ("ball_tree", "brute", "kd_tree", "auto") +COMMON_VALID_METRICS = sorted( + set.intersection(*map(set, neighbors.VALID_METRICS.values())) +) P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) @@ -58,6 +64,46 @@ neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) +def _get_dummy_metric_params_list(metric: str, n_features: int): + """Return list of dummy DistanceMetric kwargs for tests.""" + + # Distinguishing on cases not to compute unneeded datastructures. + rng = np.random.RandomState(1) + + if metric == "minkowski": + minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)] + if sp_version >= parse_version("1.8.0.dev0"): + # TODO: remove the test once we no longer support scipy < 1.8.0. + # Recent scipy versions accept weights in the Minkowski metric directly: + # type: ignore + minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + return minkowski_kwargs + + # TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0. + if metric == "wminkowski": + weights = rng.random_sample(n_features) + weights /= weights.sum() + wminkowski_kwargs = [dict(p=1.5, w=weights)] + if sp_version < parse_version("1.8.0.dev0"): + # wminkowski was removed in scipy 1.8.0 but should work for previous + # versions. + wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + return wminkowski_kwargs + + if metric == "seuclidean": + return [dict(V=rng.rand(n_features))] + + if metric == "mahalanobis": + A = rng.rand(n_features, n_features) + # Make the matrix symmetric positive definite + VI = A + A.T + 3 * np.eye(n_features) + return [dict(VI=VI)] + + # Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric. + # In those cases, no kwargs is needed. + return [{}] + + def _weight_func(dist): """Weight function to replace lambda d: d ** -2. The lambda function is not valid because: @@ -1275,57 +1321,22 @@ def test_neighbors_badargs(): nbrs.radius_neighbors_graph(X, mode="blah") -def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5): +@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) +def test_neighbors_metrics( + metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5 +): # Test computing the neighbors for various metrics # create a symmetric matrix - V = rng.rand(n_features, n_features) - VI = np.dot(V, V.T) - - metrics = [ - ("euclidean", {}), - ("manhattan", {}), - ("minkowski", dict(p=1)), - ("minkowski", dict(p=2)), - ("minkowski", dict(p=3)), - ("minkowski", dict(p=np.inf)), - ("chebyshev", {}), - ("seuclidean", dict(V=rng.rand(n_features))), - ("mahalanobis", dict(VI=VI)), - ("haversine", {}), - ] - if sp_version < parse_version("1.8.0.dev0"): - # TODO: remove once we no longer support scipy < 1.8.0. - # wminkowski was removed in scipy 1.8.0 but should work for previous - # versions. - metrics.append( - ("wminkowski", dict(p=3, w=rng.rand(n_features))), - ) - else: - # Recent scipy versions accept weights in the Minkowski metric directly: - metrics.append( - ("minkowski", dict(p=3, w=rng.rand(n_features))), - ) - algorithms = ["brute", "ball_tree", "kd_tree"] - X = rng.rand(n_samples, n_features) + X_train = rng.rand(n_samples, n_features) + X_test = rng.rand(n_query_pts, n_features) - test = rng.rand(n_query_pts, n_features) + metric_params_list = _get_dummy_metric_params_list(metric, n_features) - for metric, metric_params in metrics: + for metric_params in metric_params_list: results = {} p = metric_params.pop("p", 2) - w = metric_params.get("w", None) for algorithm in algorithms: - # KD tree doesn't support all metrics - if algorithm == "kd_tree" and ( - metric not in neighbors.KDTree.valid_metrics or w is not None - ): - est = neighbors.NearestNeighbors( - algorithm=algorithm, metric=metric, metric_params=metric_params - ) - with pytest.raises(ValueError): - est.fit(X) - continue neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, @@ -1334,10 +1345,7 @@ def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbor metric_params=metric_params, ) - # Haversine distance only accepts 2D data - feature_sl = slice(None, 2) if metric == "haversine" else slice(None) - - neigh.fit(X[:, feature_sl]) + neigh.fit(X_train) # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 ExceptionToAssert = None @@ -1349,15 +1357,20 @@ def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbor ExceptionToAssert = DeprecationWarning with pytest.warns(ExceptionToAssert): - results[algorithm] = neigh.kneighbors( - test[:, feature_sl], return_distance=True - ) + results[algorithm] = neigh.kneighbors(X_test, return_distance=True) + + brute_dst, brute_idx = results["brute"] + kd_tree_dst, kd_tree_idx = results["kd_tree"] + ball_tree_dst, ball_tree_idx = results["ball_tree"] - assert_array_almost_equal(results["brute"][0], results["ball_tree"][0]) - assert_array_almost_equal(results["brute"][1], results["ball_tree"][1]) - if "kd_tree" in results: - assert_array_almost_equal(results["brute"][0], results["kd_tree"][0]) - assert_array_almost_equal(results["brute"][1], results["kd_tree"][1]) + assert_allclose(brute_dst, ball_tree_dst) + assert_array_equal(brute_idx, ball_tree_idx) + + assert_allclose(brute_dst, kd_tree_dst) + assert_array_equal(brute_idx, kd_tree_idx) + + assert_allclose(ball_tree_dst, kd_tree_dst) + assert_array_equal(ball_tree_idx, kd_tree_idx) def test_callable_metric(): @@ -1381,58 +1394,47 @@ def custom_metric(x1, x2): assert_array_almost_equal(dist1, dist2) -def test_valid_brute_metric_for_auto_algorithm(): - X = rng.rand(12, 12) +@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) +def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12): + X = rng.rand(n_samples, n_features) Xcsr = csr_matrix(X) - # check that there is a metric that is valid for brute - # but not ball_tree (so we actually test something) - assert "cosine" in VALID_METRICS["brute"] - assert "cosine" not in VALID_METRICS["ball_tree"] + metric_params_list = _get_dummy_metric_params_list(metric, n_features) - # Metric which don't required any additional parameter - require_params = ["mahalanobis", "wminkowski", "seuclidean"] - for metric in VALID_METRICS["brute"]: - if metric != "precomputed" and metric not in require_params: + if metric == "precomputed": + X_precomputed = rng.random_sample((10, 4)) + Y_precomputed = rng.random_sample((3, 4)) + DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") + DYX = metrics.pairwise_distances( + Y_precomputed, X_precomputed, metric="euclidean" + ) + nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed") + nb_p.fit(DXX) + nb_p.kneighbors(DYX) + + else: + for metric_params in metric_params_list: nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric + n_neighbors=3, + algorithm="auto", + metric=metric, + metric_params=metric_params, ) - if metric != "haversine": - nn.fit(X) - nn.kneighbors(X) + # Haversine distance only accepts 2D data + if metric == "haversine": + feature_sl = slice(None, 2) + X = np.ascontiguousarray(X[:, feature_sl]) else: - nn.fit(X[:, :2]) - nn.kneighbors(X[:, :2]) - elif metric == "precomputed": - X_precomputed = rng.random_sample((10, 4)) - Y_precomputed = rng.random_sample((3, 4)) - DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") - DYX = metrics.pairwise_distances( - Y_precomputed, X_precomputed, metric="euclidean" - ) - nb_p = neighbors.NearestNeighbors(n_neighbors=3) - nb_p.fit(DXX) - nb_p.kneighbors(DYX) + X = X - for metric in VALID_METRICS_SPARSE["brute"]: - if metric != "precomputed" and metric not in require_params: - nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric - ).fit(Xcsr) - nn.kneighbors(Xcsr) - - # Metric with parameter - VI = np.dot(X, X.T) - list_metrics = [ - ("seuclidean", dict(V=rng.rand(12))), - ("wminkowski", dict(w=rng.rand(12))), - ("mahalanobis", dict(VI=VI)), - ] - for metric, params in list_metrics: - nn = neighbors.NearestNeighbors( - n_neighbors=3, algorithm="auto", metric=metric, metric_params=params - ).fit(X) - nn.kneighbors(X) + nn.fit(X) + nn.kneighbors(X) + + if metric in VALID_METRICS_SPARSE["brute"]: + nn = neighbors.NearestNeighbors( + n_neighbors=3, algorithm="auto", metric=metric + ).fit(Xcsr) + nn.kneighbors(Xcsr) def test_metric_params_interface(): From 235ef6e69ce387cb12710e70d768c36b0760809e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 18 Jan 2022 17:00:23 +0100 Subject: [PATCH 02/12] [scipy-dev] Testing coverage From 6f192318095a04f3ede1ee9930a02ef2a5350354 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 14:49:38 +0100 Subject: [PATCH 03/12] Address review comments Co-authored-by: Olivier Grisel --- sklearn/neighbors/tests/test_neighbors.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 76e81564b0ce2..3d7c8d17e08a9 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -64,7 +64,7 @@ neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph) -def _get_dummy_metric_params_list(metric: str, n_features: int): +def _generate_test_params_for(metric: str, n_features: int): """Return list of dummy DistanceMetric kwargs for tests.""" # Distinguishing on cases not to compute unneeded datastructures. @@ -1331,7 +1331,7 @@ def test_neighbors_metrics( X_train = rng.rand(n_samples, n_features) X_test = rng.rand(n_query_pts, n_features) - metric_params_list = _get_dummy_metric_params_list(metric, n_features) + metric_params_list = _generate_test_params_for(metric, n_features) for metric_params in metric_params_list: results = {} @@ -1396,10 +1396,15 @@ def custom_metric(x1, x2): @pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12): + # Any valid metric for algorithm="brute" should be valid for algorithm="auto" + # it's the responsibility of the estimator to select which algorithm is likely + # to be the most efficient from the subset of the algorithm compatible with + # that metric (and params). If the worst case, algorithm="brute" is the ultimate + # fallback. X = rng.rand(n_samples, n_features) Xcsr = csr_matrix(X) - metric_params_list = _get_dummy_metric_params_list(metric, n_features) + metric_params_list = _generate_test_params_for(metric, n_features) if metric == "precomputed": X_precomputed = rng.random_sample((10, 4)) @@ -1422,10 +1427,7 @@ def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features= ) # Haversine distance only accepts 2D data if metric == "haversine": - feature_sl = slice(None, 2) - X = np.ascontiguousarray(X[:, feature_sl]) - else: - X = X + X = np.ascontiguousarray(X[:, :2]) nn.fit(X) nn.kneighbors(X) From 1feaf552580193716b09a751ee92112e5db77185 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 14:59:29 +0100 Subject: [PATCH 04/12] Warn when falling back on BallTree --- sklearn/neighbors/_base.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 81f57fcfd187a..f04a37ebce8bf 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -535,7 +535,21 @@ def _fit(self, X, y=None): ): self._fit_method = "brute" else: - if self.effective_metric_ in VALID_METRICS["kd_tree"]: + if ( + self.effective_metric_ == "minkowski" + and "w" in self.effective_metric_params_ + ): + # Be consistent with scipy 1.8 conventions: in scipy 1.8, + # 'wminkowski' was removed in favor of passing a + # weight vector directly to 'minkowski'. + # + # 'wminkowski' is not part of valid metrics for KDTree but + # the weights-less 'minkowski' is. + # + # Hence, we detect this case here and choose BallTree + # which supports 'wminkowski'. + self._fit_method = "ball_tree" + elif self.effective_metric_ in VALID_METRICS["kd_tree"]: self._fit_method = "kd_tree" elif ( callable(self.effective_metric_) @@ -557,16 +571,17 @@ def _fit(self, X, y=None): self.effective_metric_ == "minkowski" and "w" in self.effective_metric_params_ ): - # Be consistent with scipy 1.8 conventions: in scipy 1.8, - # 'wminkowski' was removed in favor of passing a - # weight vector directly to 'minkowski'. - # - # 'wminkowski' is not part of valid metrics for KDTree but - # the weights-less 'minkowski' is. - # - # Hence, we detect this case here and choose bruteforce instead. - self._fit_method = "brute" - self._tree = None + warnings.warn( + "KDTree does not support Weighted Minkowski. " + "Falling back on BallTree." + ) + self._fit_method = "ball_tree" + self._tree = BallTree( + X, + self.leaf_size, + metric=self.effective_metric_, + **self.effective_metric_params_, + ) else: self._tree = KDTree( From bf489b71b374669b2f148170b1628c7c02803255 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 15:52:21 +0100 Subject: [PATCH 05/12] Properly skip and raise Exception --- sklearn/neighbors/_base.py | 19 +++++++------------ sklearn/neighbors/tests/test_neighbors.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index f04a37ebce8bf..5c57c28f8cfd1 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -537,7 +537,7 @@ def _fit(self, X, y=None): else: if ( self.effective_metric_ == "minkowski" - and "w" in self.effective_metric_params_ + and self.effective_metric_params_.get("w") is not None ): # Be consistent with scipy 1.8 conventions: in scipy 1.8, # 'wminkowski' was removed in favor of passing a @@ -569,18 +569,13 @@ def _fit(self, X, y=None): elif self._fit_method == "kd_tree": if ( self.effective_metric_ == "minkowski" - and "w" in self.effective_metric_params_ + and self.effective_metric_params_.get("w") is not None ): - warnings.warn( - "KDTree does not support Weighted Minkowski. " - "Falling back on BallTree." - ) - self._fit_method = "ball_tree" - self._tree = BallTree( - X, - self.leaf_size, - metric=self.effective_metric_, - **self.effective_metric_params_, + raise ValueError( + "algorithm='kd_tree' is not valid for " + "metric='minkowski' with a weight parameter 'w': " + "try algorithm='ball_tree' " + "or algorithm='brute' instead." ) else: diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 3d7c8d17e08a9..2616b9f110c7c 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -1345,6 +1345,18 @@ def test_neighbors_metrics( metric_params=metric_params, ) + if ( + metric == "minkowski" + and metric_params.get("w") is not None + and algorithm == "kd_tree" + ): + pytest.skip( + "algorithm='kd_tree' is not valid for " + "metric='minkowski' with a weight parameter 'w': " + "try algorithm='ball_tree' " + "or algorithm='brute' instead" + ) + neigh.fit(X_train) # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 From d523377bd6b2c9b1942cdb9541f9907c4dbb4cf2 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 16:34:06 +0100 Subject: [PATCH 06/12] Revert test_neighbors_metrics and add case to pass --- sklearn/neighbors/tests/test_neighbors.py | 78 ++++++++++++++++------- 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 2616b9f110c7c..05314a1e1816d 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -26,7 +26,6 @@ from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed from sklearn.pipeline import make_pipeline from sklearn.utils._testing import ( - assert_allclose, assert_array_almost_equal, assert_array_equal, ) @@ -1321,22 +1320,57 @@ def test_neighbors_badargs(): nbrs.radius_neighbors_graph(X, mode="blah") -@pytest.mark.parametrize("metric", COMMON_VALID_METRICS) -def test_neighbors_metrics( - metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5 -): +def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5): # Test computing the neighbors for various metrics # create a symmetric matrix + V = rng.rand(n_features, n_features) + VI = np.dot(V, V.T) + + metrics = [ + ("euclidean", {}), + ("manhattan", {}), + ("minkowski", dict(p=1)), + ("minkowski", dict(p=2)), + ("minkowski", dict(p=3)), + ("minkowski", dict(p=np.inf)), + ("chebyshev", {}), + ("seuclidean", dict(V=rng.rand(n_features))), + ("mahalanobis", dict(VI=VI)), + ("haversine", {}), + ] + if sp_version < parse_version("1.8.0.dev0"): + # TODO: remove once we no longer support scipy < 1.8.0. + # wminkowski was removed in scipy 1.8.0 but should work for previous + # versions. + metrics.append( + ("wminkowski", dict(p=3, w=rng.rand(n_features))), + ) + else: + # Recent scipy versions accept weights in the Minkowski metric directly: + metrics.append( + ("minkowski", dict(p=3, w=rng.rand(n_features))), + ) + algorithms = ["brute", "ball_tree", "kd_tree"] - X_train = rng.rand(n_samples, n_features) - X_test = rng.rand(n_query_pts, n_features) + X = rng.rand(n_samples, n_features) - metric_params_list = _generate_test_params_for(metric, n_features) + test = rng.rand(n_query_pts, n_features) - for metric_params in metric_params_list: + for metric, metric_params in metrics: results = {} p = metric_params.pop("p", 2) + w = metric_params.get("w", None) for algorithm in algorithms: + # KD tree doesn't support all metrics + if algorithm == "kd_tree" and ( + metric not in neighbors.KDTree.valid_metrics or w is not None + ): + est = neighbors.NearestNeighbors( + algorithm=algorithm, metric=metric, metric_params=metric_params + ) + with pytest.raises(ValueError): + est.fit(X) + continue neigh = neighbors.NearestNeighbors( n_neighbors=n_neighbors, algorithm=algorithm, @@ -1345,6 +1379,9 @@ def test_neighbors_metrics( metric_params=metric_params, ) + # Haversine distance only accepts 2D data + feature_sl = slice(None, 2) if metric == "haversine" else slice(None) + if ( metric == "minkowski" and metric_params.get("w") is not None @@ -1357,7 +1394,7 @@ def test_neighbors_metrics( "or algorithm='brute' instead" ) - neigh.fit(X_train) + neigh.fit(X[:, feature_sl]) # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 ExceptionToAssert = None @@ -1369,20 +1406,15 @@ def test_neighbors_metrics( ExceptionToAssert = DeprecationWarning with pytest.warns(ExceptionToAssert): - results[algorithm] = neigh.kneighbors(X_test, return_distance=True) - - brute_dst, brute_idx = results["brute"] - kd_tree_dst, kd_tree_idx = results["kd_tree"] - ball_tree_dst, ball_tree_idx = results["ball_tree"] - - assert_allclose(brute_dst, ball_tree_dst) - assert_array_equal(brute_idx, ball_tree_idx) - - assert_allclose(brute_dst, kd_tree_dst) - assert_array_equal(brute_idx, kd_tree_idx) + results[algorithm] = neigh.kneighbors( + test[:, feature_sl], return_distance=True + ) - assert_allclose(ball_tree_dst, kd_tree_dst) - assert_array_equal(ball_tree_idx, kd_tree_idx) + assert_array_almost_equal(results["brute"][0], results["ball_tree"][0]) + assert_array_almost_equal(results["brute"][1], results["ball_tree"][1]) + if "kd_tree" in results: + assert_array_almost_equal(results["brute"][0], results["kd_tree"][0]) + assert_array_almost_equal(results["brute"][1], results["kd_tree"][1]) def test_callable_metric(): From 693e633a8618f9a226e79416723e5605b1758906 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 17:09:44 +0100 Subject: [PATCH 07/12] Remove unused common valid metrics --- sklearn/neighbors/tests/test_neighbors.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 05314a1e1816d..94446a42d3b7a 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -52,9 +52,6 @@ SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,) ALGORITHMS = ("ball_tree", "brute", "kd_tree", "auto") -COMMON_VALID_METRICS = sorted( - set.intersection(*map(set, neighbors.VALID_METRICS.values())) -) P = (1, 2, 3, 4, np.inf) JOBLIB_BACKENDS = list(joblib.parallel.BACKENDS.keys()) From ecef93279afcd9304a2fffef2210c7a72b70cb1b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 19 Jan 2022 17:21:46 +0100 Subject: [PATCH 08/12] Revert imports --- sklearn/neighbors/tests/test_neighbors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 94446a42d3b7a..bd40b7d74cbb6 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -25,10 +25,8 @@ from sklearn.neighbors import VALID_METRICS_SPARSE from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed from sklearn.pipeline import make_pipeline -from sklearn.utils._testing import ( - assert_array_almost_equal, - assert_array_equal, -) +from sklearn.utils._testing import assert_array_almost_equal +from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import ignore_warnings from sklearn.utils.validation import check_random_state from sklearn.utils.fixes import sp_version, parse_version From 08b464e3e5b2fa7639e6f587b4611a0219a2d0ea Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 20 Jan 2022 10:43:08 +0100 Subject: [PATCH 09/12] Remove unnecessary test Co-authored-by: Olivier Grisel --- sklearn/neighbors/tests/test_neighbors.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index bd40b7d74cbb6..ab2fb3735ab4c 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -1377,18 +1377,6 @@ def test_neighbors_metrics(n_samples=20, n_features=3, n_query_pts=2, n_neighbor # Haversine distance only accepts 2D data feature_sl = slice(None, 2) if metric == "haversine" else slice(None) - if ( - metric == "minkowski" - and metric_params.get("w") is not None - and algorithm == "kd_tree" - ): - pytest.skip( - "algorithm='kd_tree' is not valid for " - "metric='minkowski' with a weight parameter 'w': " - "try algorithm='ball_tree' " - "or algorithm='brute' instead" - ) - neigh.fit(X[:, feature_sl]) # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 From 9df5672869ffd093d5344005d731a0baa4dd7ed4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 22 Jan 2022 15:04:28 +0100 Subject: [PATCH 10/12] Update sklearn/neighbors/_base.py Co-authored-by: Christian Lorentzen --- sklearn/neighbors/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 5c57c28f8cfd1..4a2b2fb170ec1 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -544,7 +544,7 @@ def _fit(self, X, y=None): # weight vector directly to 'minkowski'. # # 'wminkowski' is not part of valid metrics for KDTree but - # the weights-less 'minkowski' is. + # the 'minkowski' without weights is. # # Hence, we detect this case here and choose BallTree # which supports 'wminkowski'. From b17f59c9293054d655b6daff3fc091603405f3c6 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 22 Jan 2022 15:15:40 +0100 Subject: [PATCH 11/12] Apply comments from reviews Co-authored-by: Christian Lorentzen Co-authored-by: Thomas J. Fan --- sklearn/neighbors/_base.py | 16 +++++++--------- sklearn/neighbors/tests/test_neighbors.py | 11 +++++------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 4a2b2fb170ec1..179a82e150f93 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -546,7 +546,7 @@ def _fit(self, X, y=None): # 'wminkowski' is not part of valid metrics for KDTree but # the 'minkowski' without weights is. # - # Hence, we detect this case here and choose BallTree + # Hence, we detect this case and choose BallTree # which supports 'wminkowski'. self._fit_method = "ball_tree" elif self.effective_metric_ in VALID_METRICS["kd_tree"]: @@ -577,14 +577,12 @@ def _fit(self, X, y=None): "try algorithm='ball_tree' " "or algorithm='brute' instead." ) - - else: - self._tree = KDTree( - X, - self.leaf_size, - metric=self.effective_metric_, - **self.effective_metric_params_, - ) + self._tree = KDTree( + X, + self.leaf_size, + metric=self.effective_metric_, + **self.effective_metric_params_, + ) elif self._fit_method == "brute": self._tree = None else: diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index d79528c670706..5ba940c0bd1a7 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -63,6 +63,7 @@ def _generate_test_params_for(metric: str, n_features: int): # Distinguishing on cases not to compute unneeded datastructures. rng = np.random.RandomState(1) + weights = rng.random_sample(n_features) if metric == "minkowski": minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)] @@ -75,7 +76,6 @@ def _generate_test_params_for(metric: str, n_features: int): # TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0. if metric == "wminkowski": - weights = rng.random_sample(n_features) weights /= weights.sum() wminkowski_kwargs = [dict(p=1.5, w=weights)] if sp_version < parse_version("1.8.0.dev0"): @@ -94,7 +94,7 @@ def _generate_test_params_for(metric: str, n_features: int): return [dict(VI=VI)] # Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric. - # In those cases, no kwargs is needed. + # In those cases, no kwargs are needed. return [{}] @@ -1427,11 +1427,10 @@ def custom_metric(x1, x2): @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") @pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12): - # Any valid metric for algorithm="brute" should be valid for algorithm="auto" - # it's the responsibility of the estimator to select which algorithm is likely + # Any valid metric for algorithm="brute" must be a valid for algorithm="auto". + # It's the responsibility of the estimator to select which algorithm is likely # to be the most efficient from the subset of the algorithm compatible with - # that metric (and params). If the worst case, algorithm="brute" is the ultimate - # fallback. + # that metric (and params). Worst case is to fallback to algorithm="brute". X = rng.rand(n_samples, n_features) Xcsr = csr_matrix(X) From a320e9ce3be95cf24372c549e95bdeaeefc234d4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 22 Jan 2022 17:35:18 +0100 Subject: [PATCH 12/12] [scipy-dev] Testing coverage