From 0412969c20dae43b127b5d813e7cef70feb9d132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Carlos=20Alfaro=20Jim=C3=A9nez?= Date: Wed, 1 Jul 2020 19:04:40 +0200 Subject: [PATCH 01/42] MNT Deprecate _estimator_type and replace by estimator_tags --- sklearn/base.py | 57 ++++++++++++++++---- sklearn/feature_selection/_rfe.py | 6 ++- sklearn/feature_selection/tests/test_rfe.py | 11 +++- sklearn/model_selection/_search.py | 8 +++ sklearn/model_selection/tests/test_search.py | 9 ++++ sklearn/pipeline.py | 9 ++++ sklearn/tests/test_base.py | 17 ++++++ sklearn/tests/test_pipeline.py | 8 +++ 8 files changed, 113 insertions(+), 12 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 46398baabfd3a..f6dcde5c1a1c2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -19,6 +19,7 @@ from .utils.validation import check_array from .utils._estimator_html_repr import estimator_html_repr from .utils.validation import _deprecate_positional_args +from .utils.deprecation import deprecated _DEFAULT_TAGS = { 'non_deterministic': False, @@ -37,6 +38,7 @@ 'binary_only': False, 'requires_fit': True, 'requires_y': False, + 'estimator_type': None } @@ -465,7 +467,12 @@ def _repr_mimebundle_(self, **kwargs): class ClassifierMixin: """Mixin class for all classifiers in scikit-learn.""" - _estimator_type = "classifier" + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") + @property + def _estimator_type(self): + return "classifier" def score(self, X, y, sample_weight=None): """ @@ -495,12 +502,18 @@ def score(self, X, y, sample_weight=None): return accuracy_score(y, self.predict(X), sample_weight=sample_weight) def _more_tags(self): - return {'requires_y': True} + return {'requires_y': True, 'estimator_type': 'classifier'} class RegressorMixin: """Mixin class for all regression estimators in scikit-learn.""" - _estimator_type = "regressor" + + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") + @property + def _estimator_type(self): + return "regressor" def score(self, X, y, sample_weight=None): """Return the coefficient of determination R^2 of the prediction. @@ -548,12 +561,18 @@ def score(self, X, y, sample_weight=None): return r2_score(y, y_pred, sample_weight=sample_weight) def _more_tags(self): - return {'requires_y': True} + return {'requires_y': True, 'estimator_type': 'regressor'} class ClusterMixin: """Mixin class for all cluster estimators in scikit-learn.""" - _estimator_type = "clusterer" + + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") + @property + def _estimator_type(self): + return "clusterer" def fit_predict(self, X, y=None): """ @@ -692,7 +711,13 @@ def fit_transform(self, X, y=None, **fit_params): class DensityMixin: """Mixin class for all density estimators in scikit-learn.""" - _estimator_type = "DensityEstimator" + + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") + @property + def _estimator_type(self): + return "DensityEstimator" def score(self, X, y=None): """Return the score of the model on the data X @@ -710,10 +735,19 @@ def score(self, X, y=None): """ pass + def _more_tags(self): + return {'estimator_type': 'DensityEstimator'} + class OutlierMixin: """Mixin class for all outlier detection estimators in scikit-learn.""" - _estimator_type = "outlier_detector" + + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") + @property + def _estimator_type(self): + return "outlier_detector" def fit_predict(self, X, y=None): """Perform fit on X and returns labels for X. @@ -736,6 +770,9 @@ def fit_predict(self, X, y=None): # override for transductive outlier detectors like LocalOulierFactor return self.fit(X).predict(X) + def _more_tags(self): + return {'estimator_type': 'outlier_detector'} + class MetaEstimatorMixin: _required_parameters = ["estimator"] @@ -768,7 +805,7 @@ def is_classifier(estimator): out : bool True if estimator is a classifier and False otherwise. """ - return getattr(estimator, "_estimator_type", None) == "classifier" + return estimator._get_tags()["estimator_type"] == "classifier" def is_regressor(estimator): @@ -784,7 +821,7 @@ def is_regressor(estimator): out : bool True if estimator is a regressor and False otherwise. """ - return getattr(estimator, "_estimator_type", None) == "regressor" + return estimator._get_tags()["estimator_type"] == "regressor" def is_outlier_detector(estimator): @@ -800,4 +837,4 @@ def is_outlier_detector(estimator): out : bool True if estimator is an outlier detector and False otherwise. """ - return getattr(estimator, "_estimator_type", None) == "outlier_detector" + return estimator._get_tags()["estimator_type"] == "outlier_detector" diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 5e0e8ffc6f6d8..af6e4d0ffcd39 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -14,6 +14,7 @@ from ..utils.metaestimators import _safe_split from ..utils.validation import check_is_fitted from ..utils.validation import _deprecate_positional_args +from ..utils.deprecation import deprecated from ..base import BaseEstimator from ..base import MetaEstimatorMixin from ..base import clone @@ -149,6 +150,9 @@ def __init__(self, estimator, *, n_features_to_select=None, step=1, self.importance_getter = importance_getter self.verbose = verbose + # mypy error: Decorated property not supported + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") @property def _estimator_type(self): return self.estimator._estimator_type @@ -365,7 +369,7 @@ def _more_tags(self): return {'poor_score': True, 'allow_nan': estimator_tags.get('allow_nan', True), 'requires_y': True, - } + 'estimator_type': estimator_tags['estimator_type']} class RFECV(RFE): diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index adb371f5fc006..1b19ff677397e 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -57,7 +57,7 @@ def set_params(self, **params): return self def _get_tags(self): - return {} + return {'estimator_type': 'classifier'} def test_rfe_features_importance(): @@ -280,6 +280,7 @@ def test_rfecv_grid_scores_size(): assert rfecv.n_features_ >= min_features_to_select +@ignore_warnings(category=FutureWarning) def test_rfe_estimator_tags(): rfe = RFE(SVC(kernel='linear')) assert rfe._estimator_type == "classifier" @@ -491,3 +492,11 @@ def test_multioutput(ClsRFE): clf = RandomForestClassifier(n_estimators=5) rfe_test = ClsRFE(clf) rfe_test.fit(X, y) + + +# TODO: Remove in version 0.26 +def test_deprecated_estimator_type(): + # Assert that deprecated _estimator_type warns FutureWarning + rfe = RFE(SVC()) + with pytest.warns(FutureWarning): + hasattr(rfe, "_estimator_type") diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index bcdbcdbc498fb..6007408c6631b 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -419,6 +419,8 @@ def __init__(self, estimator, *, scoring=None, n_jobs=None, self.return_train_score = return_train_score @property + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") def _estimator_type(self): return self.estimator._estimator_type @@ -867,6 +869,12 @@ def _store(key_name, array, weights=None, splits=False, rank=False): return results + def _more_tags(self): + estimator_tags = self.estimator._get_tags() + return { + 'estimator_type': estimator_tags['estimator_type'] + } + class GridSearchCV(BaseSearchCV): """Exhaustive search over specified parameter values for an estimator. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index cd01916d28ea9..337837849a160 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -1885,3 +1885,12 @@ def _fit_param_callable(): 'scalar_param': 42, } model.fit(X_train, y_train, **fit_params) + + +# TODO: Remove in version 0.26 +@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV]) +def test_deprecated_estimator_type(SearchCV): + # Assert that deprecated _estimator_type warns FutureWarning + search = SearchCV(DecisionTreeClassifier(), {'max_depth': [5, 10]}) + with pytest.warns(FutureWarning): + hasattr(search, "_estimator_type") diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index cd85f664afde4..1ea54870b14f4 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -22,6 +22,7 @@ from .utils import Bunch, _print_elapsed_time from .utils.validation import check_memory from .utils.validation import _deprecate_positional_args +from .utils.deprecation import deprecated from .utils.metaestimators import _BaseComposition @@ -212,6 +213,8 @@ def __getitem__(self, ind): return est @property + @deprecated("_estimator_type is deprecated in " + "0.24 and will be removed in 0.26") def _estimator_type(self): return self.steps[-1][1]._estimator_type @@ -638,6 +641,12 @@ def _get_name(name, est): name_details=name_details, dash_wrapped=False) + def _more_tags(self): + estimator_tags = self.steps[-1][1]._get_tags() + return { + 'estimator_type': estimator_tags['estimator_type'] + } + def _name_estimators(estimators): """Generate names for estimators.""" diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index db5c88051346a..fe006aa66a33a 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -21,6 +21,10 @@ from sklearn.tree import DecisionTreeRegressor from sklearn import datasets +from sklearn.cluster import KMeans +from sklearn.mixture import BayesianGaussianMixture +from sklearn.ensemble import IsolationForest + from sklearn.base import TransformerMixin from sklearn.utils._mocking import MockDataFrame from sklearn import config_context @@ -537,3 +541,16 @@ def test_repr_html_wraps(): with config_context(display='diagram'): output = tree._repr_html_() assert "