-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
MAINT conversion old->new/new->old tags (bis) #30327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3f0a4cb
f1ec5fa
c34bf25
64c772e
b3fb575
ea5ccb2
c2e6372
57d0415
8153589
ad4fe79
6e89afe
1764ea9
97a4114
4149013
2357fcd
b515f1a
9c5a363
82367e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
import warnings | ||
from collections import OrderedDict | ||
from dataclasses import dataclass, field | ||
from itertools import chain | ||
|
||
from .fixes import _dataclass_args | ||
|
||
|
@@ -290,6 +292,71 @@ def default_tags(estimator) -> Tags: | |
) | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _find_tags_provider(estimator, warn=True): | ||
"""Find the tags provider for an estimator. | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator object | ||
The estimator to find the tags provider for. | ||
|
||
warn : bool, default=True | ||
Whether to warn if the tags provider is not found. | ||
|
||
Returns | ||
------- | ||
tag_provider : str | ||
The tags provider for the estimator. Can be one of: | ||
- "_get_tags": to use the old tags infrastructure | ||
- "__sklearn_tags__": to use the new tags infrastructure | ||
""" | ||
mro_model = type(estimator).mro() | ||
tags_mro = OrderedDict() | ||
for klass in mro_model: | ||
tags_provider = [] | ||
if "_more_tags" in vars(klass): | ||
tags_provider.append("_more_tags") | ||
if "_get_tags" in vars(klass): | ||
tags_provider.append("_get_tags") | ||
if "__sklearn_tags__" in vars(klass): | ||
tags_provider.append("__sklearn_tags__") | ||
tags_mro[klass.__name__] = tags_provider | ||
|
||
all_providers = set(chain.from_iterable(tags_mro.values())) | ||
if "__sklearn_tags__" not in all_providers: | ||
# default on the old tags infrastructure | ||
return "_get_tags" | ||
|
||
tag_provider = "__sklearn_tags__" | ||
for klass in tags_mro: | ||
has_get_or_more_tags = any( | ||
provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags") | ||
) | ||
has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass] | ||
|
||
if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty | ||
if has_get_or_more_tags and not has_sklearn_tags: | ||
# Case where a class does not implement __sklearn_tags__ and we fallback | ||
# to _get_tags. We should therefore warn for implementing | ||
# __sklearn_tags__. | ||
tag_provider = "_get_tags" | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
break | ||
|
||
if warn and tag_provider == "_get_tags": | ||
warnings.warn( | ||
f"The {estimator.__class__.__name__} or classes from which it inherits " | ||
"use `_get_tags` and `_more_tags`. Please define the " | ||
"`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` " | ||
"and/or other appropriate mixins such as `sklearn.base.TransformerMixin`, " | ||
"`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and " | ||
"`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining " | ||
"`__sklearn_tags__` will raise an error.", | ||
category=FutureWarning, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally, I would go with a In any case, I am okay with the current implementation. |
||
) | ||
return tag_provider | ||
|
||
|
||
def get_tags(estimator) -> Tags: | ||
"""Get estimator tags. | ||
|
||
|
@@ -316,19 +383,201 @@ def get_tags(estimator) -> Tags: | |
The estimator tags. | ||
""" | ||
|
||
if hasattr(estimator, "__sklearn_tags__"): | ||
tag_provider = _find_tags_provider(estimator) | ||
|
||
if tag_provider == "__sklearn_tags__": | ||
tags = estimator.__sklearn_tags__() | ||
else: | ||
warnings.warn( | ||
f"Estimator {estimator} has no __sklearn_tags__ attribute, which is " | ||
"defined in `sklearn.base.BaseEstimator`. This will raise an error in " | ||
"scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit " | ||
"from `sklearn.base.BaseEstimator` and other appropriate mixins such as " | ||
"`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, " | ||
"`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and " | ||
"`sklearn.base.OutlierMixin`.", | ||
category=FutureWarning, | ||
# TODO(1.7): Remove this branch of the code | ||
# Let's go through the MRO and patch each class implementing _more_tags | ||
sklearn_tags_provider = {} | ||
more_tags_provider = {} | ||
class_order = [] | ||
for klass in reversed(type(estimator).mro()): | ||
if "__sklearn_tags__" in vars(klass): | ||
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
elif "_more_tags" in vars(klass): | ||
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined] | ||
class_order.append(klass) | ||
|
||
# Find differences between consecutive in the case of __sklearn_tags__ | ||
# inheritance | ||
sklearn_tags_diff = {} | ||
items = list(sklearn_tags_provider.items()) | ||
for current_item, next_item in zip(items[:-1], items[1:]): | ||
current_name, current_tags = current_item | ||
next_name, next_tags = next_item | ||
current_tags = _to_old_tags(current_tags) | ||
next_tags = _to_old_tags(next_tags) | ||
|
||
# Compare tags and store differences | ||
diff = {} | ||
for key in current_tags: | ||
if current_tags[key] != next_tags[key]: | ||
diff[key] = next_tags[key] | ||
|
||
sklearn_tags_diff[next_name] = diff | ||
|
||
tags = {} | ||
for klass in class_order: | ||
if klass in sklearn_tags_diff: | ||
tags.update(sklearn_tags_diff[klass]) | ||
elif klass in more_tags_provider: | ||
tags.update(more_tags_provider[klass]) | ||
|
||
tags = _to_new_tags( | ||
{**_to_old_tags(default_tags(estimator)), **tags}, estimator | ||
) | ||
tags = default_tags(estimator) | ||
|
||
return tags | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _safe_tags(estimator, key=None): | ||
warnings.warn( | ||
"The `_safe_tags` function is deprecated in 1.6 and will be removed in " | ||
"1.7. Use the public `get_tags` function instead and make sure to implement " | ||
"the `__sklearn_tags__` method.", | ||
category=FutureWarning, | ||
) | ||
tags = _to_old_tags(get_tags(estimator)) | ||
|
||
if key is not None: | ||
if key not in tags: | ||
raise ValueError( | ||
f"The key {key} is not defined for the class " | ||
f"{estimator.__class__.__name__}." | ||
) | ||
return tags[key] | ||
return tags | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _to_new_tags(old_tags, estimator=None): | ||
"""Utility function convert old tags (dictionary) to new tags (dataclass).""" | ||
input_tags = InputTags( | ||
one_d_array="1darray" in old_tags["X_types"], | ||
two_d_array="2darray" in old_tags["X_types"], | ||
three_d_array="3darray" in old_tags["X_types"], | ||
sparse="sparse" in old_tags["X_types"], | ||
categorical="categorical" in old_tags["X_types"], | ||
string="string" in old_tags["X_types"], | ||
dict="dict" in old_tags["X_types"], | ||
positive_only=old_tags["requires_positive_X"], | ||
allow_nan=old_tags["allow_nan"], | ||
pairwise=old_tags["pairwise"], | ||
) | ||
target_tags = TargetTags( | ||
required=old_tags["requires_y"], | ||
one_d_labels="1dlabels" in old_tags["X_types"], | ||
two_d_labels="2dlabels" in old_tags["X_types"], | ||
positive_only=old_tags["requires_positive_y"], | ||
multi_output=old_tags["multioutput"] or old_tags["multioutput_only"], | ||
single_output=not old_tags["multioutput_only"], | ||
) | ||
if estimator is not None and ( | ||
hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") | ||
): | ||
transformer_tags = TransformerTags( | ||
preserves_dtype=old_tags["preserves_dtype"], | ||
) | ||
else: | ||
transformer_tags = None | ||
estimator_type = getattr(estimator, "_estimator_type", None) | ||
if estimator_type == "classifier": | ||
classifier_tags = ClassifierTags( | ||
poor_score=old_tags["poor_score"], | ||
multi_class=not old_tags["binary_only"], | ||
multi_label=old_tags["multilabel"], | ||
) | ||
else: | ||
classifier_tags = None | ||
if estimator_type == "regressor": | ||
regressor_tags = RegressorTags( | ||
poor_score=old_tags["poor_score"], | ||
multi_label=old_tags["multilabel"], | ||
) | ||
else: | ||
regressor_tags = None | ||
return Tags( | ||
estimator_type=estimator_type, | ||
target_tags=target_tags, | ||
transformer_tags=transformer_tags, | ||
classifier_tags=classifier_tags, | ||
regressor_tags=regressor_tags, | ||
input_tags=input_tags, | ||
array_api_support=old_tags["array_api_support"], | ||
no_validation=old_tags["no_validation"], | ||
non_deterministic=old_tags["non_deterministic"], | ||
requires_fit=old_tags["requires_fit"], | ||
_skip_test=old_tags["_skip_test"], | ||
) | ||
|
||
|
||
# TODO(1.7): Remove this function | ||
def _to_old_tags(new_tags): | ||
"""Utility function convert old tags (dictionary) to new tags (dataclass).""" | ||
if new_tags.classifier_tags: | ||
binary_only = not new_tags.classifier_tags.multi_class | ||
multilabel_clf = new_tags.classifier_tags.multi_label | ||
poor_score_clf = new_tags.classifier_tags.poor_score | ||
else: | ||
binary_only = False | ||
multilabel_clf = False | ||
poor_score_clf = False | ||
|
||
if new_tags.regressor_tags: | ||
multilabel_reg = new_tags.regressor_tags.multi_label | ||
poor_score_reg = new_tags.regressor_tags.poor_score | ||
else: | ||
multilabel_reg = False | ||
poor_score_reg = False | ||
|
||
if new_tags.transformer_tags: | ||
preserves_dtype = new_tags.transformer_tags.preserves_dtype | ||
else: | ||
preserves_dtype = ["float64"] | ||
|
||
tags = { | ||
"allow_nan": new_tags.input_tags.allow_nan, | ||
"array_api_support": new_tags.array_api_support, | ||
"binary_only": binary_only, | ||
"multilabel": multilabel_clf or multilabel_reg, | ||
"multioutput": new_tags.target_tags.multi_output, | ||
"multioutput_only": ( | ||
not new_tags.target_tags.single_output and new_tags.target_tags.multi_output | ||
), | ||
"no_validation": new_tags.no_validation, | ||
"non_deterministic": new_tags.non_deterministic, | ||
"pairwise": new_tags.input_tags.pairwise, | ||
"preserves_dtype": preserves_dtype, | ||
"poor_score": poor_score_clf or poor_score_reg, | ||
"requires_fit": new_tags.requires_fit, | ||
"requires_positive_X": new_tags.input_tags.positive_only, | ||
"requires_y": new_tags.target_tags.required, | ||
"requires_positive_y": new_tags.target_tags.positive_only, | ||
"_skip_test": new_tags._skip_test, | ||
"stateless": new_tags.requires_fit, | ||
} | ||
X_types = [] | ||
if new_tags.input_tags.one_d_array: | ||
X_types.append("1darray") | ||
if new_tags.input_tags.two_d_array: | ||
X_types.append("2darray") | ||
if new_tags.input_tags.three_d_array: | ||
X_types.append("3darray") | ||
if new_tags.input_tags.sparse: | ||
X_types.append("sparse") | ||
if new_tags.input_tags.categorical: | ||
X_types.append("categorical") | ||
if new_tags.input_tags.string: | ||
X_types.append("string") | ||
if new_tags.input_tags.dict: | ||
X_types.append("dict") | ||
if new_tags.target_tags.one_d_labels: | ||
X_types.append("1dlabels") | ||
if new_tags.target_tags.two_d_labels: | ||
X_types.append("2dlabels") | ||
tags["X_types"] = X_types | ||
return tags |
Uh oh!
There was an error while loading. Please reload this page.