diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index e8b9e24a280c1..e2549cbde0807 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -505,6 +505,30 @@ def test_pairwise_distances_argmin_min(): np.testing.assert_almost_equal(dist_orig_ind, dist_chunked_ind, decimal=7) np.testing.assert_almost_equal(dist_orig_val, dist_chunked_val, decimal=7) + # Changing the axis and permuting datasets must give the same results + argmin_0, dist_0 = pairwise_distances_argmin_min(X, Y, axis=0) + argmin_1, dist_1 = pairwise_distances_argmin_min(Y, X, axis=1) + + assert_allclose(dist_0, dist_1) + assert_array_equal(argmin_0, argmin_1) + + argmin_0, dist_0 = pairwise_distances_argmin_min(X, X, axis=0) + argmin_1, dist_1 = pairwise_distances_argmin_min(X, X, axis=1) + + assert_allclose(dist_0, dist_1) + assert_array_equal(argmin_0, argmin_1) + + # Changing the axis and permuting datasets must give the same results + argmin_0 = pairwise_distances_argmin(X, Y, axis=0) + argmin_1 = pairwise_distances_argmin(Y, X, axis=1) + + assert_array_equal(argmin_0, argmin_1) + + argmin_0 = pairwise_distances_argmin(X, X, axis=0) + argmin_1 = pairwise_distances_argmin(X, X, axis=1) + + assert_array_equal(argmin_0, argmin_1) + def _reduce_func(dist, start): return dist[:, :100]