diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index b420c0b40abef..440300b7e44bb 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -315,6 +315,7 @@ Meta-estimators and functions supporting metadata routing: - :class:`sklearn.multioutput.RegressorChain` - :class:`sklearn.pipeline.FeatureUnion` - :class:`sklearn.pipeline.Pipeline` +- :class:`sklearn.semi_supervised.SelfTrainingClassifier` Meta-estimators and tools not supporting metadata routing yet: @@ -324,4 +325,3 @@ Meta-estimators and tools not supporting metadata routing yet: - :class:`sklearn.feature_selection.RFECV` - :class:`sklearn.feature_selection.SequentialFeatureSelector` - :class:`sklearn.model_selection.permutation_test_score` -- :class:`sklearn.semi_supervised.SelfTrainingClassifier` diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index c98314d5ca1de..4f1ee132d95b5 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -71,7 +71,7 @@ more details. :class:`ensemble.StackingRegressor` now support metadata routing and pass ``**fit_params`` to the underlying estimators via their `fit` methods. :pr:`28701` by :user:`Stefanie Senger `. - + - |Feature| :class:`compose.TransformedTargetRegressor` now supports metadata routing in its `fit` and `predict` methods and routes the corresponding params to the underlying regressor. @@ -81,6 +81,14 @@ more details. the `fit` method of its estimator and for its underlying CV splitter and scorer. :pr:`29329` by :user:`Stefanie Senger `. +- |Feature| :class:`semi_supervised.SelfTrainingClassifier` + now supports metadata routing. The fit method now accepts ``**fit_params`` + which are passed to the underlying estimators via their `fit` methods. + In addition, the `predict`, `predict_proba`, `predict_log_proba`, `score` + and `decision_function` methods also accept ``**params`` which are + passed to the underlying estimators via their respective methods. + :pr:`28494` by :user:`Adam Li `. + Dropping official support for PyPy ---------------------------------- @@ -189,6 +197,13 @@ Changelog when duplicate values in the training data lead to inaccurate outlier detection. :pr:`28773` by :user:`Henrique Caroço `. +:mod:`sklearn.semi_supervised` +.............................. + +- |API| :class:`semi_supervised.SelfTrainingClassifier` + deprecated the `base_estimator` parameter in favor of `estimator`. + :pr:`28494` by :user:`Adam Li `. + Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.5, including: diff --git a/sklearn/semi_supervised/_self_training.py b/sklearn/semi_supervised/_self_training.py index 647f48204414a..b1ebea1061e4c 100644 --- a/sklearn/semi_supervised/_self_training.py +++ b/sklearn/semi_supervised/_self_training.py @@ -1,12 +1,19 @@ import warnings from numbers import Integral, Real +from warnings import warn import numpy as np from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone -from ..utils import safe_mask -from ..utils._param_validation import HasMethods, Interval, StrOptions -from ..utils.metadata_routing import _RoutingNotSupportedMixin +from ..utils import Bunch, safe_mask +from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions +from ..utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + _raise_for_params, + _routing_enabled, + process_routing, +) from ..utils.metaestimators import available_if from ..utils.validation import check_is_fitted @@ -19,25 +26,23 @@ def _estimator_has(attr): """Check if we can delegate a method to the underlying estimator. - First, we check the fitted `base_estimator_` if available, otherwise we check - the unfitted `base_estimator`. We raise the original `AttributeError` if + First, we check the fitted `estimator_` if available, otherwise we check + the unfitted `estimator`. We raise the original `AttributeError` if `attr` does not exist. This function is used together with `available_if`. """ def check(self): - if hasattr(self, "base_estimator_"): - getattr(self.base_estimator_, attr) + if hasattr(self, "estimator_"): + getattr(self.estimator_, attr) else: - getattr(self.base_estimator, attr) + getattr(self.estimator, attr) return True return check -class SelfTrainingClassifier( - _RoutingNotSupportedMixin, MetaEstimatorMixin, BaseEstimator -): +class SelfTrainingClassifier(MetaEstimatorMixin, BaseEstimator): """Self-training classifier. This :term:`metaestimator` allows a given supervised classifier to function as a @@ -52,10 +57,22 @@ class SelfTrainingClassifier( Parameters ---------- + estimator : estimator object + An estimator object implementing `fit` and `predict_proba`. + Invoking the `fit` method will fit a clone of the passed estimator, + which will be stored in the `estimator_` attribute. + + .. versionadded:: 1.6 + `estimator` was added to replace `base_estimator`. + base_estimator : estimator object An estimator object implementing `fit` and `predict_proba`. Invoking the `fit` method will fit a clone of the passed estimator, - which will be stored in the `base_estimator_` attribute. + which will be stored in the `estimator_` attribute. + + .. deprecated:: 1.6 + `base_estimator` was deprecated in 1.6 and will be removed in 1.8. + Use `estimator` instead. threshold : float, default=0.75 The decision threshold for use with `criterion='threshold'`. @@ -85,12 +102,12 @@ class SelfTrainingClassifier( Attributes ---------- - base_estimator_ : estimator object + estimator_ : estimator object The fitted estimator. classes_ : ndarray or list of ndarray of shape (n_classes,) Class labels for each output. (Taken from the trained - `base_estimator_`). + `estimator_`). transduction_ : ndarray of shape (n_samples,) The labels used for the final fit of the classifier, including @@ -159,7 +176,13 @@ class SelfTrainingClassifier( _parameter_constraints: dict = { # We don't require `predic_proba` here to allow passing a meta-estimator # that only exposes `predict_proba` after fitting. - "base_estimator": [HasMethods(["fit"])], + # TODO(1.8) remove None option + "estimator": [None, HasMethods(["fit"])], + # TODO(1.8) remove + "base_estimator": [ + HasMethods(["fit"]), + Hidden(StrOptions({"deprecated"})), + ], "threshold": [Interval(Real, 0.0, 1.0, closed="left")], "criterion": [StrOptions({"threshold", "k_best"})], "k_best": [Interval(Integral, 1, None, closed="left")], @@ -169,25 +192,63 @@ class SelfTrainingClassifier( def __init__( self, - base_estimator, + estimator=None, + base_estimator="deprecated", threshold=0.75, criterion="threshold", k_best=10, max_iter=10, verbose=False, ): - self.base_estimator = base_estimator + self.estimator = estimator self.threshold = threshold self.criterion = criterion self.k_best = k_best self.max_iter = max_iter self.verbose = verbose + # TODO(1.8) remove + self.base_estimator = base_estimator + + def _get_estimator(self): + """Get the estimator. + + Returns + ------- + estimator_ : estimator object + The cloned estimator object. + """ + # TODO(1.8): remove and only keep clone(self.estimator) + if self.estimator is None and self.base_estimator != "deprecated": + estimator_ = clone(self.base_estimator) + + warn( + ( + "`base_estimator` has been deprecated in 1.6 and will be removed" + " in 1.8. Please use `estimator` instead." + ), + FutureWarning, + ) + # TODO(1.8) remove + elif self.estimator is None and self.base_estimator == "deprecated": + raise ValueError( + "You must pass an estimator to SelfTrainingClassifier." + " Use `estimator`." + ) + elif self.estimator is not None and self.base_estimator != "deprecated": + raise ValueError( + "You must pass only one estimator to SelfTrainingClassifier." + " Use `estimator`." + ) + else: + estimator_ = clone(self.estimator) + return estimator_ + @_fit_context( - # SelfTrainingClassifier.base_estimator is not validated yet + # SelfTrainingClassifier.estimator is not validated yet prefer_skip_nested_validation=False ) - def fit(self, X, y): + def fit(self, X, y, **params): """ Fit self-training classifier using `X`, `y` as training data. @@ -200,19 +261,31 @@ def fit(self, X, y): Array representing the labels. Unlabeled samples should have the label -1. + **params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Fitted estimator. """ + _raise_for_params(params, self, "fit") + + self.estimator_ = self._get_estimator() + # we need row slicing support for sparse matrices, but costly finiteness check # can be delegated to the base estimator. X, y = self._validate_data( X, y, accept_sparse=["csr", "csc", "lil", "dok"], force_all_finite=False ) - self.base_estimator_ = clone(self.base_estimator) - if y.dtype.kind in ["U", "S"]: raise ValueError( "y has dtype string. If you wish to predict on " @@ -237,6 +310,11 @@ def fit(self, X, y): UserWarning, ) + if _routing_enabled(): + routed_params = process_routing(self, "fit", **params) + else: + routed_params = Bunch(estimator=Bunch(fit={})) + self.transduction_ = np.copy(y) self.labeled_iter_ = np.full_like(y, -1) self.labeled_iter_[has_label] = 0 @@ -247,13 +325,15 @@ def fit(self, X, y): self.max_iter is None or self.n_iter_ < self.max_iter ): self.n_iter_ += 1 - self.base_estimator_.fit( - X[safe_mask(X, has_label)], self.transduction_[has_label] + self.estimator_.fit( + X[safe_mask(X, has_label)], + self.transduction_[has_label], + **routed_params.estimator.fit, ) # Predict on the unlabeled samples - prob = self.base_estimator_.predict_proba(X[safe_mask(X, ~has_label)]) - pred = self.base_estimator_.classes_[np.argmax(prob, axis=1)] + prob = self.estimator_.predict_proba(X[safe_mask(X, ~has_label)]) + pred = self.estimator_.classes_[np.argmax(prob, axis=1)] max_proba = np.max(prob, axis=1) # Select new labeled samples @@ -291,14 +371,16 @@ def fit(self, X, y): if np.all(has_label): self.termination_condition_ = "all_labeled" - self.base_estimator_.fit( - X[safe_mask(X, has_label)], self.transduction_[has_label] + self.estimator_.fit( + X[safe_mask(X, has_label)], + self.transduction_[has_label], + **routed_params.estimator.fit, ) - self.classes_ = self.base_estimator_.classes_ + self.classes_ = self.estimator_.classes_ return self @available_if(_estimator_has("predict")) - def predict(self, X): + def predict(self, X, **params): """Predict the classes of `X`. Parameters @@ -306,22 +388,40 @@ def predict(self, X): X : {array-like, sparse matrix} of shape (n_samples, n_features) Array representing the data. + **params : dict of str -> object + Parameters to pass to the underlying estimator's ``predict`` method. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y : ndarray of shape (n_samples,) Array with predicted labels. """ check_is_fitted(self) + _raise_for_params(params, self, "predict") + + if _routing_enabled(): + # metadata routing is enabled. + routed_params = process_routing(self, "predict", **params) + else: + routed_params = Bunch(estimator=Bunch(predict={})) + X = self._validate_data( X, accept_sparse=True, force_all_finite=False, reset=False, ) - return self.base_estimator_.predict(X) + return self.estimator_.predict(X, **routed_params.estimator.predict) @available_if(_estimator_has("predict_proba")) - def predict_proba(self, X): + def predict_proba(self, X, **params): """Predict probability for each possible outcome. Parameters @@ -329,45 +429,85 @@ def predict_proba(self, X): X : {array-like, sparse matrix} of shape (n_samples, n_features) Array representing the data. + **params : dict of str -> object + Parameters to pass to the underlying estimator's + ``predict_proba`` method. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y : ndarray of shape (n_samples, n_features) Array with prediction probabilities. """ check_is_fitted(self) + _raise_for_params(params, self, "predict_proba") + + if _routing_enabled(): + # metadata routing is enabled. + routed_params = process_routing(self, "predict_proba", **params) + else: + routed_params = Bunch(estimator=Bunch(predict_proba={})) + X = self._validate_data( X, accept_sparse=True, force_all_finite=False, reset=False, ) - return self.base_estimator_.predict_proba(X) + return self.estimator_.predict_proba(X, **routed_params.estimator.predict_proba) @available_if(_estimator_has("decision_function")) - def decision_function(self, X): - """Call decision function of the `base_estimator`. + def decision_function(self, X, **params): + """Call decision function of the `estimator`. Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) Array representing the data. + **params : dict of str -> object + Parameters to pass to the underlying estimator's + ``decision_function`` method. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y : ndarray of shape (n_samples, n_features) - Result of the decision function of the `base_estimator`. + Result of the decision function of the `estimator`. """ check_is_fitted(self) + _raise_for_params(params, self, "decision_function") + + if _routing_enabled(): + # metadata routing is enabled. + routed_params = process_routing(self, "decision_function", **params) + else: + routed_params = Bunch(estimator=Bunch(decision_function={})) + X = self._validate_data( X, accept_sparse=True, force_all_finite=False, reset=False, ) - return self.base_estimator_.decision_function(X) + return self.estimator_.decision_function( + X, **routed_params.estimator.decision_function + ) @available_if(_estimator_has("predict_log_proba")) - def predict_log_proba(self, X): + def predict_log_proba(self, X, **params): """Predict log probability for each possible outcome. Parameters @@ -375,23 +515,44 @@ def predict_log_proba(self, X): X : {array-like, sparse matrix} of shape (n_samples, n_features) Array representing the data. + **params : dict of str -> object + Parameters to pass to the underlying estimator's + ``predict_log_proba`` method. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y : ndarray of shape (n_samples, n_features) Array with log prediction probabilities. """ check_is_fitted(self) + _raise_for_params(params, self, "predict_log_proba") + + if _routing_enabled(): + # metadata routing is enabled. + routed_params = process_routing(self, "predict_log_proba", **params) + else: + routed_params = Bunch(estimator=Bunch(predict_log_proba={})) + X = self._validate_data( X, accept_sparse=True, force_all_finite=False, reset=False, ) - return self.base_estimator_.predict_log_proba(X) + return self.estimator_.predict_log_proba( + X, **routed_params.estimator.predict_log_proba + ) @available_if(_estimator_has("score")) - def score(self, X, y): - """Call score on the `base_estimator`. + def score(self, X, y, **params): + """Call score on the `estimator`. Parameters ---------- @@ -401,16 +562,64 @@ def score(self, X, y): y : array-like of shape (n_samples,) Array representing the labels. + **params : dict of str -> object + Parameters to pass to the underlying estimator's ``score`` method. + + .. versionadded:: 1.6 + Only available if `enable_metadata_routing=True`, + which can be set by using + ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- score : float - Result of calling score on the `base_estimator`. + Result of calling score on the `estimator`. """ check_is_fitted(self) + _raise_for_params(params, self, "score") + + if _routing_enabled(): + # metadata routing is enabled. + routed_params = process_routing(self, "score", **params) + else: + routed_params = Bunch(estimator=Bunch(score={})) + X = self._validate_data( X, accept_sparse=True, force_all_finite=False, reset=False, ) - return self.base_estimator_.score(X, y) + return self.estimator_.score(X, y, **routed_params.estimator.score) + + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + .. versionadded:: 1.6 + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + router.add( + estimator=self.estimator, + method_mapping=( + MethodMapping() + .add(callee="fit", caller="fit") + .add(callee="score", caller="fit") + .add(callee="predict", caller="predict") + .add(callee="predict_proba", caller="predict_proba") + .add(callee="decision_function", caller="decision_function") + .add(callee="predict_log_proba", caller="predict_log_proba") + .add(callee="score", caller="score") + ), + ) + return router diff --git a/sklearn/semi_supervised/tests/test_self_training.py b/sklearn/semi_supervised/tests/test_self_training.py index 29b8f1ac6e87c..02244063994d5 100644 --- a/sklearn/semi_supervised/tests/test_self_training.py +++ b/sklearn/semi_supervised/tests/test_self_training.py @@ -12,6 +12,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.semi_supervised import SelfTrainingClassifier from sklearn.svm import SVC +from sklearn.tests.test_pipeline import SimpleEstimator from sklearn.tree import DecisionTreeClassifier # Authors: The scikit-learn developers @@ -43,25 +44,25 @@ def test_warns_k_best(): @pytest.mark.parametrize( - "base_estimator", + "estimator", [KNeighborsClassifier(), SVC(gamma="scale", probability=True, random_state=0)], ) @pytest.mark.parametrize("selection_crit", ["threshold", "k_best"]) -def test_classification(base_estimator, selection_crit): +def test_classification(estimator, selection_crit): # Check classification for various parameter settings. # Also assert that predictions for strings and numerical labels are equal. # Also test for multioutput classification threshold = 0.75 max_iter = 10 st = SelfTrainingClassifier( - base_estimator, max_iter=max_iter, threshold=threshold, criterion=selection_crit + estimator, max_iter=max_iter, threshold=threshold, criterion=selection_crit ) st.fit(X_train, y_train_missing_labels) pred = st.predict(X_test) proba = st.predict_proba(X_test) st_string = SelfTrainingClassifier( - base_estimator, max_iter=max_iter, criterion=selection_crit, threshold=threshold + estimator, max_iter=max_iter, criterion=selection_crit, threshold=threshold ) st_string.fit(X_train, y_train_missing_strings) pred_string = st_string.predict(X_test) @@ -112,15 +113,15 @@ def test_k_best(): def test_sanity_classification(): - base_estimator = SVC(gamma="scale", probability=True) - base_estimator.fit(X_train[n_labeled_samples:], y_train[n_labeled_samples:]) + estimator = SVC(gamma="scale", probability=True) + estimator.fit(X_train[n_labeled_samples:], y_train[n_labeled_samples:]) - st = SelfTrainingClassifier(base_estimator) + st = SelfTrainingClassifier(estimator) st.fit(X_train, y_train_missing_labels) - pred1, pred2 = base_estimator.predict(X_test), st.predict(X_test) + pred1, pred2 = estimator.predict(X_test), st.predict(X_test) assert not np.array_equal(pred1, pred2) - score_supervised = accuracy_score(base_estimator.predict(X_test), y_test) + score_supervised = accuracy_score(estimator.predict(X_test), y_test) score_self_training = accuracy_score(st.predict(X_test), y_test) assert score_self_training > score_supervised @@ -137,21 +138,21 @@ def test_none_iter(): @pytest.mark.parametrize( - "base_estimator", + "estimator", [KNeighborsClassifier(), SVC(gamma="scale", probability=True, random_state=0)], ) @pytest.mark.parametrize("y", [y_train_missing_labels, y_train_missing_strings]) -def test_zero_iterations(base_estimator, y): +def test_zero_iterations(estimator, y): # Check classification for zero iterations. # Fitting a SelfTrainingClassifier with zero iterations should give the # same results as fitting a supervised classifier. # This also asserts that string arrays work as expected. - clf1 = SelfTrainingClassifier(base_estimator, max_iter=0) + clf1 = SelfTrainingClassifier(estimator, max_iter=0) clf1.fit(X_train, y) - clf2 = base_estimator.fit(X_train[:n_labeled_samples], y[:n_labeled_samples]) + clf2 = estimator.fit(X_train[:n_labeled_samples], y[:n_labeled_samples]) assert_array_equal(clf1.predict(X_test), clf2.predict(X_test)) assert clf1.termination_condition_ == "max_iter" @@ -280,14 +281,14 @@ def test_k_best_selects_best(): assert row in added_by_st -def test_base_estimator_meta_estimator(): +def test_estimator_meta_estimator(): # Check that a meta-estimator relying on an estimator implementing # `predict_proba` will work even if it does not expose this method before being # fitted. # Non-regression test for: # https://github.com/scikit-learn/scikit-learn/issues/19119 - base_estimator = StackingClassifier( + estimator = StackingClassifier( estimators=[ ("svc_1", SVC(probability=True)), ("svc_2", SVC(probability=True)), @@ -296,12 +297,12 @@ def test_base_estimator_meta_estimator(): cv=2, ) - assert hasattr(base_estimator, "predict_proba") - clf = SelfTrainingClassifier(base_estimator=base_estimator) + assert hasattr(estimator, "predict_proba") + clf = SelfTrainingClassifier(estimator=estimator) clf.fit(X_train, y_train_missing_labels) clf.predict_proba(X_test) - base_estimator = StackingClassifier( + estimator = StackingClassifier( estimators=[ ("svc_1", SVC(probability=False)), ("svc_2", SVC(probability=False)), @@ -310,14 +311,14 @@ def test_base_estimator_meta_estimator(): cv=2, ) - assert not hasattr(base_estimator, "predict_proba") - clf = SelfTrainingClassifier(base_estimator=base_estimator) + assert not hasattr(estimator, "predict_proba") + clf = SelfTrainingClassifier(estimator=estimator) with pytest.raises(AttributeError): clf.fit(X_train, y_train_missing_labels) def test_self_training_estimator_attribute_error(): - """Check that we raise the proper AttributeErrors when the `base_estimator` + """Check that we raise the proper AttributeErrors when the `estimator` does not implement the `predict_proba` method, which is called from within `fit`, or `decision_function`, which is decorated with `available_if`. @@ -327,15 +328,15 @@ def test_self_training_estimator_attribute_error(): # `SVC` with `probability=False` does not implement 'predict_proba' that # is required internally in `fit` of `SelfTrainingClassifier`. We expect # an AttributeError to be raised. - base_estimator = SVC(probability=False, gamma="scale") - self_training = SelfTrainingClassifier(base_estimator) + estimator = SVC(probability=False, gamma="scale") + self_training = SelfTrainingClassifier(estimator) with pytest.raises(AttributeError, match="has no attribute 'predict_proba'"): self_training.fit(X_train, y_train_missing_labels) # `DecisionTreeClassifier` does not implement 'decision_function' and # should raise an AttributeError - self_training = SelfTrainingClassifier(base_estimator=DecisionTreeClassifier()) + self_training = SelfTrainingClassifier(estimator=DecisionTreeClassifier()) outer_msg = "This 'SelfTrainingClassifier' has no attribute 'decision_function'" inner_msg = "'DecisionTreeClassifier' object has no attribute 'decision_function'" @@ -343,3 +344,52 @@ def test_self_training_estimator_attribute_error(): self_training.fit(X_train, y_train_missing_labels).decision_function(X_train) assert isinstance(exec_info.value.__cause__, AttributeError) assert inner_msg in str(exec_info.value.__cause__) + + +# TODO(1.8): remove in 1.8 +def test_deprecation_warning_base_estimator(): + warn_msg = "`base_estimator` has been deprecated in 1.6 and will be removed" + with pytest.warns(FutureWarning, match=warn_msg): + SelfTrainingClassifier(base_estimator=DecisionTreeClassifier()).fit( + X_train, y_train_missing_labels + ) + + error_msg = "You must pass an estimator to SelfTrainingClassifier" + with pytest.raises(ValueError, match=error_msg): + SelfTrainingClassifier().fit(X_train, y_train_missing_labels) + + error_msg = "You must pass only one estimator to SelfTrainingClassifier." + with pytest.raises(ValueError, match=error_msg): + SelfTrainingClassifier( + base_estimator=DecisionTreeClassifier(), estimator=DecisionTreeClassifier() + ).fit(X_train, y_train_missing_labels) + + +# Metadata routing tests +# ================================================================= + + +@pytest.mark.filterwarnings("ignore:y contains no unlabeled samples:UserWarning") +@pytest.mark.parametrize( + "method", ["decision_function", "predict_log_proba", "predict_proba", "predict"] +) +def test_routing_passed_metadata_not_supported(method): + """Test that the right error message is raised when metadata is passed while + not supported when `enable_metadata_routing=False`.""" + est = SelfTrainingClassifier(estimator=SimpleEstimator()) + with pytest.raises( + ValueError, match="is only supported if enable_metadata_routing=True" + ): + est.fit([[1], [1]], [1, 1], sample_weight=[1], prop="a") + + est = SelfTrainingClassifier(estimator=SimpleEstimator()) + with pytest.raises( + ValueError, match="is only supported if enable_metadata_routing=True" + ): + # make sure that the estimator thinks it is already fitted + est.fitted_params_ = True + getattr(est, method)([[1]], sample_weight=[1], prop="a") + + +# End of routing tests +# ==================== diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 0af522f9f9342..5fffec8fccecf 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -215,6 +215,17 @@ def predict(self, X): y_pred[len(X) // 2 :] = 1 return y_pred + def predict_proba(self, X): + # dummy probabilities to support predict_proba + y_proba = np.empty(shape=(len(X), 2)) + y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) + y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) + return y_proba + + def predict_log_proba(self, X): + # dummy probabilities to support predict_log_proba + return self.predict_proba(X) + class NonConsumingRegressor(RegressorMixin, BaseEstimator): """A classifier which accepts no metadata on any method.""" @@ -291,13 +302,10 @@ def predict_proba(self, X, sample_weight="default", metadata="default"): return y_proba def predict_log_proba(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # uncomment when needed - # record_metadata_not_default( - # self, sample_weight=sample_weight, metadata=metadata - # ) - # return np.zeros(shape=(len(X), 2)) + record_metadata_not_default( + self, sample_weight=sample_weight, metadata=metadata + ) + return np.zeros(shape=(len(X), 2)) def decision_function(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( @@ -308,12 +316,11 @@ def decision_function(self, X, sample_weight="default", metadata="default"): y_score[: len(X) // 2] = 1 return y_score - # uncomment when needed - # def score(self, X, y, sample_weight="default", metadata="default"): - # record_metadata_not_default( - # self, sample_weight=sample_weight, metadata=metadata - # ) - # return 1 + def score(self, X, y, sample_weight="default", metadata="default"): + record_metadata_not_default( + self, sample_weight=sample_weight, metadata=metadata + ) + return 1 class ConsumingTransformer(TransformerMixin, BaseEstimator): diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index cf2bb130267a3..9aca241521ca0 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -390,6 +390,23 @@ def enable_slep006(): "y": y, "estimator_routing_methods": ["fit", "predict"], }, + { + "metaestimator": SelfTrainingClassifier, + "estimator_name": "estimator", + "estimator": "classifier", + "X": X, + "y": y, + "preserves_metadata": True, + "estimator_routing_methods": [ + "fit", + "predict", + "predict_proba", + "predict_log_proba", + "decision_function", + "score", + ], + "method_mapping": {"fit": ["fit", "score"]}, + }, ] """List containing all metaestimators to be tested and their settings @@ -433,7 +450,6 @@ def enable_slep006(): AdaBoostRegressor(), RFE(ConsumingClassifier()), RFECV(ConsumingClassifier()), - SelfTrainingClassifier(ConsumingClassifier()), SequentialFeatureSelector(ConsumingClassifier()), ] @@ -640,7 +656,7 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator): value=None, ) try: - # `fit` and `partial_fit` accept y, others don't. + # `fit`, `partial_fit`, 'score' accept y, others don't. method(X, y, **method_kwargs) except TypeError: method(X, **method_kwargs) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 273aa4e9d36e4..b9fba86d01e9b 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1822,8 +1822,8 @@ class SimpleEstimator(BaseEstimator): # This class is used in this section for testing routing in the pipeline. # This class should have every set_{method}_request def fit(self, X, y, sample_weight=None, prop=None): - assert sample_weight is not None - assert prop is not None + assert sample_weight is not None, sample_weight + assert prop is not None, prop return self def fit_transform(self, X, y, sample_weight=None, prop=None):