You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While making the code of skrub compatible with scikit-learn 1.6, I found that the following is really surprising:
# %%importnumpyasnpfromsklearn.baseimportBaseEstimator, RegressorMixinclassMyRegressor(RegressorMixin, BaseEstimator):
def__init__(self, seed=None):
self.seed=seeddeffit(self, X, y):
self.rng_=np.random.default_rng(self.seed)
returnselfdefpredict(self, X):
returnself.rng_.normal(size=X.shape[0])
def_more_tags(self):
return {
"multioutput": True
}
# %%fromsklearn.datasetsimportmake_regressionX, y=make_regression(n_samples=10, n_features=5, random_state=42)
regressor=MyRegressor(seed=42).fit(X, y)
regressor.predict(X)
# %%fromsklearn.utilsimportget_tagstags=get_tags(regressor) # does not warn because we inherit from BaseEstimatortags.target_tags.multi_output# does not use anymore the _more_tags and thus is wrong
In the code above, because we inherit from BaseEstimator and RegressorMixin, we have the default tags set with the methods __sklearn_tags__.
However, the previous code that we had was using _more_tags.
Currently, get_tags will not warn that something is going wrong because we will fallback on the default tags from the base class and mixins.
I think that we should:
use the values defined in _more_tags and warn for the future change
in the future we should error if we have both _more_tags and __sklearn_tags__ to be sure that people stop using _more_tags
The text was updated successfully, but these errors were encountered:
use the values defined in _more_tags and warn for the future change
There are two paths to continue using _more_tags:
Bring back all the _get_tags + _more_tags logic from before. Bringing back _get_tags will introduce some complexity when both _get_tags and __sklearn_tags__ are defined.
Only support _more_tags. If _more_tags exist, then have __sklearn_tags__ create a Tag object that uses the configuration from _more_tags. (This means we'll need to translate from the old dict format to the new dataclasses.)
So this issue really happen for third-party developer: #30324
There, I could witness that #30327 is solving the issue.
#30327 is reintroducing _get_tags because we will still break code of people. So there is definitely a complexity because you need to convert __sklearn_tags__ that do inheritance into a _more_tags that just add tag for the class itself. But I think that I got it right with the different test. It really provide a fully backward compatible solution.
And I'm not sure that option would be enough. Translating _more_tags to a Tags means that you will create some default values. Since __sklearn_tags__ use the inheritance, having the default tags will wrongly override the class of a child class I think.
Edit: actually @thomasjpfan, I was thinking the walk in the MRO reversed :).
While making the code of
skrub
compatible with scikit-learn 1.6, I found that the following is really surprising:In the code above, because we inherit from
BaseEstimator
andRegressorMixin
, we have the default tags set with the methods__sklearn_tags__
.However, the previous code that we had was using
_more_tags
.Currently,
get_tags
will not warn that something is going wrong because we will fallback on the default tags from the base class and mixins.I think that we should:
_more_tags
and warn for the future change_more_tags
and__sklearn_tags__
to be sure that people stop using_more_tags
The text was updated successfully, but these errors were encountered: