From a2ce47c37bd2752224a55478d25ea17a41374300 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Dec 2022 11:24:13 +0100 Subject: [PATCH 1/5] FIX Remove spurious warnings The condition was not the proper one: it must be a conjuction and not a disjonction. This is rewritten more naturally using a negation. --- .../metrics/_pairwise_distances_reduction/_argkmin.pyx.tp | 6 +++--- .../_pairwise_distances_reduction/_radius_neighbors.pyx.tp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index eec2e2aabdd06..cee7797f7511e 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -331,9 +331,9 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): ): if ( metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs + len(metric_kwargs) > 0 and not ( + "Y_norm_squared" in metric_kwargs or + "X_norm_squared" in metric_kwargs ) ): warnings.warn( diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index 0fdc3bb50203f..04525e7dee1c8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -337,9 +337,9 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} ): if ( metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs + len(metric_kwargs) > 0 and not ( + "Y_norm_squared" in metric_kwargs or + "X_norm_squared" in metric_kwargs ) ): warnings.warn( From 6be2e73b732d3ebdfe158cf802cdba57d03dacff Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Dec 2022 16:10:04 +0100 Subject: [PATCH 2/5] Add non-regression tests --- .../test_pairwise_distances_reduction.py | 66 ++++++++++++++++--- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c334087c65448..5df5cb9e06c8e 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,5 +1,6 @@ import itertools import re +import warnings from collections import defaultdict import numpy as np @@ -620,19 +621,44 @@ def test_argkmin_factory_method_wrong_usages(): with pytest.raises(ValueError, match="ndarray is not C-contiguous"): ArgKmin.compute(X=np.asfortranarray(X), Y=Y, k=k, metric=metric) + # A UserWarning must be raised in this case. unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(" - r"EuclideanArgKmin64." - ) + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): ArgKmin.compute( X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + def test_radius_neighbors_factory_method_wrong_usages(): rng = np.random.RandomState(1) @@ -683,16 +709,38 @@ def test_radius_neighbors_factory_method_wrong_usages(): unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(EuclideanRadiusNeighbors64" - ) + # A UserWarning must be raised in this case. + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): RadiusNeighbors.compute( X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] From fc4b9df335b82fa7263e3c56c0b85dd9f5db9c98 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Dec 2022 16:10:31 +0100 Subject: [PATCH 3/5] Use more natural conditions --- .../metrics/_pairwise_distances_reduction/_argkmin.pyx.tp | 7 ++----- .../_pairwise_distances_reduction/_radius_neighbors.pyx.tp | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index cee7797f7511e..088446dd6c0f4 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -330,11 +330,8 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and not ( - "Y_norm_squared" in metric_kwargs or - "X_norm_squared" in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + len(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) > 0 ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index 04525e7dee1c8..ad4cef85f5684 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -336,11 +336,8 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and not ( - "Y_norm_squared" in metric_kwargs or - "X_norm_squared" in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + len(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) > 0 ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " From a1c533f9bc2a3bdcc06c3eee1c1784985e2ec9dd Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Dec 2022 16:16:02 +0100 Subject: [PATCH 4/5] fixup! Add non-regression tests --- .../metrics/tests/test_pairwise_distances_reduction.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 5df5cb9e06c8e..4fe8013cd3602 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -741,6 +741,16 @@ def test_radius_neighbors_factory_method_wrong_usages(): X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs ) + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] From f5dcc59d38f1ab19d95a045fd5b836f96913f84b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Dec 2022 17:26:14 +0100 Subject: [PATCH 5/5] Use the truthiness of a non-empty set for conditions Co-authored-by: Thomas J. Fan --- sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp | 2 +- .../_pairwise_distances_reduction/_radius_neighbors.pyx.tp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index 088446dd6c0f4..b8afe5c3cd5f8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -331,7 +331,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): ): if ( isinstance(metric_kwargs, dict) and - len(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) > 0 + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index ad4cef85f5684..b3f20cac3ea08 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -337,7 +337,7 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} ): if ( isinstance(metric_kwargs, dict) and - len(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) > 0 + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "