Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 03e516f

Browse files
committed
Plug 'fast_sqeuclidean' strategy implementation and test for pairwise_distances_argmin_min
1 parent e1bb0a1 commit 03e516f

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

sklearn/metrics/pairwise.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -646,19 +646,25 @@ def pairwise_distances_argmin_min(
646646
"""
647647
X, Y = check_pairwise_arrays(X, Y)
648648

649-
if metric_kwargs is None:
650-
metric_kwargs = {}
649+
if metric == "fast_sqeuclidean":
650+
# TODO: generalise this simple plug here
651+
values, indices = _argkmin(X, Y, k=1, strategy="auto", return_distance=True)
652+
values = np.ndarray.flatten(values)
653+
indices = np.ndarray.flatten(indices)
654+
else:
655+
if metric_kwargs is None:
656+
metric_kwargs = {}
651657

652-
if axis == 0:
653-
X, Y = Y, X
658+
if axis == 0:
659+
X, Y = Y, X
654660

655-
indices, values = zip(
656-
*pairwise_distances_chunked(
657-
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
661+
indices, values = zip(
662+
*pairwise_distances_chunked(
663+
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
664+
)
658665
)
659-
)
660-
indices = np.concatenate(indices)
661-
values = np.concatenate(values)
666+
indices = np.concatenate(indices)
667+
values = np.concatenate(values)
662668

663669
return indices, values
664670

sklearn/metrics/tests/test_pairwise.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,3 +1477,25 @@ def test_numeric_pairwise_distances_datatypes(metric, dtype, y_is_x):
14771477
# and fails due to rounding errors
14781478
rtol = 1e-5 if dtype is np.float32 else 1e-7
14791479
assert_allclose(dist, expected_dist, rtol=rtol)
1480+
1481+
1482+
@pytest.mark.parametrize("n", [10 ** i for i in [2, 3, 4]])
1483+
@pytest.mark.parametrize("d", [5, 10, 100])
1484+
@pytest.mark.parametrize("X_translation", [10 ** i for i in [2, 3, 4, 5, 6, 7]])
1485+
@pytest.mark.parametrize("Y_translation", [10 ** i for i in [2, 3, 4, 5, 6, 7]])
1486+
@pytest.mark.parametrize("sign", [1, -1])
1487+
def test_fast_sqeuclidean_correctness(n, d, X_translation, Y_translation, sign):
1488+
1489+
rng = np.random.RandomState(1)
1490+
1491+
# Translating to test numerical stability
1492+
X = X_translation + rng.rand(int(n * d)).reshape((-1, d))
1493+
Y = sign * Y_translation + rng.rand(int(n * d)).reshape((-1, d))
1494+
1495+
argmins, distances = pairwise_distances_argmin_min(X, Y,
1496+
metric="euclidean")
1497+
fsq_argmins, fsq_distances = pairwise_distances_argmin_min(X, Y,
1498+
metric="fast_sqeuclidean")
1499+
1500+
np.testing.assert_array_equal(argmins, fsq_argmins)
1501+
np.testing.assert_almost_equal(distances, fsq_distances)

0 commit comments

Comments
 (0)