Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 42831e5

Browse files
glemaitreadrinjalalithomasjpfanjeremiedbb
committed
FIX warn if an estimator does have a concrete __sklearn_tags__ implementation (#30516)
Co-authored-by: Adrin Jalali <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 0d2ce43 commit 42831e5

File tree

3 files changed

+78
-1
lines changed

3 files changed

+78
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- Raise a `DeprecationWarning` when there is no concrete implementation of `__sklearn_tags__`
2+
in the MRO of the estimator. We request to inherit from `BaseEstimator` that
3+
implements `__sklearn_tags__`.
4+
By :user:`Guillaume Lemaitre <glemaitre>`

sklearn/utils/_tags.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,32 @@ def get_tags(estimator) -> Tags:
393393
tag_provider = _find_tags_provider(estimator)
394394

395395
if tag_provider == "__sklearn_tags__":
396-
tags = estimator.__sklearn_tags__()
396+
# TODO(1.7): turn the warning into an error
397+
try:
398+
tags = estimator.__sklearn_tags__()
399+
except AttributeError as exc:
400+
if str(exc) == "'super' object has no attribute '__sklearn_tags__'":
401+
# workaround the regression reported in
402+
# https://github.com/scikit-learn/scikit-learn/issues/30479
403+
# `__sklearn_tags__` is implemented by calling
404+
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__`
405+
# method in the base class.
406+
warnings.warn(
407+
f"The following error was raised: {str(exc)}. It seems that "
408+
"there are no classes that implement `__sklearn_tags__` "
409+
"in the MRO and/or all classes in the MRO call "
410+
"`super().__sklearn_tags__()`. Make sure to inherit from "
411+
"`BaseEstimator` which implements `__sklearn_tags__` (or "
412+
"alternatively define `__sklearn_tags__` but we don't recommend "
413+
"this approach). Note that `BaseEstimator` needs to be on the "
414+
"right side of other Mixins in the inheritance order. The "
415+
"default are now used instead since retrieving tags failed. "
416+
"This warning will be replaced by an error in 1.7.",
417+
category=DeprecationWarning,
418+
)
419+
tags = default_tags(estimator)
420+
else:
421+
raise
397422
else:
398423
# TODO(1.7): Remove this branch of the code
399424
# Let's go through the MRO and patch each class implementing _more_tags

sklearn/utils/tests/test_tags.py

+48
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from dataclasses import dataclass, fields
22

3+
import numpy as np
34
import pytest
45

56
from sklearn.base import (
67
BaseEstimator,
8+
ClassifierMixin,
79
RegressorMixin,
810
TransformerMixin,
911
)
12+
from sklearn.pipeline import Pipeline
1013
from sklearn.utils import (
1114
ClassifierTags,
1215
InputTags,
@@ -637,3 +640,48 @@ def __sklearn_tags__(self):
637640
}
638641
assert old_tags == expected_tags
639642
assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags
643+
644+
645+
# TODO(1.7): Remove this test
646+
def test_tags_no_sklearn_tags_concrete_implementation():
647+
"""Non-regression test for:
648+
https://github.com/scikit-learn/scikit-learn/issues/30479
649+
650+
There is no class implementing `__sklearn_tags__` without calling
651+
`super().__sklearn_tags__()`. Thus, we raise a warning and request to inherit from
652+
`BaseEstimator` that implements `__sklearn_tags__`.
653+
"""
654+
655+
class MyEstimator(ClassifierMixin):
656+
def __init__(self, *, param=1):
657+
self.param = param
658+
659+
def fit(self, X, y=None):
660+
self.is_fitted_ = True
661+
return self
662+
663+
def predict(self, X):
664+
return np.full(shape=X.shape[0], fill_value=self.param)
665+
666+
X = np.array([[1, 2], [2, 3], [3, 4]])
667+
y = np.array([1, 0, 1])
668+
669+
my_pipeline = Pipeline([("estimator", MyEstimator(param=1))])
670+
with pytest.warns(DeprecationWarning, match="The following error was raised"):
671+
my_pipeline.fit(X, y).predict(X)
672+
673+
# check that we still raise an error if it is not a AttributeError or related to
674+
# __sklearn_tags__
675+
class MyEstimator2(MyEstimator, BaseEstimator):
676+
def __init__(self, *, param=1, error_type=AttributeError):
677+
self.param = param
678+
self.error_type = error_type
679+
680+
def __sklearn_tags__(self):
681+
super().__sklearn_tags__()
682+
raise self.error_type("test")
683+
684+
for error_type in (AttributeError, TypeError, ValueError):
685+
estimator = MyEstimator2(param=1, error_type=error_type)
686+
with pytest.raises(error_type):
687+
get_tags(estimator)

0 commit comments

Comments
 (0)