Thanks to visit codestin.com
Credit goes to github.com

Skip to content

TST introduce _safe_tags for estimator not inheriting from BaseEstimator #18797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Dec 2, 2020

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Nov 9, 2020

closes #18820

This PR reintroduce _safe_tags avoiding third-party libraries to either inherit from BaseEstimator or implement the tags.

@glemaitre glemaitre marked this pull request as draft November 9, 2020 16:00
@ogrisel
Copy link
Member

ogrisel commented Nov 9, 2020

We would need a test to check that check_estimator passes on a minimal estimator class that does not inherit from the scikit-learn base classes (at least for mode="compatible" in #18582).

@rth
Copy link
Member

rth commented Nov 9, 2020

I remember we removed _safe_tags some time ago and I was really sure we had a test for an estimator not inheriting from BaseEstimator in sklearn/utils/tests/test_estimator_checks.py. Maybe we "fixed" the test at the same time, not sure...

@rth
Copy link
Member

rth commented Nov 9, 2020

#16950 but it looks like there was no test indeed.

@ogrisel
Copy link
Member

ogrisel commented Nov 9, 2020

The goal would be to honor our hold duck typing contract which I find fundamental to the spirit of the scikit-learn API design:

https://scikit-learn.org/dev/developers/develop.html#rolling-your-own-estimator

BaseEstimator and mixins:

We tend to use “duck typing”, so building an estimator which follows the API suffices for compatibility, without needing to inherit from or even import any scikit-learn classes.

However, if a dependency on scikit-learn is acceptable in your code, you can prevent a lot of boilerplate code by deriving a class from BaseEstimator and optionally the mixin classes in sklearn.base. For example, below is a custom classifier, with more examples included in the scikit-learn-contrib project template.

@NicolasHug
Copy link
Member

NicolasHug commented Nov 9, 2020

The goal would be to honor our hold duck typing contract which I find fundamental to the spirit of the scikit-learn API design

IMHO, having merged #16950 does not violate our duck typing contract: estimators implementing fit, predict etc will work perfectly with the internal tools like cross_validate, etc.

What #16950 does is that it forces estimators to inherit from BaseEstimator in order to run check_estimator. Which is quite different, and I would say it's a reasonable constraint.

We don't force a dependency on sklearn for third-party libraries, we just force a dependency on sklearn for their developers.

So I'd be +0 to bring this back. If we do, let's please explicitly write a comment indicating that this should not be removed, possibly with a link here

@rth
Copy link
Member

rth commented Nov 9, 2020

is that it forces estimators to inherit from BaseEstimator in order to run check_estimator.

A lot of projects wouldn't want to add scikit-learn as a dependency (e.g. lightfgbm, scorch) which in turns means they have no way to way to check that their API is compliant.

+1 to put it back, especially that's it's a minor change on our side.

@jeremiedbb
Copy link
Member

@NicolasHug here (https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator) it says that in order to be a scikit-learn compatible estimator, an estimator must pass check_estimator

@rth
Copy link
Member

rth commented Nov 9, 2020

Though yes, they could actually define a _get_tags manually. If they don't set tags and don't inherit from sklearn, the determination of regressor vs classifier for tests won't work anyway, would it?

Edit: indeed, is_classifier or is_regressor helpers won't produce anything meanigful with arbitrary third party python classes at present. And then the risk is that check_estimator might pass, but not actually run any of the relevant checks which is another reason I think check_estimator should be avoided in favor of parametrize_with_checks #18750 (comment)

@ogrisel
Copy link
Member

ogrisel commented Nov 9, 2020

Edit: indeed, is_classifier or is_regressor helpers won't produce anything meanigful with arbitrary third party python classes at present. And then the risk is that check_estimator might pass, but not actually run any of the relevant checks..

We could add a check that fails if the estimator has a predict method but both is_classifier and is_regressor return False, WDYT?

@NicolasHug
Copy link
Member

@NicolasHug here (scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator) it says that in order to be a scikit-learn compatible estimator, an estimator must pass check_estimator

I understand that but I'm not sure what your point is @jeremiedbb ?

Mine is that scikit-learn is not a dependency for users of third-party libraries. It's only a dependency for their developers. And as the comments above suggest, I doubt that one can get anything meaningful out of check_estimator without inheriting from at least some of our Mixins. (It's the same for parametrize_with_checks BTW)

But I'm not opposing to adding _safe_tags back anyway.

@ogrisel
Copy link
Member

ogrisel commented Nov 9, 2020

which is another reason I think check_estimator should be avoided in favor of parametrize_with_checks

+1 for pushing more 3rd party libraries to use parametrize_with_checks if they use pytest but check_estimator is still useful for interactive checks / demos in ipython sessions or jupyter notebooks.

@jeremiedbb
Copy link
Member

I understand that but I'm not sure what your point is @jeremiedbb ?

@NicolasHug I'm feeling that neither us and our doc is sure about what it requires for a third party estimator to be scikit-learn compat. In the doc it says it needs to pass check_estimator. But in order check_estimator kind of require a dependency on sklearn which we don't want to enforce. My comment was just that I'm a bit confused here :)

@adrinjalali
Copy link
Member

I'm -0 on this one. I see the tags as a part of our API, and estimators should implement them. I'd say if users want to pass check_estimator, they should also inherit from the right classes. They can choose not to pass these tests and their estimators may work w/o an issue in most usecases, but they are not completely scikit-learn compatible.

This PR adds IMO unnecessary complexity to our codebase. This means we should always use _safe_tags instead of getting the tags, everywhere in the codebase, and that I don't see why that's necessary where we can avoid if if we require users to inherit from the right class or implement everything they should, themselves.

@glemaitre
Copy link
Member Author

This means we should always use _safe_tags instead of getting the tags, everywhere in the codebase

This a very good point. If we were limiting the use of tags in the common test, we would not have this issue. However, we started to use the tag elsewhere in the estimators themselves.

@glemaitre
Copy link
Member Author

As I mentioned there, #18798 (comment), if we want solely to make tags part of our API, we might want to isolate this functionality in a Mixin. I am not sure this a good solution though; it will force to already make multiple inheritances in the common case (i.e. BaseEstimator, TagsMixin)

@adrinjalali
Copy link
Member

Do you have a case where the user would want to inherit TagsMixin but not BaseEstimator?

@rth
Copy link
Member

rth commented Nov 10, 2020

I'd say if users want to pass check_estimator, they should also inherit from the right classes
[..] or implement everything they should, themselves.

I don't think we should suggest that they inherit from scikit-learn. Being compatible with scikit-learn API is and should be unrelated to depending on scikit-learn.

However yes, _safe_tags also has its issues. We can provide default tags, but will they be appropriate for the estimator in question? Possibly but not sure until someone actually looks into tags, so they might as well re-implement them. However then for contrib projects the issue is that they would need to also implement tag inheritance via _get_tags, if they have many estimators to avoid having repeated N tags in each estimator.

So in that sense maybe _safe_tags is still useful as it would minimize the amount of work a contrib project would need to do (only tags that are different from defaults).

This means we should always use _safe_tags instead of getting the tags, everywhere in the codebase,

Not everywhere, just in common tests and meta-estimators. We did it before. It's certainly not ideal, but also not such a big deal maybe? Though the issue there is that is can start to be used in contrib projects as well, since they would find it as the way of getting tags in our code base :/

@glemaitre
Copy link
Member Author

It would be someone that needs to redefine set/get/params/states and does not want to check_X_y within _validate_data. Of course, they could inherit their base class from it and overwrite these methods but the easiest way would be to implement your base from scratch for your use case. This looks like the use-case of cuml but I agree that it is not the most common and expected use case.

@adrinjalali
Copy link
Member

Even if we were to accept a third-party estimator as scikit-learn compatible w/o them inheriting from the right classes, I don't see why we'd need to do extra work for them to be compatible while they don't implement the API we require. This puts severe burden on us moving forward while designing our API.

@NicolasHug
Copy link
Member

I don't understand why we introduce this PR: the goal seems to be that we don't want to force third-party libraries to depend on scikit-learn to pass the check suite.

But to pass check_estimator, one needs to call check_estimator, so they need to have scikit-learn, right? Now if their estimator MyEst doesn't inherit from BaseEstimator, all they need to do is

class MyWrappedEst(BaseEstimator, MyEst):
     pass

and call check_estimator on MyWrappedEst instead of MyEst


BTW, I agree with most of what @adrinjalali said but I don't think this is true:

This means we should always use _safe_tags instead of getting the tags, everywhere in the codebase

This is only related to the check suite, not the rest of the code-base

@ogrisel
Copy link
Member

ogrisel commented Nov 10, 2020

But to pass check_estimator, one needs to call check_estimator, so they need to have scikit-learn, right?

test dependencies are not necessarily runtime dependencies.

@NicolasHug
Copy link
Member

NicolasHug commented Nov 10, 2020

test dependencies are not necessarily runtime dependencies.

Yes this is what I'm trying to say since the beginning: currently in master, we don't force third party libraries to have scikit-learn as a runtime dependency.

This PR doesn't change anything w.r.t. dependency. All it does is removing the need for estimators to inherit from BaseEstimator in order to pass check_estimator.

@jeremiedbb
Copy link
Member

My understanding is that check_estimator is a tool that can be used by third party estimators to check that such an estimator follows the scikit-learn api and can be used in scikit-learn model evaluation and model selection tools. So we don't want to force an estimator to inherit from BaseEstimator to pass check_estimator.

On the other hand, if a third party estimator doesn't want to inherit from BaseEstimator, it needs to implement all important parts of the scikit-learn api. tags is a part of that (as listed here https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator).

So I join Nicolas and Adrin and I don't think this PR is necessary since it's clear in the docs that implementing tags is mandatory to be a compatible estimator.

Comment on lines +32 to +35
For scikit-learn built-in estimators, we should still rely on
`self._get_tags()`. `_safe_tags(est)` should be used when we are not sure
where `est` comes from: typically `_safe_tags(self.base_estimator)` where
`self` is a meta-estimator, or in the common checks.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit but using backquotes on doc that will never be rendered as html adds noise imho

default : list of {str, dtype} or bool, default=None
When `esimator.get_tags()` is not implemented, default` allows to
define the default value of a tag if it is not present in
`_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the
`_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if the

Comment on lines 42 to 43
define the default value of a tag if it is not present in
`_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed at the moment IMO. Maybe we can leave this feature for later, in order to keep minimal here?

if hasattr(estimator, "_get_tags"):
if key is not None:
try:
return estimator._get_tags().get(key, _DEFAULT_TAGS[key])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this goes against our docs:

Note however that all tags must be present in the dict

What bothers me here is that there's no difference anymore between _get_tags() and _more_tags() from a third-party point of view.

@jeremiedbb
Copy link
Member

this goes against our docs:
> Note however that all tags must be present in the dict

I think it's coherent with the proposed changes to the doc:

 you will need to implement a `_get_tags()` method which returns a dict that
  should contains all the necessary tags for that estimator, including the
  default tags typically defined in :class:`~sklearn.base.BaseEstimator` and
  other scikit-learn mixin classes. Note however that **all tags must be
  present in the dict**. If any of the keys documented above is not present in
  the output of `_get_tags()`, an error might occur.

@NicolasHug
Copy link
Member

NicolasHug commented Dec 1, 2020

I don't think I agree: when doing estimator._get_tags().get(key, _DEFAULT_TAGS[key]), the following isn't true anymore:

Note however that all tags must be
present in the dict
. If any of the keys documented above is not present in
the output of _get_tags(), an error might occur

No error will ever be raised if a tag isn't returned by _get_tags(), and so we don't require all keys to exist. In effect, implementing _get_tags() is exactly the same as implementing _more_tags() for 3rd parties

Strictly following the docs would mean doing estimator._get_tags()[key] (which is what I would prefer).

@jeremiedbb
Copy link
Member

Right I did not understand what you were reffering to in the first place. I agree that the code does not reflect the doc currently. Guillaume is working on it :)

@@ -1985,15 +1985,3 @@ def _more_tags(self):
"Set the estimator tags of your estimator instead")
with pytest.warns(FutureWarning, match=msg):
cross_validate(svm, linear_kernel, y, cv=2)

# the _pairwise attribute is present and set to True while the pairwise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug by not being permissive (getting default with _get_tags), we need to remove this test. What do you think about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not sure that this case is actually possible in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't followed the introduction of the pairwise tag but since the test assumes that the tag doesn't exist and since we're telling 3rd parties that all tags should exist, I'd say it makes sense to remove the test

@@ -558,24 +558,9 @@ class IncorrectTagPCA(KernelPCA):
with pytest.warns(FutureWarning, match=msg):
assert not _is_pairwise(pca)

# the _pairwise attribute is present and set to False while the pairwise
# tag is not present
class FalsePairwise(BaseEstimator):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a second test with the same issue.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new version of _safe_tags that takes _more_tags into account if _get_tags is not present.

The code is now simpler and it's more natural for third party estimators that do no inherit from scikit-learn base classes to incrementally define new tags without having to re-implement the for _get_tags machinery from scratch.

The documentation is now simpler to follow as well.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new version of _safe_tags that takes _more_tags into account if _get_tags is not present.

I guess I'm fine with the code but regarding the docs: with that in place, a non-inheriting 3d part lib has no reason to ever implement _get_tags(), does it? Unless they want to use the tags in their own code... In which case they'll have to switch from _more_tags() to _get_tags(), which will be annoying to them. But why would a library use the tags machinery in its code while still not inheriting...?

In other words, do we even want to document "you can also override _get_tags()"? We could just say "you need to define _more_tags() if you want to override the defaults, and if you want to access tags values that you don't override (i.e. that are not in your own-defined _more_tags()), you'll need to inherit from BaseEstimator."

To override the tags of a child class, one must define the `_more_tags()`
method and return a dict with the desired tags, e.g::
It is unlikely that the default values for each tag will suit the needs of your
specific estimator. Additional tags can be created or default tags can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Additionnal tags can be created"

I thought we agreed not to support that #18797 (comment)? (or that's how I interpret @ogrisel's +1)

Copy link
Member Author

@glemaitre glemaitre Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a difference between supporting in _safe_tags and people creating their own tags within their libraries using _more_tags. This is a real need here:

https://github.com/rapidsai/cuml/pull/3113/files#diff-e4bd6eee2eca2b0619b03a5f6ba7b471b4ca03080a6619d0079105d5f13c2165R34-R35

We have something similar in imbalanced-learn since the introduction of tags.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My +1 was to remove the default param to the _safe_tags. I think third-party implementers are free to add other tags in their own estimators if they which. cuML is already doing in in their master branch apparently:

https://github.com/rapidsai/cuml/pull/3113/files

@glemaitre
Copy link
Member Author

In other words, do we even want to document "you can also override _get_tags()"? We could just say "you need to define _more_tags() if you want to override the defaults, and if you want to access tags values that you don't override (i.e. that are not in your own-defined _more_tags()), you'll need to inherit from BaseEstimator."

How do you deal with CuML case: inheriting is not an option. If they want to use tags (for new checks for instance) internally, we are forcing them to call _more_tags instead of their own implementation of _get_tags.

IMO, it is not a burden to mention that if you want to access your tags by implementing _get_tags you need to have all scikit-learn defaults because we are going to raise error otherwise.

@ogrisel
Copy link
Member

ogrisel commented Dec 2, 2020

@NicolasHug would it be fine with you if we merge this PR as you are fine with the code. This would allow us to branch 0.24.X and start the release PR for 0.24.0rc1.

We can always fine tune the doc before 0.24.0 final if needed.

Comment on lines +50 to +55
if hasattr(estimator, "_get_tags"):
tags_provider = "_get_tags()"
tags = estimator._get_tags()
elif hasattr(estimator, "_more_tags"):
tags_provider = "_more_tags()"
tags = {**_DEFAULT_TAGS, **estimator._more_tags()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we rely on _more_tags regardless of inheritance, what's the rationale for defaulting to _DEFAULT_TAGS with _more_tags but not with _get_tags?

I admit I'm a bit lost on all the possible code paths and use-cases here. It seems that we're overly permissive in some cases while being restrictive in others, with no obvious reason. Things were clearer to me when the logic was "with inheritance -> define _more_tags, no inheritance -> define _get_tags".

But anyway, feel free to merge if we need to move with the release. This is still experimental after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the message is simpler to always recommend to define _more_tags for whether or not you inherit from BaseEstimator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea now always implements _more_tags and it will work and it should cover 99% of the use case.

The remaining 1% is no inheritance and people that want to use tags -> implement _get_tags with strong requirements on our side regarding defaults.

@ogrisel ogrisel merged commit 255718b into scikit-learn:master Dec 2, 2020
@ogrisel
Copy link
Member

ogrisel commented Dec 2, 2020

Merged thanks all!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants