@@ -1477,3 +1477,25 @@ def test_numeric_pairwise_distances_datatypes(metric, dtype, y_is_x):
1477
1477
# and fails due to rounding errors
1478
1478
rtol = 1e-5 if dtype is np .float32 else 1e-7
1479
1479
assert_allclose (dist , expected_dist , rtol = rtol )
1480
+
1481
+
1482
+ @pytest .mark .parametrize ("n" , [10 ** i for i in [2 , 3 , 4 ]])
1483
+ @pytest .mark .parametrize ("d" , [5 , 10 , 100 ])
1484
+ @pytest .mark .parametrize ("X_translation" , [10 ** i for i in [2 , 3 , 4 , 5 , 6 , 7 ]])
1485
+ @pytest .mark .parametrize ("Y_translation" , [10 ** i for i in [2 , 3 , 4 , 5 , 6 , 7 ]])
1486
+ @pytest .mark .parametrize ("sign" , [1 , - 1 ])
1487
+ def test_fast_sqeuclidean_correctness (n , d , X_translation , Y_translation , sign ):
1488
+
1489
+ rng = np .random .RandomState (1 )
1490
+
1491
+ # Translating to test numerical stability
1492
+ X = X_translation + rng .rand (int (n * d )).reshape ((- 1 , d ))
1493
+ Y = sign * Y_translation + rng .rand (int (n * d )).reshape ((- 1 , d ))
1494
+
1495
+ argmins , distances = pairwise_distances_argmin_min (X , Y ,
1496
+ metric = "euclidean" )
1497
+ fsq_argmins , fsq_distances = pairwise_distances_argmin_min (X , Y ,
1498
+ metric = "fast_sqeuclidean" )
1499
+
1500
+ np .testing .assert_array_equal (argmins , fsq_argmins )
1501
+ np .testing .assert_almost_equal (distances , fsq_distances )
0 commit comments