Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e99cd11

Browse files
authored
FIX Remove spurious UserWarning (#25129)
1 parent 929b3dc commit e99cd11

File tree

3 files changed

+71
-19
lines changed

3 files changed

+71
-19
lines changed

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,8 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
330330
metric_kwargs=None,
331331
):
332332
if (
333-
metric_kwargs is not None and
334-
len(metric_kwargs) > 0 and (
335-
"Y_norm_squared" not in metric_kwargs or
336-
"X_norm_squared" not in metric_kwargs
337-
)
333+
isinstance(metric_kwargs, dict) and
334+
(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"})
338335
):
339336
warnings.warn(
340337
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,8 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}
336336
metric_kwargs=None,
337337
):
338338
if (
339-
metric_kwargs is not None and
340-
len(metric_kwargs) > 0 and (
341-
"Y_norm_squared" not in metric_kwargs or
342-
"X_norm_squared" not in metric_kwargs
343-
)
339+
isinstance(metric_kwargs, dict) and
340+
(metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"})
344341
):
345342
warnings.warn(
346343
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "

sklearn/metrics/tests/test_pairwise_distances_reduction.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import re
3+
import warnings
34
from collections import defaultdict
45

56
import numpy as np
@@ -620,19 +621,44 @@ def test_argkmin_factory_method_wrong_usages():
620621
with pytest.raises(ValueError, match="ndarray is not C-contiguous"):
621622
ArgKmin.compute(X=np.asfortranarray(X), Y=Y, k=k, metric=metric)
622623

624+
# A UserWarning must be raised in this case.
623625
unused_metric_kwargs = {"p": 3}
624626

625-
message = (
626-
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
627-
r" case \("
628-
r"EuclideanArgKmin64."
629-
)
627+
message = r"Some metric_kwargs have been passed \({'p': 3}\) but"
630628

631629
with pytest.warns(UserWarning, match=message):
632630
ArgKmin.compute(
633631
X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs
634632
)
635633

634+
# A UserWarning must be raised in this case.
635+
metric_kwargs = {
636+
"p": 3, # unused
637+
"Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2),
638+
}
639+
640+
message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'"
641+
642+
with pytest.warns(UserWarning, match=message):
643+
ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs)
644+
645+
# No user warning must be raised in this case.
646+
metric_kwargs = {
647+
"X_norm_squared": sqeuclidean_row_norms(X, num_threads=2),
648+
}
649+
with warnings.catch_warnings():
650+
warnings.simplefilter("error", category=UserWarning)
651+
ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs)
652+
653+
# No user warning must be raised in this case.
654+
metric_kwargs = {
655+
"X_norm_squared": sqeuclidean_row_norms(X, num_threads=2),
656+
"Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2),
657+
}
658+
with warnings.catch_warnings():
659+
warnings.simplefilter("error", category=UserWarning)
660+
ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs)
661+
636662

637663
def test_radius_neighbors_factory_method_wrong_usages():
638664
rng = np.random.RandomState(1)
@@ -683,16 +709,48 @@ def test_radius_neighbors_factory_method_wrong_usages():
683709

684710
unused_metric_kwargs = {"p": 3}
685711

686-
message = (
687-
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
688-
r" case \(EuclideanRadiusNeighbors64"
689-
)
712+
# A UserWarning must be raised in this case.
713+
message = r"Some metric_kwargs have been passed \({'p': 3}\) but"
690714

691715
with pytest.warns(UserWarning, match=message):
692716
RadiusNeighbors.compute(
693717
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs
694718
)
695719

720+
# A UserWarning must be raised in this case.
721+
metric_kwargs = {
722+
"p": 3, # unused
723+
"Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2),
724+
}
725+
726+
message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'"
727+
728+
with pytest.warns(UserWarning, match=message):
729+
RadiusNeighbors.compute(
730+
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs
731+
)
732+
733+
# No user warning must be raised in this case.
734+
metric_kwargs = {
735+
"X_norm_squared": sqeuclidean_row_norms(X, num_threads=2),
736+
"Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2),
737+
}
738+
with warnings.catch_warnings():
739+
warnings.simplefilter("error", category=UserWarning)
740+
RadiusNeighbors.compute(
741+
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs
742+
)
743+
744+
# No user warning must be raised in this case.
745+
metric_kwargs = {
746+
"X_norm_squared": sqeuclidean_row_norms(X, num_threads=2),
747+
}
748+
with warnings.catch_warnings():
749+
warnings.simplefilter("error", category=UserWarning)
750+
RadiusNeighbors.compute(
751+
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs
752+
)
753+
696754

697755
@pytest.mark.parametrize(
698756
"n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)]

0 commit comments

Comments
 (0)