Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX Fallback to ball_tree using minkowski with w for kd_tree #22241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 22, 2022
Merged
26 changes: 25 additions & 1 deletion sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,21 @@ def _fit(self, X, y=None):
):
self._fit_method = "brute"
else:
if self.effective_metric_ in VALID_METRICS["kd_tree"]:
if (
self.effective_metric_ == "minkowski"
and self.effective_metric_params_.get("w") is not None
):
# Be consistent with scipy 1.8 conventions: in scipy 1.8,
# 'wminkowski' was removed in favor of passing a
# weight vector directly to 'minkowski'.
#
# 'wminkowski' is not part of valid metrics for KDTree but
# the 'minkowski' without weights is.
#
# Hence, we detect this case and choose BallTree
# which supports 'wminkowski'.
self._fit_method = "ball_tree"
elif self.effective_metric_ in VALID_METRICS["kd_tree"]:
self._fit_method = "kd_tree"
elif (
callable(self.effective_metric_)
Expand All @@ -553,6 +567,16 @@ def _fit(self, X, y=None):
**self.effective_metric_params_,
)
elif self._fit_method == "kd_tree":
if (
self.effective_metric_ == "minkowski"
and self.effective_metric_params_.get("w") is not None
):
raise ValueError(
"algorithm='kd_tree' is not valid for "
"metric='minkowski' with a weight parameter 'w': "
"try algorithm='ball_tree' "
"or algorithm='brute' instead."
)
self._tree = KDTree(
X,
self.leaf_size,
Expand Down
124 changes: 77 additions & 47 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS
from sklearn.neighbors import VALID_METRICS_SPARSE
from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_almost_equal
Expand Down Expand Up @@ -58,6 +58,46 @@
neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph)


def _generate_test_params_for(metric: str, n_features: int):
"""Return list of dummy DistanceMetric kwargs for tests."""

# Distinguishing on cases not to compute unneeded datastructures.
rng = np.random.RandomState(1)
weights = rng.random_sample(n_features)

if metric == "minkowski":
minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)]
if sp_version >= parse_version("1.8.0.dev0"):
# TODO: remove the test once we no longer support scipy < 1.8.0.
# Recent scipy versions accept weights in the Minkowski metric directly:
# type: ignore
minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
return minkowski_kwargs

# TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0.
if metric == "wminkowski":
weights /= weights.sum()
wminkowski_kwargs = [dict(p=1.5, w=weights)]
if sp_version < parse_version("1.8.0.dev0"):
# wminkowski was removed in scipy 1.8.0 but should work for previous
# versions.
wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
return wminkowski_kwargs

if metric == "seuclidean":
return [dict(V=rng.rand(n_features))]

if metric == "mahalanobis":
A = rng.rand(n_features, n_features)
# Make the matrix symmetric positive definite
VI = A + A.T + 3 * np.eye(n_features)
return [dict(VI=VI)]

# Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric.
# In those cases, no kwargs are needed.
return [{}]


def _weight_func(dist):
"""Weight function to replace lambda d: d ** -2.
The lambda function is not valid because:
Expand Down Expand Up @@ -1385,58 +1425,48 @@ def custom_metric(x1, x2):

# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
def test_valid_brute_metric_for_auto_algorithm():
X = rng.rand(12, 12)
@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"])
def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12):
# Any valid metric for algorithm="brute" must be a valid for algorithm="auto".
# It's the responsibility of the estimator to select which algorithm is likely
# to be the most efficient from the subset of the algorithm compatible with
# that metric (and params). Worst case is to fallback to algorithm="brute".
X = rng.rand(n_samples, n_features)
Xcsr = csr_matrix(X)

# check that there is a metric that is valid for brute
# but not ball_tree (so we actually test something)
assert "cosine" in VALID_METRICS["brute"]
assert "cosine" not in VALID_METRICS["ball_tree"]
metric_params_list = _generate_test_params_for(metric, n_features)

if metric == "precomputed":
X_precomputed = rng.random_sample((10, 4))
Y_precomputed = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
DYX = metrics.pairwise_distances(
Y_precomputed, X_precomputed, metric="euclidean"
)
nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed")
nb_p.fit(DXX)
nb_p.kneighbors(DYX)

# Metric which don't required any additional parameter
require_params = ["mahalanobis", "wminkowski", "seuclidean"]
for metric in VALID_METRICS["brute"]:
if metric != "precomputed" and metric not in require_params:
else:
for metric_params in metric_params_list:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
)
if metric != "haversine":
nn.fit(X)
nn.kneighbors(X)
else:
nn.fit(X[:, :2])
nn.kneighbors(X[:, :2])
elif metric == "precomputed":
X_precomputed = rng.random_sample((10, 4))
Y_precomputed = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
DYX = metrics.pairwise_distances(
Y_precomputed, X_precomputed, metric="euclidean"
n_neighbors=3,
algorithm="auto",
metric=metric,
metric_params=metric_params,
)
nb_p = neighbors.NearestNeighbors(n_neighbors=3)
nb_p.fit(DXX)
nb_p.kneighbors(DYX)
# Haversine distance only accepts 2D data
if metric == "haversine":
X = np.ascontiguousarray(X[:, :2])

for metric in VALID_METRICS_SPARSE["brute"]:
if metric != "precomputed" and metric not in require_params:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
).fit(Xcsr)
nn.kneighbors(Xcsr)

# Metric with parameter
VI = np.dot(X, X.T)
list_metrics = [
("seuclidean", dict(V=rng.rand(12))),
("wminkowski", dict(w=rng.rand(12))),
("mahalanobis", dict(VI=VI)),
]
for metric, params in list_metrics:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric, metric_params=params
).fit(X)
nn.kneighbors(X)
nn.fit(X)
nn.kneighbors(X)

if metric in VALID_METRICS_SPARSE["brute"]:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
).fit(Xcsr)
nn.kneighbors(Xcsr)


def test_metric_params_interface():
Expand Down