Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX warn if an estimator does have a concrete __sklearn_tags__ implementation #30516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 6, 2025
4 changes: 4 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.utils/30516.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Raise a `DeprecationWarning` when there is no concrete implementation of `__sklearn_tags__`
in the MRO of the estimator. We request to inherit from `BaseEstimator` that
implements `__sklearn_tags__`.
By :user:`Guillaume Lemaitre <glemaitre>`
27 changes: 26 additions & 1 deletion sklearn/utils/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,32 @@ def get_tags(estimator) -> Tags:
tag_provider = _find_tags_provider(estimator)

if tag_provider == "__sklearn_tags__":
tags = estimator.__sklearn_tags__()
# TODO(1.7): turn the warning into an error
try:
tags = estimator.__sklearn_tags__()
except AttributeError as exc:
if str(exc) == "'super' object has no attribute '__sklearn_tags__'":
# workaround the regression reported in
# https://github.com/scikit-learn/scikit-learn/issues/30479
# `__sklearn_tags__` is implemented by calling
# `super().__sklearn_tags__()` but there is no `__sklearn_tags__`
# method in the base class.
warnings.warn(
f"The following error was raised: {str(exc)}. It seems that "
"there are no classes that implement `__sklearn_tags__` "
"in the MRO and/or all classes in the MRO call "
"`super().__sklearn_tags__()`. Make sure to inherit from "
"`BaseEstimator` which implements `__sklearn_tags__` (or "
"alternatively define `__sklearn_tags__` but we don't recommend "
"this approach). Note that `BaseEstimator` needs to be on the "
"right side of other Mixins in the inheritance order. The "
"default are now used instead since retrieving tags failed. "
"This warning will be replaced by an error in 1.7.",
category=DeprecationWarning,
)
tags = default_tags(estimator)
else:
raise
else:
# TODO(1.7): Remove this branch of the code
# Let's go through the MRO and patch each class implementing _more_tags
Expand Down
48 changes: 48 additions & 0 deletions sklearn/utils/tests/test_tags.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from dataclasses import dataclass, fields

import numpy as np
import pytest

from sklearn.base import (
BaseEstimator,
ClassifierMixin,
RegressorMixin,
TransformerMixin,
)
from sklearn.pipeline import Pipeline
from sklearn.utils import (
ClassifierTags,
InputTags,
Expand Down Expand Up @@ -629,3 +632,48 @@ def __sklearn_tags__(self):
}
assert old_tags == expected_tags
assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags


# TODO(1.7): Remove this test
def test_tags_no_sklearn_tags_concrete_implementation():
"""Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/30479

There is no class implementing `__sklearn_tags__` without calling
`super().__sklearn_tags__()`. Thus, we raise a warning and request to inherit from
`BaseEstimator` that implements `__sklearn_tags__`.
"""

class MyEstimator(ClassifierMixin):
def __init__(self, *, param=1):
self.param = param

def fit(self, X, y=None):
self.is_fitted_ = True
return self

def predict(self, X):
return np.full(shape=X.shape[0], fill_value=self.param)

X = np.array([[1, 2], [2, 3], [3, 4]])
y = np.array([1, 0, 1])

my_pipeline = Pipeline([("estimator", MyEstimator(param=1))])
with pytest.warns(DeprecationWarning, match="The following error was raised"):
my_pipeline.fit(X, y).predict(X)

# check that we still raise an error if it is not a AttributeError or related to
# __sklearn_tags__
class MyEstimator2(MyEstimator, BaseEstimator):
def __init__(self, *, param=1, error_type=AttributeError):
self.param = param
self.error_type = error_type

def __sklearn_tags__(self):
super().__sklearn_tags__()
raise self.error_type("test")

for error_type in (AttributeError, TypeError, ValueError):
estimator = MyEstimator2(param=1, error_type=error_type)
with pytest.raises(error_type):
get_tags(estimator)
Loading