Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c9fa7d4

Browse files
FIX KNeighbor classes correctly set positive_only tag (#30372)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 0e0df36 commit c9fa7d4

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

sklearn/neighbors/_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,8 @@ def __sklearn_tags__(self):
709709
tags = super().__sklearn_tags__()
710710
# For cross-validation routines to split data correctly
711711
tags.input_tags.pairwise = self.metric == "precomputed"
712+
# when input is precomputed metric values, all those values need to be positive
713+
tags.input_tags.positive_only = tags.input_tags.pairwise
712714
tags.input_tags.allow_nan = self.metric == "nan_euclidean"
713715
return tags
714716

sklearn/utils/_tags.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class InputTags:
5858
Specifically, this tag is used by
5959
`sklearn.utils.metaestimators._safe_split` to slice rows and
6060
columns.
61+
62+
Note that if setting this tag to ``True`` means the estimator can take only
63+
positive values, the `positive_only` tag must reflect it and also be set to
64+
``True``.
6165
"""
6266

6367
one_d_array: bool = False

sklearn/utils/estimator_checks.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _yield_api_checks(estimator):
148148
yield check_do_not_raise_errors_in_init_or_set_params
149149
yield check_n_features_in_after_fitting
150150
yield check_mixin_order
151+
yield check_positive_only_tag_during_fit
151152

152153

153154
def _yield_checks(estimator):
@@ -3899,6 +3900,39 @@ def _enforce_estimator_tags_X(estimator, X, X_test=None, kernel=linear_kernel):
38993900
return X_res
39003901

39013902

3903+
@ignore_warnings(category=FutureWarning)
3904+
def check_positive_only_tag_during_fit(name, estimator_orig):
3905+
"""Test that the estimator correctly sets the tags.input_tags.positive_only
3906+
3907+
If the tag is False, the estimator should accept negative input regardless of the
3908+
tags.input_tags.pairwise flag.
3909+
"""
3910+
estimator = clone(estimator_orig)
3911+
tags = get_tags(estimator)
3912+
3913+
X, y = load_iris(return_X_y=True)
3914+
y = _enforce_estimator_tags_y(estimator, y)
3915+
set_random_state(estimator, 0)
3916+
X = _enforce_estimator_tags_X(estimator, X)
3917+
X -= X.mean()
3918+
3919+
if tags.input_tags.positive_only:
3920+
with raises(ValueError, match="Negative values in data"):
3921+
estimator.fit(X, y)
3922+
else:
3923+
# This should pass
3924+
try:
3925+
estimator.fit(X, y)
3926+
except Exception as e:
3927+
err_msg = (
3928+
f"Estimator {repr(name)} raised {e.__class__.__name__} unexpectedly."
3929+
" This happens when passing negative input values as X."
3930+
" If negative values are not supported for this estimator instance,"
3931+
" then the tags.input_tags.positive_only tag needs to be set to True."
3932+
)
3933+
raise AssertionError(err_msg) from e
3934+
3935+
39023936
@ignore_warnings(category=FutureWarning)
39033937
def check_non_transformer_estimators_n_iter(name, estimator_orig):
39043938
# Test that estimators that are not transformers with a parameter

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
check_outlier_contamination,
8686
check_outlier_corruption,
8787
check_parameters_default_constructible,
88+
check_positive_only_tag_during_fit,
8889
check_regressor_data_not_an_array,
8990
check_requires_y_none,
9091
check_sample_weights_pandas_series,
@@ -509,7 +510,7 @@ class RequiresPositiveXRegressor(LinearRegression):
509510
def fit(self, X, y):
510511
X, y = validate_data(self, X, y, multi_output=True)
511512
if (X < 0).any():
512-
raise ValueError("negative X values not supported!")
513+
raise ValueError("Negative values in data passed to X.")
513514
return super().fit(X, y)
514515

515516
def __sklearn_tags__(self):
@@ -1600,3 +1601,18 @@ def fit(self, X, y=None):
16001601
msg = "TransformerMixin comes before/left side of BaseEstimator"
16011602
with raises(AssertionError, match=re.escape(msg)):
16021603
check_mixin_order("BadEstimator", BadEstimator())
1604+
1605+
1606+
def test_check_positive_only_tag_during_fit():
1607+
class RequiresPositiveXBadTag(RequiresPositiveXRegressor):
1608+
def __sklearn_tags__(self):
1609+
tags = super().__sklearn_tags__()
1610+
tags.input_tags.positive_only = False
1611+
return tags
1612+
1613+
with raises(
1614+
AssertionError, match="This happens when passing negative input values as X."
1615+
):
1616+
check_positive_only_tag_during_fit(
1617+
"RequiresPositiveXBadTag", RequiresPositiveXBadTag()
1618+
)

0 commit comments

Comments
 (0)