|
22 | 22 | from sklearn.metrics.pairwise import pairwise_distances
|
23 | 23 | from sklearn.model_selection import cross_val_score
|
24 | 24 | from sklearn.model_selection import train_test_split
|
25 |
| -from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS |
| 25 | +from sklearn.neighbors import VALID_METRICS_SPARSE |
26 | 26 | from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed
|
27 | 27 | from sklearn.pipeline import make_pipeline
|
28 | 28 | from sklearn.utils._testing import assert_array_almost_equal
|
|
58 | 58 | neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph)
|
59 | 59 |
|
60 | 60 |
|
| 61 | +def _generate_test_params_for(metric: str, n_features: int): |
| 62 | + """Return list of dummy DistanceMetric kwargs for tests.""" |
| 63 | + |
| 64 | + # Distinguishing on cases not to compute unneeded datastructures. |
| 65 | + rng = np.random.RandomState(1) |
| 66 | + weights = rng.random_sample(n_features) |
| 67 | + |
| 68 | + if metric == "minkowski": |
| 69 | + minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)] |
| 70 | + if sp_version >= parse_version("1.8.0.dev0"): |
| 71 | + # TODO: remove the test once we no longer support scipy < 1.8.0. |
| 72 | + # Recent scipy versions accept weights in the Minkowski metric directly: |
| 73 | + # type: ignore |
| 74 | + minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) |
| 75 | + return minkowski_kwargs |
| 76 | + |
| 77 | + # TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0. |
| 78 | + if metric == "wminkowski": |
| 79 | + weights /= weights.sum() |
| 80 | + wminkowski_kwargs = [dict(p=1.5, w=weights)] |
| 81 | + if sp_version < parse_version("1.8.0.dev0"): |
| 82 | + # wminkowski was removed in scipy 1.8.0 but should work for previous |
| 83 | + # versions. |
| 84 | + wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) |
| 85 | + return wminkowski_kwargs |
| 86 | + |
| 87 | + if metric == "seuclidean": |
| 88 | + return [dict(V=rng.rand(n_features))] |
| 89 | + |
| 90 | + if metric == "mahalanobis": |
| 91 | + A = rng.rand(n_features, n_features) |
| 92 | + # Make the matrix symmetric positive definite |
| 93 | + VI = A + A.T + 3 * np.eye(n_features) |
| 94 | + return [dict(VI=VI)] |
| 95 | + |
| 96 | + # Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric. |
| 97 | + # In those cases, no kwargs are needed. |
| 98 | + return [{}] |
| 99 | + |
| 100 | + |
61 | 101 | def _weight_func(dist):
|
62 | 102 | """Weight function to replace lambda d: d ** -2.
|
63 | 103 | The lambda function is not valid because:
|
@@ -1385,58 +1425,48 @@ def custom_metric(x1, x2):
|
1385 | 1425 |
|
1386 | 1426 | # TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
1387 | 1427 | @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
1388 |
| -def test_valid_brute_metric_for_auto_algorithm(): |
1389 |
| - X = rng.rand(12, 12) |
| 1428 | +@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"]) |
| 1429 | +def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12): |
| 1430 | + # Any valid metric for algorithm="brute" must be a valid for algorithm="auto". |
| 1431 | + # It's the responsibility of the estimator to select which algorithm is likely |
| 1432 | + # to be the most efficient from the subset of the algorithm compatible with |
| 1433 | + # that metric (and params). Worst case is to fallback to algorithm="brute". |
| 1434 | + X = rng.rand(n_samples, n_features) |
1390 | 1435 | Xcsr = csr_matrix(X)
|
1391 | 1436 |
|
1392 |
| - # check that there is a metric that is valid for brute |
1393 |
| - # but not ball_tree (so we actually test something) |
1394 |
| - assert "cosine" in VALID_METRICS["brute"] |
1395 |
| - assert "cosine" not in VALID_METRICS["ball_tree"] |
| 1437 | + metric_params_list = _generate_test_params_for(metric, n_features) |
| 1438 | + |
| 1439 | + if metric == "precomputed": |
| 1440 | + X_precomputed = rng.random_sample((10, 4)) |
| 1441 | + Y_precomputed = rng.random_sample((3, 4)) |
| 1442 | + DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") |
| 1443 | + DYX = metrics.pairwise_distances( |
| 1444 | + Y_precomputed, X_precomputed, metric="euclidean" |
| 1445 | + ) |
| 1446 | + nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed") |
| 1447 | + nb_p.fit(DXX) |
| 1448 | + nb_p.kneighbors(DYX) |
1396 | 1449 |
|
1397 |
| - # Metric which don't required any additional parameter |
1398 |
| - require_params = ["mahalanobis", "wminkowski", "seuclidean"] |
1399 |
| - for metric in VALID_METRICS["brute"]: |
1400 |
| - if metric != "precomputed" and metric not in require_params: |
| 1450 | + else: |
| 1451 | + for metric_params in metric_params_list: |
1401 | 1452 | nn = neighbors.NearestNeighbors(
|
1402 |
| - n_neighbors=3, algorithm="auto", metric=metric |
1403 |
| - ) |
1404 |
| - if metric != "haversine": |
1405 |
| - nn.fit(X) |
1406 |
| - nn.kneighbors(X) |
1407 |
| - else: |
1408 |
| - nn.fit(X[:, :2]) |
1409 |
| - nn.kneighbors(X[:, :2]) |
1410 |
| - elif metric == "precomputed": |
1411 |
| - X_precomputed = rng.random_sample((10, 4)) |
1412 |
| - Y_precomputed = rng.random_sample((3, 4)) |
1413 |
| - DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean") |
1414 |
| - DYX = metrics.pairwise_distances( |
1415 |
| - Y_precomputed, X_precomputed, metric="euclidean" |
| 1453 | + n_neighbors=3, |
| 1454 | + algorithm="auto", |
| 1455 | + metric=metric, |
| 1456 | + metric_params=metric_params, |
1416 | 1457 | )
|
1417 |
| - nb_p = neighbors.NearestNeighbors(n_neighbors=3) |
1418 |
| - nb_p.fit(DXX) |
1419 |
| - nb_p.kneighbors(DYX) |
| 1458 | + # Haversine distance only accepts 2D data |
| 1459 | + if metric == "haversine": |
| 1460 | + X = np.ascontiguousarray(X[:, :2]) |
1420 | 1461 |
|
1421 |
| - for metric in VALID_METRICS_SPARSE["brute"]: |
1422 |
| - if metric != "precomputed" and metric not in require_params: |
1423 |
| - nn = neighbors.NearestNeighbors( |
1424 |
| - n_neighbors=3, algorithm="auto", metric=metric |
1425 |
| - ).fit(Xcsr) |
1426 |
| - nn.kneighbors(Xcsr) |
1427 |
| - |
1428 |
| - # Metric with parameter |
1429 |
| - VI = np.dot(X, X.T) |
1430 |
| - list_metrics = [ |
1431 |
| - ("seuclidean", dict(V=rng.rand(12))), |
1432 |
| - ("wminkowski", dict(w=rng.rand(12))), |
1433 |
| - ("mahalanobis", dict(VI=VI)), |
1434 |
| - ] |
1435 |
| - for metric, params in list_metrics: |
1436 |
| - nn = neighbors.NearestNeighbors( |
1437 |
| - n_neighbors=3, algorithm="auto", metric=metric, metric_params=params |
1438 |
| - ).fit(X) |
1439 |
| - nn.kneighbors(X) |
| 1462 | + nn.fit(X) |
| 1463 | + nn.kneighbors(X) |
| 1464 | + |
| 1465 | + if metric in VALID_METRICS_SPARSE["brute"]: |
| 1466 | + nn = neighbors.NearestNeighbors( |
| 1467 | + n_neighbors=3, algorithm="auto", metric=metric |
| 1468 | + ).fit(Xcsr) |
| 1469 | + nn.kneighbors(Xcsr) |
1440 | 1470 |
|
1441 | 1471 |
|
1442 | 1472 | def test_metric_params_interface():
|
|
0 commit comments