From c9a8723e60a825a07448104d990e8d9ccca37030 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 4 Feb 2022 10:56:10 +0100 Subject: [PATCH 1/3] TST Complete tests for pairwise_distance_{argmin,argmin_min} --- sklearn/metrics/tests/test_pairwise.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index e8b9e24a280c1..4212567d2f0e4 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -505,6 +505,30 @@ def test_pairwise_distances_argmin_min(): np.testing.assert_almost_equal(dist_orig_ind, dist_chunked_ind, decimal=7) np.testing.assert_almost_equal(dist_orig_val, dist_chunked_val, decimal=7) + # Changing the axe and permuting datasets must give the same results + argmin_0, dist_0 = pairwise_distances_argmin_min(X, Y, axis=0) + argmin_1, dist_1 = pairwise_distances_argmin_min(Y, X, axis=1) + + assert_array_equal(dist_0, dist_1) + assert_array_equal(argmin_0, argmin_1) + + argmin_0, dist_0 = pairwise_distances_argmin_min(X, X, axis=0) + argmin_1, dist_1 = pairwise_distances_argmin_min(X, X, axis=1) + + assert_array_equal(dist_0, dist_1) + assert_array_equal(argmin_0, argmin_1) + + # Changing the axe and permuting datasets must give the same results + argmin_0 = pairwise_distances_argmin(X, Y, axis=0) + argmin_1 = pairwise_distances_argmin(Y, X, axis=1) + + assert_array_equal(argmin_0, argmin_1) + + argmin_0 = pairwise_distances_argmin_min(X, X, axis=0) + argmin_1 = pairwise_distances_argmin_min(X, X, axis=1) + + assert_array_equal(argmin_0, argmin_1) + def _reduce_func(dist, start): return dist[:, :100] From f6caee07bc4572c06ab3fa5c4ab1c3226c599c0c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 4 Feb 2022 17:38:03 +0100 Subject: [PATCH 2/3] CLN Address review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger Co-authored-by: Thomas J. Fan --- sklearn/metrics/tests/test_pairwise.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 4212567d2f0e4..f63bc3e155ff6 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -505,20 +505,20 @@ def test_pairwise_distances_argmin_min(): np.testing.assert_almost_equal(dist_orig_ind, dist_chunked_ind, decimal=7) np.testing.assert_almost_equal(dist_orig_val, dist_chunked_val, decimal=7) - # Changing the axe and permuting datasets must give the same results + # Changing the axis and permuting datasets must give the same results argmin_0, dist_0 = pairwise_distances_argmin_min(X, Y, axis=0) argmin_1, dist_1 = pairwise_distances_argmin_min(Y, X, axis=1) - assert_array_equal(dist_0, dist_1) + assert_allclose(dist_0, dist_1) assert_array_equal(argmin_0, argmin_1) argmin_0, dist_0 = pairwise_distances_argmin_min(X, X, axis=0) argmin_1, dist_1 = pairwise_distances_argmin_min(X, X, axis=1) - assert_array_equal(dist_0, dist_1) + assert_allclose(dist_0, dist_1) assert_array_equal(argmin_0, argmin_1) - # Changing the axe and permuting datasets must give the same results + # Changing the axis and permuting datasets must give the same results argmin_0 = pairwise_distances_argmin(X, Y, axis=0) argmin_1 = pairwise_distances_argmin(Y, X, axis=1) From 6679d86c65d2c3be163d95b6ee81e00b9560f178 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 8 Feb 2022 13:13:50 +0100 Subject: [PATCH 3/3] Use appropriate function to test Co-authored-by: Olivier Grisel --- sklearn/metrics/tests/test_pairwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index f63bc3e155ff6..e2549cbde0807 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -524,8 +524,8 @@ def test_pairwise_distances_argmin_min(): assert_array_equal(argmin_0, argmin_1) - argmin_0 = pairwise_distances_argmin_min(X, X, axis=0) - argmin_1 = pairwise_distances_argmin_min(X, X, axis=1) + argmin_0 = pairwise_distances_argmin(X, X, axis=0) + argmin_1 = pairwise_distances_argmin(X, X, axis=1) assert_array_equal(argmin_0, argmin_1)