diff --git a/sklearn/base.py b/sklearn/base.py index d646f8d3e56bf..2c82cf05a6c5a 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -389,6 +389,33 @@ def __setstate__(self, state): except AttributeError: self.__dict__.update(state) + # TODO(1.7): Remove this method + def _more_tags(self): + """This code should never be reached since our `get_tags` will fallback on + `__sklearn_tags__` implemented below. We keep it for backward compatibility. + It is tested in `test_base_estimator_more_tags` in + `sklearn/utils/testing/test_tags.py`.""" + from sklearn.utils._tags import _to_old_tags, default_tags + + warnings.warn( + "The `_more_tags` method is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + return _to_old_tags(default_tags(self)) + + # TODO(1.7): Remove this method + def _get_tags(self): + from sklearn.utils._tags import _to_old_tags, get_tags + + warnings.warn( + "The `_get_tags` method is deprecated in 1.6 and will be removed in " + "1.7. Please implement the `__sklearn_tags__` method.", + category=FutureWarning, + ) + + return _to_old_tags(get_tags(self)) + def __sklearn_tags__(self): return Tags( estimator_type=None, diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index ccbc9d2438268..1ba1913c37234 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -1,7 +1,9 @@ from __future__ import annotations import warnings +from collections import OrderedDict from dataclasses import dataclass, field +from itertools import chain from .fixes import _dataclass_args @@ -290,6 +292,71 @@ def default_tags(estimator) -> Tags: ) +# TODO(1.7): Remove this function +def _find_tags_provider(estimator, warn=True): + """Find the tags provider for an estimator. + + Parameters + ---------- + estimator : estimator object + The estimator to find the tags provider for. + + warn : bool, default=True + Whether to warn if the tags provider is not found. + + Returns + ------- + tag_provider : str + The tags provider for the estimator. Can be one of: + - "_get_tags": to use the old tags infrastructure + - "__sklearn_tags__": to use the new tags infrastructure + """ + mro_model = type(estimator).mro() + tags_mro = OrderedDict() + for klass in mro_model: + tags_provider = [] + if "_more_tags" in vars(klass): + tags_provider.append("_more_tags") + if "_get_tags" in vars(klass): + tags_provider.append("_get_tags") + if "__sklearn_tags__" in vars(klass): + tags_provider.append("__sklearn_tags__") + tags_mro[klass.__name__] = tags_provider + + all_providers = set(chain.from_iterable(tags_mro.values())) + if "__sklearn_tags__" not in all_providers: + # default on the old tags infrastructure + return "_get_tags" + + tag_provider = "__sklearn_tags__" + for klass in tags_mro: + has_get_or_more_tags = any( + provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags") + ) + has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass] + + if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty + if has_get_or_more_tags and not has_sklearn_tags: + # Case where a class does not implement __sklearn_tags__ and we fallback + # to _get_tags. We should therefore warn for implementing + # __sklearn_tags__. + tag_provider = "_get_tags" + break + + if warn and tag_provider == "_get_tags": + warnings.warn( + f"The {estimator.__class__.__name__} or classes from which it inherits " + "use `_get_tags` and `_more_tags`. Please define the " + "`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` " + "and/or other appropriate mixins such as `sklearn.base.TransformerMixin`, " + "`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and " + "`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining " + "`__sklearn_tags__` will raise an error.", + category=FutureWarning, + ) + return tag_provider + + def get_tags(estimator) -> Tags: """Get estimator tags. @@ -316,19 +383,201 @@ def get_tags(estimator) -> Tags: The estimator tags. """ - if hasattr(estimator, "__sklearn_tags__"): + tag_provider = _find_tags_provider(estimator) + + if tag_provider == "__sklearn_tags__": tags = estimator.__sklearn_tags__() else: - warnings.warn( - f"Estimator {estimator} has no __sklearn_tags__ attribute, which is " - "defined in `sklearn.base.BaseEstimator`. This will raise an error in " - "scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit " - "from `sklearn.base.BaseEstimator` and other appropriate mixins such as " - "`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, " - "`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and " - "`sklearn.base.OutlierMixin`.", - category=FutureWarning, + # TODO(1.7): Remove this branch of the code + # Let's go through the MRO and patch each class implementing _more_tags + sklearn_tags_provider = {} + more_tags_provider = {} + class_order = [] + for klass in reversed(type(estimator).mro()): + if "__sklearn_tags__" in vars(klass): + sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined] + class_order.append(klass) + elif "_more_tags" in vars(klass): + more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined] + class_order.append(klass) + + # Find differences between consecutive in the case of __sklearn_tags__ + # inheritance + sklearn_tags_diff = {} + items = list(sklearn_tags_provider.items()) + for current_item, next_item in zip(items[:-1], items[1:]): + current_name, current_tags = current_item + next_name, next_tags = next_item + current_tags = _to_old_tags(current_tags) + next_tags = _to_old_tags(next_tags) + + # Compare tags and store differences + diff = {} + for key in current_tags: + if current_tags[key] != next_tags[key]: + diff[key] = next_tags[key] + + sklearn_tags_diff[next_name] = diff + + tags = {} + for klass in class_order: + if klass in sklearn_tags_diff: + tags.update(sklearn_tags_diff[klass]) + elif klass in more_tags_provider: + tags.update(more_tags_provider[klass]) + + tags = _to_new_tags( + {**_to_old_tags(default_tags(estimator)), **tags}, estimator ) - tags = default_tags(estimator) return tags + + +# TODO(1.7): Remove this function +def _safe_tags(estimator, key=None): + warnings.warn( + "The `_safe_tags` function is deprecated in 1.6 and will be removed in " + "1.7. Use the public `get_tags` function instead and make sure to implement " + "the `__sklearn_tags__` method.", + category=FutureWarning, + ) + tags = _to_old_tags(get_tags(estimator)) + + if key is not None: + if key not in tags: + raise ValueError( + f"The key {key} is not defined for the class " + f"{estimator.__class__.__name__}." + ) + return tags[key] + return tags + + +# TODO(1.7): Remove this function +def _to_new_tags(old_tags, estimator=None): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + input_tags = InputTags( + one_d_array="1darray" in old_tags["X_types"], + two_d_array="2darray" in old_tags["X_types"], + three_d_array="3darray" in old_tags["X_types"], + sparse="sparse" in old_tags["X_types"], + categorical="categorical" in old_tags["X_types"], + string="string" in old_tags["X_types"], + dict="dict" in old_tags["X_types"], + positive_only=old_tags["requires_positive_X"], + allow_nan=old_tags["allow_nan"], + pairwise=old_tags["pairwise"], + ) + target_tags = TargetTags( + required=old_tags["requires_y"], + one_d_labels="1dlabels" in old_tags["X_types"], + two_d_labels="2dlabels" in old_tags["X_types"], + positive_only=old_tags["requires_positive_y"], + multi_output=old_tags["multioutput"] or old_tags["multioutput_only"], + single_output=not old_tags["multioutput_only"], + ) + if estimator is not None and ( + hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") + ): + transformer_tags = TransformerTags( + preserves_dtype=old_tags["preserves_dtype"], + ) + else: + transformer_tags = None + estimator_type = getattr(estimator, "_estimator_type", None) + if estimator_type == "classifier": + classifier_tags = ClassifierTags( + poor_score=old_tags["poor_score"], + multi_class=not old_tags["binary_only"], + multi_label=old_tags["multilabel"], + ) + else: + classifier_tags = None + if estimator_type == "regressor": + regressor_tags = RegressorTags( + poor_score=old_tags["poor_score"], + multi_label=old_tags["multilabel"], + ) + else: + regressor_tags = None + return Tags( + estimator_type=estimator_type, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + input_tags=input_tags, + array_api_support=old_tags["array_api_support"], + no_validation=old_tags["no_validation"], + non_deterministic=old_tags["non_deterministic"], + requires_fit=old_tags["requires_fit"], + _skip_test=old_tags["_skip_test"], + ) + + +# TODO(1.7): Remove this function +def _to_old_tags(new_tags): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + if new_tags.classifier_tags: + binary_only = not new_tags.classifier_tags.multi_class + multilabel_clf = new_tags.classifier_tags.multi_label + poor_score_clf = new_tags.classifier_tags.poor_score + else: + binary_only = False + multilabel_clf = False + poor_score_clf = False + + if new_tags.regressor_tags: + multilabel_reg = new_tags.regressor_tags.multi_label + poor_score_reg = new_tags.regressor_tags.poor_score + else: + multilabel_reg = False + poor_score_reg = False + + if new_tags.transformer_tags: + preserves_dtype = new_tags.transformer_tags.preserves_dtype + else: + preserves_dtype = ["float64"] + + tags = { + "allow_nan": new_tags.input_tags.allow_nan, + "array_api_support": new_tags.array_api_support, + "binary_only": binary_only, + "multilabel": multilabel_clf or multilabel_reg, + "multioutput": new_tags.target_tags.multi_output, + "multioutput_only": ( + not new_tags.target_tags.single_output and new_tags.target_tags.multi_output + ), + "no_validation": new_tags.no_validation, + "non_deterministic": new_tags.non_deterministic, + "pairwise": new_tags.input_tags.pairwise, + "preserves_dtype": preserves_dtype, + "poor_score": poor_score_clf or poor_score_reg, + "requires_fit": new_tags.requires_fit, + "requires_positive_X": new_tags.input_tags.positive_only, + "requires_y": new_tags.target_tags.required, + "requires_positive_y": new_tags.target_tags.positive_only, + "_skip_test": new_tags._skip_test, + "stateless": new_tags.requires_fit, + } + X_types = [] + if new_tags.input_tags.one_d_array: + X_types.append("1darray") + if new_tags.input_tags.two_d_array: + X_types.append("2darray") + if new_tags.input_tags.three_d_array: + X_types.append("3darray") + if new_tags.input_tags.sparse: + X_types.append("sparse") + if new_tags.input_tags.categorical: + X_types.append("categorical") + if new_tags.input_tags.string: + X_types.append("string") + if new_tags.input_tags.dict: + X_types.append("dict") + if new_tags.target_tags.one_d_labels: + X_types.append("1dlabels") + if new_tags.target_tags.two_d_labels: + X_types.append("2dlabels") + tags["X_types"] = X_types + return tags diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 413fbc6bbd3de..86e4e2d7c431e 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -7,7 +7,16 @@ RegressorMixin, TransformerMixin, ) -from sklearn.utils import Tags, get_tags +from sklearn.utils import ( + ClassifierTags, + InputTags, + RegressorTags, + Tags, + TargetTags, + TransformerTags, + get_tags, +) +from sklearn.utils._tags import _safe_tags, _to_new_tags, _to_old_tags, default_tags from sklearn.utils.estimator_checks import ( check_estimator_tags_renamed, check_valid_tag_types, @@ -78,3 +87,546 @@ def __sklearn_tags__(self): return tags check_valid_tag_types("MyEstimator", MyEstimator()) + + +######################################################################################## +# Test for the deprecation +# TODO(1.7): Remove this +######################################################################################## + + +class MixinAllowNanOldTags: + def _more_tags(self): + return {"allow_nan": True} + + +class MixinAllowNanNewTags: + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = True + return tags + + +class MixinAllowNanOldNewTags: + def _more_tags(self): + return {"allow_nan": True} # pragma: no cover + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = True + return tags + + +class MixinArrayApiSupportOldTags: + def _more_tags(self): + return {"array_api_support": True} + + +class MixinArrayApiSupportNewTags: + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.array_api_support = True + return tags + + +class MixinArrayApiSupportOldNewTags: + def _more_tags(self): + return {"array_api_support": True} # pragma: no cover + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.array_api_support = True + return tags + + +class PredictorOldTags(BaseEstimator): + def _more_tags(self): + return {"requires_fit": True} + + +class PredictorNewTags(BaseEstimator): + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.requires_fit = True + return tags + + +class PredictorOldNewTags(BaseEstimator): + def _more_tags(self): + return {"requires_fit": True} # pragma: no cover + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.requires_fit = True + return tags + + +def test_get_tags_backward_compatibility(): + warn_msg = "Please define the `__sklearn_tags__` method" + + #################################################################################### + # only predictor inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + for predictor_cls in predictor_classes: + if predictor_cls.__name__.endswith("OldTags"): + with pytest.warns(FutureWarning, match=warn_msg): + tags = get_tags(predictor_cls()) + else: + tags = get_tags(predictor_cls()) + assert tags.requires_fit + + #################################################################################### + # one mixin and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for allow_nan_cls in allow_nan_classes: + for predictor_cls in predictor_classes: + + class ChildClass(allow_nan_cls, predictor_cls): + pass + + if any( + base_cls.__name__.endswith("OldTags") + for base_cls in (predictor_cls, allow_nan_cls) + ): + with pytest.warns(FutureWarning, match=warn_msg): + tags = get_tags(ChildClass()) + else: + tags = get_tags(ChildClass()) + + assert tags.input_tags.allow_nan + assert tags.requires_fit + + #################################################################################### + # two mixins and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + array_api_classes = [ + MixinArrayApiSupportNewTags, + MixinArrayApiSupportOldNewTags, + MixinArrayApiSupportOldTags, + ] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for predictor_cls in predictor_classes: + for array_api_cls in array_api_classes: + for allow_nan_cls in allow_nan_classes: + + class ChildClass(allow_nan_cls, array_api_cls, predictor_cls): + pass + + if any( + base_cls.__name__.endswith("OldTags") + for base_cls in (predictor_cls, array_api_cls, allow_nan_cls) + ): + with pytest.warns(FutureWarning, match=warn_msg): + tags = get_tags(ChildClass()) + else: + tags = get_tags(ChildClass()) + + assert tags.input_tags.allow_nan + assert tags.array_api_support + assert tags.requires_fit + + +@pytest.mark.filterwarnings( + "ignore:.*Please define the `__sklearn_tags__` method.*:FutureWarning" +) +def test_safe_tags_backward_compatibility(): + warn_msg = "The `_safe_tags` function is deprecated in 1.6" + + #################################################################################### + # only predictor inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + for predictor_cls in predictor_classes: + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(predictor_cls()) + assert tags["requires_fit"] + + #################################################################################### + # one mixin and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for allow_nan_cls in allow_nan_classes: + for predictor_cls in predictor_classes: + + class ChildClass(allow_nan_cls, predictor_cls): + pass + + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + + assert tags["allow_nan"] + assert tags["requires_fit"] + + #################################################################################### + # two mixins and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + array_api_classes = [ + MixinArrayApiSupportNewTags, + MixinArrayApiSupportOldNewTags, + MixinArrayApiSupportOldTags, + ] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for predictor_cls in predictor_classes: + for array_api_cls in array_api_classes: + for allow_nan_cls in allow_nan_classes: + + class ChildClass(allow_nan_cls, array_api_cls, predictor_cls): + pass + + with pytest.warns(FutureWarning, match=warn_msg): + tags = _safe_tags(ChildClass()) + + assert tags["allow_nan"] + assert tags["array_api_support"] + assert tags["requires_fit"] + + +@pytest.mark.filterwarnings( + "ignore:.*Please define the `__sklearn_tags__` method.*:FutureWarning" +) +def test__get_tags_backward_compatibility(): + warn_msg = "The `_get_tags` method is deprecated in 1.6" + + #################################################################################### + # only predictor inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + for predictor_cls in predictor_classes: + with pytest.warns(FutureWarning, match=warn_msg): + tags = predictor_cls()._get_tags() + assert tags["requires_fit"] + + #################################################################################### + # one mixin and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for allow_nan_cls in allow_nan_classes: + for predictor_cls in predictor_classes: + + class ChildClass(allow_nan_cls, predictor_cls): + pass + + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + + assert tags["allow_nan"] + assert tags["requires_fit"] + + #################################################################################### + # two mixins and one predictor all inheriting from BaseEstimator + predictor_classes = [PredictorNewTags, PredictorOldNewTags, PredictorOldTags] + array_api_classes = [ + MixinArrayApiSupportNewTags, + MixinArrayApiSupportOldNewTags, + MixinArrayApiSupportOldTags, + ] + allow_nan_classes = [ + MixinAllowNanNewTags, + MixinAllowNanOldNewTags, + MixinAllowNanOldTags, + ] + + for predictor_cls in predictor_classes: + for array_api_cls in array_api_classes: + for allow_nan_cls in allow_nan_classes: + + class ChildClass(allow_nan_cls, array_api_cls, predictor_cls): + pass + + with pytest.warns(FutureWarning, match=warn_msg): + tags = ChildClass()._get_tags() + + assert tags["allow_nan"] + assert tags["array_api_support"] + assert tags["requires_fit"] + + +def test_roundtrip_tags(): + estimator = PredictorNewTags() + tags = default_tags(estimator) + assert _to_new_tags(_to_old_tags(tags), estimator=estimator) == tags + + +def test_base_estimator_more_tags(): + """Test that the `_more_tags` and `_get_tags` methods are equivalent for + `BaseEstimator`. + """ + estimator = BaseEstimator() + with pytest.warns(FutureWarning, match="The `_more_tags` method is deprecated"): + more_tags = BaseEstimator._more_tags(estimator) + + with pytest.warns(FutureWarning, match="The `_get_tags` method is deprecated"): + get_tags = BaseEstimator._get_tags(estimator) + + assert more_tags == get_tags + + +def test_safe_tags(): + estimator = PredictorNewTags() + with pytest.warns(FutureWarning, match="The `_safe_tags` function is deprecated"): + tags = _safe_tags(estimator) + + with pytest.warns(FutureWarning, match="The `_safe_tags` function is deprecated"): + tags_requires_fit = _safe_tags(estimator, key="requires_fit") + + assert tags_requires_fit == tags["requires_fit"] + + err_msg = "The key unknown_key is not defined" + with pytest.raises(ValueError, match=err_msg): + with pytest.warns( + FutureWarning, match="The `_safe_tags` function is deprecated" + ): + _safe_tags(estimator, key="unknown_key") + + +def test_old_tags(): + """Set to non-default values and check that we get the expected old tags.""" + + class MyClass: + _estimator_type = "regressor" + + def __sklearn_tags__(self): + input_tags = InputTags( + one_d_array=True, + two_d_array=False, + three_d_array=True, + sparse=True, + categorical=True, + string=True, + dict=True, + positive_only=True, + allow_nan=True, + pairwise=True, + ) + target_tags = TargetTags( + required=False, + one_d_labels=True, + two_d_labels=True, + positive_only=True, + multi_output=True, + single_output=False, + ) + transformer_tags = None + classifier_tags = None + regressor_tags = RegressorTags( + poor_score=True, + multi_label=True, + ) + return Tags( + estimator_type=self._estimator_type, + input_tags=input_tags, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + ) + + estimator = MyClass() + new_tags = get_tags(estimator) + old_tags = _to_old_tags(new_tags) + expected_tags = { + "allow_nan": True, + "array_api_support": False, + "binary_only": False, + "multilabel": True, + "multioutput": True, + "multioutput_only": True, + "no_validation": False, + "non_deterministic": False, + "pairwise": True, + "preserves_dtype": ["float64"], + "poor_score": True, + "requires_fit": True, + "requires_positive_X": True, + "requires_y": False, + "requires_positive_y": True, + "_skip_test": False, + "stateless": True, + "X_types": [ + "1darray", + "3darray", + "sparse", + "categorical", + "string", + "dict", + "1dlabels", + "2dlabels", + ], + } + assert old_tags == expected_tags + assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags + + class MyClass: + _estimator_type = "classifier" + + def __sklearn_tags__(self): + input_tags = InputTags( + one_d_array=True, + two_d_array=False, + three_d_array=True, + sparse=True, + categorical=True, + string=True, + dict=True, + positive_only=True, + allow_nan=True, + pairwise=True, + ) + target_tags = TargetTags( + required=False, + one_d_labels=True, + two_d_labels=False, + positive_only=True, + multi_output=True, + single_output=False, + ) + transformer_tags = None + classifier_tags = ClassifierTags( + poor_score=True, + multi_class=False, + multi_label=True, + ) + regressor_tags = None + return Tags( + estimator_type=self._estimator_type, + input_tags=input_tags, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + ) + + estimator = MyClass() + new_tags = get_tags(estimator) + old_tags = _to_old_tags(new_tags) + expected_tags = { + "allow_nan": True, + "array_api_support": False, + "binary_only": True, + "multilabel": True, + "multioutput": True, + "multioutput_only": True, + "no_validation": False, + "non_deterministic": False, + "pairwise": True, + "preserves_dtype": ["float64"], + "poor_score": True, + "requires_fit": True, + "requires_positive_X": True, + "requires_y": False, + "requires_positive_y": True, + "_skip_test": False, + "stateless": True, + "X_types": [ + "1darray", + "3darray", + "sparse", + "categorical", + "string", + "dict", + "1dlabels", + ], + } + assert old_tags == expected_tags + assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags + + class MyClass: + + def fit(self, X, y=None): + return self # pragma: no cover + + def transform(self, X): + return X # pragma: no cover + + def __sklearn_tags__(self): + input_tags = InputTags( + one_d_array=True, + two_d_array=False, + three_d_array=True, + sparse=True, + categorical=True, + string=True, + dict=True, + positive_only=True, + allow_nan=True, + pairwise=True, + ) + target_tags = TargetTags( + required=False, + one_d_labels=True, + two_d_labels=False, + positive_only=True, + multi_output=True, + single_output=False, + ) + transformer_tags = TransformerTags( + preserves_dtype=["float64"], + ) + classifier_tags = None + regressor_tags = None + return Tags( + estimator_type=None, + input_tags=input_tags, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + ) + + estimator = MyClass() + new_tags = get_tags(estimator) + old_tags = _to_old_tags(new_tags) + expected_tags = { + "allow_nan": True, + "array_api_support": False, + "binary_only": False, + "multilabel": False, + "multioutput": True, + "multioutput_only": True, + "no_validation": False, + "non_deterministic": False, + "pairwise": True, + "preserves_dtype": ["float64"], + "poor_score": False, + "requires_fit": True, + "requires_positive_X": True, + "requires_y": False, + "requires_positive_y": True, + "_skip_test": False, + "stateless": True, + "X_types": [ + "1darray", + "3darray", + "sparse", + "categorical", + "string", + "dict", + "1dlabels", + ], + } + assert old_tags == expected_tags + assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags