|
1 | 1 | from dataclasses import dataclass, fields
|
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | import pytest
|
4 | 5 |
|
5 | 6 | from sklearn.base import (
|
6 | 7 | BaseEstimator,
|
| 8 | + ClassifierMixin, |
7 | 9 | RegressorMixin,
|
8 | 10 | TransformerMixin,
|
9 | 11 | )
|
| 12 | +from sklearn.pipeline import Pipeline |
10 | 13 | from sklearn.utils import (
|
11 | 14 | ClassifierTags,
|
12 | 15 | InputTags,
|
@@ -637,3 +640,48 @@ def __sklearn_tags__(self):
|
637 | 640 | }
|
638 | 641 | assert old_tags == expected_tags
|
639 | 642 | assert _to_new_tags(_to_old_tags(new_tags), estimator=estimator) == new_tags
|
| 643 | + |
| 644 | + |
| 645 | +# TODO(1.7): Remove this test |
| 646 | +def test_tags_no_sklearn_tags_concrete_implementation(): |
| 647 | + """Non-regression test for: |
| 648 | + https://github.com/scikit-learn/scikit-learn/issues/30479 |
| 649 | +
|
| 650 | + There is no class implementing `__sklearn_tags__` without calling |
| 651 | + `super().__sklearn_tags__()`. Thus, we raise a warning and request to inherit from |
| 652 | + `BaseEstimator` that implements `__sklearn_tags__`. |
| 653 | + """ |
| 654 | + |
| 655 | + class MyEstimator(ClassifierMixin): |
| 656 | + def __init__(self, *, param=1): |
| 657 | + self.param = param |
| 658 | + |
| 659 | + def fit(self, X, y=None): |
| 660 | + self.is_fitted_ = True |
| 661 | + return self |
| 662 | + |
| 663 | + def predict(self, X): |
| 664 | + return np.full(shape=X.shape[0], fill_value=self.param) |
| 665 | + |
| 666 | + X = np.array([[1, 2], [2, 3], [3, 4]]) |
| 667 | + y = np.array([1, 0, 1]) |
| 668 | + |
| 669 | + my_pipeline = Pipeline([("estimator", MyEstimator(param=1))]) |
| 670 | + with pytest.warns(DeprecationWarning, match="The following error was raised"): |
| 671 | + my_pipeline.fit(X, y).predict(X) |
| 672 | + |
| 673 | + # check that we still raise an error if it is not a AttributeError or related to |
| 674 | + # __sklearn_tags__ |
| 675 | + class MyEstimator2(MyEstimator, BaseEstimator): |
| 676 | + def __init__(self, *, param=1, error_type=AttributeError): |
| 677 | + self.param = param |
| 678 | + self.error_type = error_type |
| 679 | + |
| 680 | + def __sklearn_tags__(self): |
| 681 | + super().__sklearn_tags__() |
| 682 | + raise self.error_type("test") |
| 683 | + |
| 684 | + for error_type in (AttributeError, TypeError, ValueError): |
| 685 | + estimator = MyEstimator2(param=1, error_type=error_type) |
| 686 | + with pytest.raises(error_type): |
| 687 | + get_tags(estimator) |
0 commit comments