diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 68ea8ba0f7a72..2cfe6970dd7b1 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -48,6 +48,13 @@ Changelog `'use_encoded_value'` strategies. :pr:`19234` by `Guillaume Lemaitre `. +:mod:`sklearn.multioutput` +.......................... + +- |Fix| :class:`multioutput.MultiOutputRegressor` now works with estimators + that dynamically define `predict` during fitting, such as + :class:`ensemble.StackingRegressor`. :pr:`19308` by `Thomas Fan`_. + :mod:`sklearn.semi_supervised` .............................. diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 9987c01b13187..4cb01c524d59d 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -198,7 +198,7 @@ def predict(self, X): Note: Separate models are generated for each predictor. """ check_is_fitted(self) - if not hasattr(self.estimator, "predict"): + if not hasattr(self.estimators_[0], "predict"): raise ValueError("The base estimator should implement" " a predict method") diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 87e5218e08e22..c20db084aa664 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -10,6 +10,7 @@ from sklearn import datasets from sklearn.base import clone from sklearn.datasets import make_classification +from sklearn.datasets import load_linnerud from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier from sklearn.exceptions import NotFittedError from sklearn.linear_model import Lasso @@ -30,6 +31,7 @@ from sklearn.dummy import DummyRegressor, DummyClassifier from sklearn.pipeline import make_pipeline from sklearn.impute import SimpleImputer +from sklearn.ensemble import StackingRegressor def test_multi_target_regression(): @@ -658,3 +660,19 @@ def test_classifier_chain_tuple_invalid_order(): with pytest.raises(ValueError, match='invalid order'): chain.fit(X, y) + + +def test_multioutputregressor_ducktypes_fitted_estimator(): + """Test that MultiOutputRegressor checks the fitted estimator for + predict. Non-regression test for #16549.""" + X, y = load_linnerud(return_X_y=True) + stacker = StackingRegressor( + estimators=[("sgd", SGDRegressor(random_state=1))], + final_estimator=Ridge(), + cv=2 + ) + + reg = MultiOutputRegressor(estimator=stacker).fit(X, y) + + # Does not raise + reg.predict(X)