diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 3ef8a7653b5f7..31ebe9e2d21d6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -136,6 +136,10 @@ Changelog scikit-learn 1.3: retraining with scikit-learn 1.3 is required. :pr:`25186` by :user:`Felipe Breve Siola `. +- |Enhancement| :class:`ensemble.BaggingClassifier` and + :class:`ensemble.BaggingRegressor` expose the `allow_nan` tag from the + underlying estimator. :pr:`25506` by `Thomas Fan`_. + :mod:`sklearn.exception` ........................ - |Feature| Added :class:`exception.InconsistentVersionWarning` which is raised diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 4586e55a59f97..d10f89102ea82 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -23,6 +23,7 @@ from ..utils.random import sample_without_replacement from ..utils._param_validation import Interval, HasMethods, StrOptions from ..utils.validation import has_fit_parameter, check_is_fitted, _check_sample_weight +from ..utils._tags import _safe_tags from ..utils.parallel import delayed, Parallel @@ -981,6 +982,14 @@ def decision_function(self, X): return decisions + def _more_tags(self): + if self.estimator is None: + estimator = DecisionTreeClassifier() + else: + estimator = self.estimator + + return {"allow_nan": _safe_tags(estimator, "allow_nan")} + class BaggingRegressor(RegressorMixin, BaseBagging): """A Bagging regressor. @@ -1261,3 +1270,10 @@ def _set_oob_score(self, X, y): self.oob_prediction_ = predictions self.oob_score_ = r2_score(y, predictions) + + def _more_tags(self): + if self.estimator is None: + estimator = DecisionTreeRegressor() + else: + estimator = self.estimator + return {"allow_nan": _safe_tags(estimator, "allow_nan")} diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 330287cefef37..ebe21a594e8eb 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -25,6 +25,8 @@ from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectKBest from sklearn.model_selection import train_test_split +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.datasets import load_diabetes, load_iris, make_hastie_10_2 from sklearn.utils import check_random_state from sklearn.preprocessing import FunctionTransformer, scale @@ -980,3 +982,17 @@ def test_deprecated_base_estimator_has_decision_function(): with pytest.warns(FutureWarning, match=warn_msg): y_decision = clf.fit(X, y).decision_function(X) assert y_decision.shape == (150, 3) + + +@pytest.mark.parametrize( + "bagging, expected_allow_nan", + [ + (BaggingClassifier(HistGradientBoostingClassifier(max_iter=1)), True), + (BaggingRegressor(HistGradientBoostingRegressor(max_iter=1)), True), + (BaggingClassifier(LogisticRegression()), False), + (BaggingRegressor(SVR()), False), + ], +) +def test_bagging_allow_nan_tag(bagging, expected_allow_nan): + """Check that bagging inherits allow_nan tag.""" + assert bagging._get_tags()["allow_nan"] == expected_allow_nan