Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Estimator creating _more_tags and inheriting from BaseEstimator will not warn about old tag infrastructure #30257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
glemaitre opened this issue Nov 9, 2024 · 4 comments · Fixed by #30327
Labels
Milestone

Comments

@glemaitre
Copy link
Member

While making the code of skrub compatible with scikit-learn 1.6, I found that the following is really surprising:

# %%
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin

class MyRegressor(RegressorMixin, BaseEstimator):
    def __init__(self, seed=None):
        self.seed = seed

    def fit(self, X, y):
        self.rng_ = np.random.default_rng(self.seed)
        return self

    def predict(self, X):
        return self.rng_.normal(size=X.shape[0])

    def _more_tags(self):
        return {
            "multioutput": True
        }


# %%
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=10, n_features=5, random_state=42)
regressor = MyRegressor(seed=42).fit(X, y)
regressor.predict(X)

# %%
from sklearn.utils import get_tags

tags = get_tags(regressor)  # does not warn because we inherit from BaseEstimator
tags.target_tags.multi_output  # does not use anymore the _more_tags and thus is wrong

In the code above, because we inherit from BaseEstimator and RegressorMixin, we have the default tags set with the methods __sklearn_tags__.

However, the previous code that we had was using _more_tags.

Currently, get_tags will not warn that something is going wrong because we will fallback on the default tags from the base class and mixins.

I think that we should:

  • use the values defined in _more_tags and warn for the future change
  • in the future we should error if we have both _more_tags and __sklearn_tags__ to be sure that people stop using _more_tags
@thomasjpfan
Copy link
Member

use the values defined in _more_tags and warn for the future change

There are two paths to continue using _more_tags:

  1. Bring back all the _get_tags + _more_tags logic from before. Bringing back _get_tags will introduce some complexity when both _get_tags and __sklearn_tags__ are defined.
  2. Only support _more_tags. If _more_tags exist, then have __sklearn_tags__ create a Tag object that uses the configuration from _more_tags. (This means we'll need to translate from the old dict format to the new dataclasses.)

I'm in-favor of 2.

@ogrisel
Copy link
Member

ogrisel commented Nov 22, 2024

I am also in favor 2.

@glemaitre
Copy link
Member Author

Oh I forgot about this issue.

@glemaitre
Copy link
Member Author

glemaitre commented Nov 22, 2024

So this issue really happen for third-party developer: #30324

There, I could witness that #30327 is solving the issue.

#30327 is reintroducing _get_tags because we will still break code of people. So there is definitely a complexity because you need to convert __sklearn_tags__ that do inheritance into a _more_tags that just add tag for the class itself. But I think that I got it right with the different test. It really provide a fully backward compatible solution.

And I'm not sure that option would be enough. Translating _more_tags to a Tags means that you will create some default values. Since __sklearn_tags__ use the inheritance, having the default tags will wrongly override the class of a child class I think.

Edit: actually @thomasjpfan, I was thinking the walk in the MRO reversed :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants