Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT conversion old->new/new->old tags (bis) #30327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,33 @@ def __setstate__(self, state):
except AttributeError:
self.__dict__.update(state)

# TODO(1.7): Remove this method
def _more_tags(self):
"""This code should never be reached since our `get_tags` will fallback on
`__sklearn_tags__` implemented below. We keep it for backward compatibility.
It is tested in `test_base_estimator_more_tags` in
`sklearn/utils/testing/test_tags.py`."""
from sklearn.utils._tags import _to_old_tags, default_tags

warnings.warn(
"The `_more_tags` method is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)
return _to_old_tags(default_tags(self))

# TODO(1.7): Remove this method
def _get_tags(self):
from sklearn.utils._tags import _to_old_tags, get_tags

warnings.warn(
"The `_get_tags` method is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)

return _to_old_tags(get_tags(self))

def __sklearn_tags__(self):
return Tags(
estimator_type=None,
Expand Down
271 changes: 260 additions & 11 deletions sklearn/utils/_tags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from itertools import chain

from .fixes import _dataclass_args

Expand Down Expand Up @@ -290,6 +292,71 @@ def default_tags(estimator) -> Tags:
)


# TODO(1.7): Remove this function
def _find_tags_provider(estimator, warn=True):
"""Find the tags provider for an estimator.

Parameters
----------
estimator : estimator object
The estimator to find the tags provider for.

warn : bool, default=True
Whether to warn if the tags provider is not found.

Returns
-------
tag_provider : str
The tags provider for the estimator. Can be one of:
- "_get_tags": to use the old tags infrastructure
- "__sklearn_tags__": to use the new tags infrastructure
"""
mro_model = type(estimator).mro()
tags_mro = OrderedDict()
for klass in mro_model:
tags_provider = []
if "_more_tags" in vars(klass):
tags_provider.append("_more_tags")
if "_get_tags" in vars(klass):
tags_provider.append("_get_tags")
if "__sklearn_tags__" in vars(klass):
tags_provider.append("__sklearn_tags__")
tags_mro[klass.__name__] = tags_provider

all_providers = set(chain.from_iterable(tags_mro.values()))
if "__sklearn_tags__" not in all_providers:
# default on the old tags infrastructure
return "_get_tags"

tag_provider = "__sklearn_tags__"
for klass in tags_mro:
has_get_or_more_tags = any(
provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags")
)
has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass]

if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty
if has_get_or_more_tags and not has_sklearn_tags:
# Case where a class does not implement __sklearn_tags__ and we fallback
# to _get_tags. We should therefore warn for implementing
# __sklearn_tags__.
tag_provider = "_get_tags"
break

if warn and tag_provider == "_get_tags":
warnings.warn(
f"The {estimator.__class__.__name__} or classes from which it inherits "
"use `_get_tags` and `_more_tags`. Please define the "
"`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` "
"and/or other appropriate mixins such as `sklearn.base.TransformerMixin`, "
"`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and "
"`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining "
"`__sklearn_tags__` will raise an error.",
category=FutureWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, I would go with a DeprecationWarning for developer focused changes. It's unfortunate that users will see this warning too if a library has a custom _more_tags. I do not see a good way around it because users can also write their own estimators with custom _more_tags.

In any case, I am okay with the current implementation.

)
return tag_provider


def get_tags(estimator) -> Tags:
"""Get estimator tags.

Expand All @@ -316,19 +383,201 @@ def get_tags(estimator) -> Tags:
The estimator tags.
"""

if hasattr(estimator, "__sklearn_tags__"):
tag_provider = _find_tags_provider(estimator)

if tag_provider == "__sklearn_tags__":
tags = estimator.__sklearn_tags__()
else:
warnings.warn(
f"Estimator {estimator} has no __sklearn_tags__ attribute, which is "
"defined in `sklearn.base.BaseEstimator`. This will raise an error in "
"scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit "
"from `sklearn.base.BaseEstimator` and other appropriate mixins such as "
"`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, "
"`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and "
"`sklearn.base.OutlierMixin`.",
category=FutureWarning,
# TODO(1.7): Remove this branch of the code
# Let's go through the MRO and patch each class implementing _more_tags
sklearn_tags_provider = {}
more_tags_provider = {}
class_order = []
for klass in reversed(type(estimator).mro()):
if "__sklearn_tags__" in vars(klass):
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined]
class_order.append(klass)
elif "_more_tags" in vars(klass):
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined]
class_order.append(klass)

# Find differences between consecutive in the case of __sklearn_tags__
# inheritance
sklearn_tags_diff = {}
items = list(sklearn_tags_provider.items())
for current_item, next_item in zip(items[:-1], items[1:]):
current_name, current_tags = current_item
next_name, next_tags = next_item
current_tags = _to_old_tags(current_tags)
next_tags = _to_old_tags(next_tags)

# Compare tags and store differences
diff = {}
for key in current_tags:
if current_tags[key] != next_tags[key]:
diff[key] = next_tags[key]

sklearn_tags_diff[next_name] = diff

tags = {}
for klass in class_order:
if klass in sklearn_tags_diff:
tags.update(sklearn_tags_diff[klass])
elif klass in more_tags_provider:
tags.update(more_tags_provider[klass])

tags = _to_new_tags(
{**_to_old_tags(default_tags(estimator)), **tags}, estimator
)
tags = default_tags(estimator)

return tags


# TODO(1.7): Remove this function
def _safe_tags(estimator, key=None):
warnings.warn(
"The `_safe_tags` function is deprecated in 1.6 and will be removed in "
"1.7. Use the public `get_tags` function instead and make sure to implement "
"the `__sklearn_tags__` method.",
category=FutureWarning,
)
tags = _to_old_tags(get_tags(estimator))

if key is not None:
if key not in tags:
raise ValueError(
f"The key {key} is not defined for the class "
f"{estimator.__class__.__name__}."
)
return tags[key]
return tags


# TODO(1.7): Remove this function
def _to_new_tags(old_tags, estimator=None):
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
input_tags = InputTags(
one_d_array="1darray" in old_tags["X_types"],
two_d_array="2darray" in old_tags["X_types"],
three_d_array="3darray" in old_tags["X_types"],
sparse="sparse" in old_tags["X_types"],
categorical="categorical" in old_tags["X_types"],
string="string" in old_tags["X_types"],
dict="dict" in old_tags["X_types"],
positive_only=old_tags["requires_positive_X"],
allow_nan=old_tags["allow_nan"],
pairwise=old_tags["pairwise"],
)
target_tags = TargetTags(
required=old_tags["requires_y"],
one_d_labels="1dlabels" in old_tags["X_types"],
two_d_labels="2dlabels" in old_tags["X_types"],
positive_only=old_tags["requires_positive_y"],
multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
single_output=not old_tags["multioutput_only"],
)
if estimator is not None and (
hasattr(estimator, "transform") or hasattr(estimator, "fit_transform")
):
transformer_tags = TransformerTags(
preserves_dtype=old_tags["preserves_dtype"],
)
else:
transformer_tags = None
estimator_type = getattr(estimator, "_estimator_type", None)
if estimator_type == "classifier":
classifier_tags = ClassifierTags(
poor_score=old_tags["poor_score"],
multi_class=not old_tags["binary_only"],
multi_label=old_tags["multilabel"],
)
else:
classifier_tags = None
if estimator_type == "regressor":
regressor_tags = RegressorTags(
poor_score=old_tags["poor_score"],
multi_label=old_tags["multilabel"],
)
else:
regressor_tags = None
return Tags(
estimator_type=estimator_type,
target_tags=target_tags,
transformer_tags=transformer_tags,
classifier_tags=classifier_tags,
regressor_tags=regressor_tags,
input_tags=input_tags,
array_api_support=old_tags["array_api_support"],
no_validation=old_tags["no_validation"],
non_deterministic=old_tags["non_deterministic"],
requires_fit=old_tags["requires_fit"],
_skip_test=old_tags["_skip_test"],
)


# TODO(1.7): Remove this function
def _to_old_tags(new_tags):
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
if new_tags.classifier_tags:
binary_only = not new_tags.classifier_tags.multi_class
multilabel_clf = new_tags.classifier_tags.multi_label
poor_score_clf = new_tags.classifier_tags.poor_score
else:
binary_only = False
multilabel_clf = False
poor_score_clf = False

if new_tags.regressor_tags:
multilabel_reg = new_tags.regressor_tags.multi_label
poor_score_reg = new_tags.regressor_tags.poor_score
else:
multilabel_reg = False
poor_score_reg = False

if new_tags.transformer_tags:
preserves_dtype = new_tags.transformer_tags.preserves_dtype
else:
preserves_dtype = ["float64"]

tags = {
"allow_nan": new_tags.input_tags.allow_nan,
"array_api_support": new_tags.array_api_support,
"binary_only": binary_only,
"multilabel": multilabel_clf or multilabel_reg,
"multioutput": new_tags.target_tags.multi_output,
"multioutput_only": (
not new_tags.target_tags.single_output and new_tags.target_tags.multi_output
),
"no_validation": new_tags.no_validation,
"non_deterministic": new_tags.non_deterministic,
"pairwise": new_tags.input_tags.pairwise,
"preserves_dtype": preserves_dtype,
"poor_score": poor_score_clf or poor_score_reg,
"requires_fit": new_tags.requires_fit,
"requires_positive_X": new_tags.input_tags.positive_only,
"requires_y": new_tags.target_tags.required,
"requires_positive_y": new_tags.target_tags.positive_only,
"_skip_test": new_tags._skip_test,
"stateless": new_tags.requires_fit,
}
X_types = []
if new_tags.input_tags.one_d_array:
X_types.append("1darray")
if new_tags.input_tags.two_d_array:
X_types.append("2darray")
if new_tags.input_tags.three_d_array:
X_types.append("3darray")
if new_tags.input_tags.sparse:
X_types.append("sparse")
if new_tags.input_tags.categorical:
X_types.append("categorical")
if new_tags.input_tags.string:
X_types.append("string")
if new_tags.input_tags.dict:
X_types.append("dict")
if new_tags.target_tags.one_d_labels:
X_types.append("1dlabels")
if new_tags.target_tags.two_d_labels:
X_types.append("2dlabels")
tags["X_types"] = X_types
return tags
Loading