diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index eec2e2aabdd06..b8afe5c3cd5f8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -330,11 +330,8 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index 0fdc3bb50203f..b3f20cac3ea08 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -336,11 +336,8 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c334087c65448..4fe8013cd3602 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,5 +1,6 @@ import itertools import re +import warnings from collections import defaultdict import numpy as np @@ -620,19 +621,44 @@ def test_argkmin_factory_method_wrong_usages(): with pytest.raises(ValueError, match="ndarray is not C-contiguous"): ArgKmin.compute(X=np.asfortranarray(X), Y=Y, k=k, metric=metric) + # A UserWarning must be raised in this case. unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(" - r"EuclideanArgKmin64." - ) + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): ArgKmin.compute( X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + def test_radius_neighbors_factory_method_wrong_usages(): rng = np.random.RandomState(1) @@ -683,16 +709,48 @@ def test_radius_neighbors_factory_method_wrong_usages(): unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(EuclideanRadiusNeighbors64" - ) + # A UserWarning must be raised in this case. + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): RadiusNeighbors.compute( X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)]