Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8dc7d56

Browse files
SLEP006: Metadata routing for bagging
This PR adds metadata routing to BaggingClassifier and BaggingRegressor (see scikit-learn#22893). With this change, in addition to sample_weight, which was already supported, it's now also possible to pass arbitrary fit_params to the sub estimator. Implementation Most of the changes should be pretty straightforward with the existing infrastructure for testing metadata routing. There was one aspect which was not quite trivial though: The current implementation of bagging works by inspecting the sub estimator's fit method. If the sub estimator supports sample_weight, then subsampling is performed by making use of sample weight. This will also happen if the user does not explicitly pass sample weight. At first, I wanted to change the implementation such that if sample weights are requested, subsampling should use the sample weight approach, otherwise it shouldn't. However, that breaks a couple of tests, so I rolled back the change and stuck very closely to the existing implementation. I can't judge if this prevents the user from doing certain things or if subsampling using vs not using sample_weight are equivalent. Coincidental changes The method _validate_estimator on the BaseEnsemble class used to validate, and then set as attribute, the sub estimator. This was inconvenient because for get_metadata_routing, we want to fetch the sub estimator, which is not easily possible with this method. Therefore, a change was introduced that the method now returns the sub estimator and the caller is now responsible for setting it as an attribute. This has the added advantages that the caller can now decide the attribute name and that this method now more closely mirrors _BaseHeterogeneousEnsemble._validate_estimators. Affected by this change are random forests, extra trees, and ada boosting. The function process_routing used to mutate the incoming param dict (adding new items), now it creates a shallow copy first. Extended docstring for check_input of BaseBagging._fit. Testing I noticed that the bagging tests didn't have a test case for sparse input + using sample weights, so I extended an existing test to cover it. The test test_bagging_sample_weight_unsupported_but_passed now raises a TypeError, not ValueError, when sample_weight are passed but not supported.
1 parent 0afaa63 commit 8dc7d56

File tree

7 files changed

+177
-42
lines changed

7 files changed

+177
-42
lines changed

sklearn/ensemble/_bagging.py

+70-18
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..utils.multiclass import check_classification_targets
2525
from ..utils.random import sample_without_replacement
2626
from ..utils._param_validation import Interval
27+
from ..utils.metadata_routing import MetadataRouter, MethodMapping, process_routing
2728
from ..utils.validation import has_fit_parameter, check_is_fitted, _check_sample_weight
2829
from ..utils.fixes import delayed
2930

@@ -75,11 +76,11 @@ def _parallel_build_estimators(
7576
ensemble,
7677
X,
7778
y,
78-
sample_weight,
7979
seeds,
8080
total_n_estimators,
8181
verbose,
8282
check_input,
83+
fit_params,
8384
):
8485
"""Private function used to build a batch of estimators within a job."""
8586
# Retrieve settings
@@ -88,13 +89,9 @@ def _parallel_build_estimators(
8889
max_samples = ensemble._max_samples
8990
bootstrap = ensemble.bootstrap
9091
bootstrap_features = ensemble.bootstrap_features
91-
support_sample_weight = has_fit_parameter(ensemble.base_estimator_, "sample_weight")
9292
has_check_input = has_fit_parameter(ensemble.base_estimator_, "check_input")
9393
requires_feature_indexing = bootstrap_features or max_features != n_features
9494

95-
if not support_sample_weight and sample_weight is not None:
96-
raise ValueError("The base estimator doesn't support sample weight")
97-
9895
# Build estimators
9996
estimators = []
10097
estimators_features = []
@@ -126,7 +123,11 @@ def _parallel_build_estimators(
126123
)
127124

128125
# Draw samples, using sample weights, and then fit
126+
support_sample_weight = has_fit_parameter(
127+
ensemble.base_estimator_, "sample_weight"
128+
)
129129
if support_sample_weight:
130+
sample_weight = fit_params.get("sample_weight")
130131
if sample_weight is None:
131132
curr_sample_weight = np.ones((n_samples,))
132133
else:
@@ -139,8 +140,11 @@ def _parallel_build_estimators(
139140
not_indices_mask = ~indices_to_mask(indices, n_samples)
140141
curr_sample_weight[not_indices_mask] = 0
141142

143+
fit_params = {
144+
key: val for key, val in fit_params.items() if key != "sample_weight"
145+
}
142146
X_ = X[:, features] if requires_feature_indexing else X
143-
estimator_fit(X_, y, sample_weight=curr_sample_weight)
147+
estimator_fit(X_, y, sample_weight=curr_sample_weight, **fit_params)
144148
else:
145149
X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
146150
estimator_fit(X_, y[indices])
@@ -287,7 +291,7 @@ def __init__(
287291
self.random_state = random_state
288292
self.verbose = verbose
289293

290-
def fit(self, X, y, sample_weight=None):
294+
def fit(self, X, y, sample_weight=None, **fit_params):
291295
"""Build a Bagging ensemble of estimators from the training set (X, y).
292296
293297
Parameters
@@ -322,11 +326,17 @@ def fit(self, X, y, sample_weight=None):
322326
force_all_finite=False,
323327
multi_output=True,
324328
)
325-
return self._fit(X, y, self.max_samples, sample_weight=sample_weight)
329+
return self._fit(
330+
X, y, self.max_samples, sample_weight=sample_weight, **fit_params
331+
)
326332

327333
def _parallel_args(self):
328334
return {}
329335

336+
def _get_estimator(self):
337+
# should be overridden by child classes
338+
return None
339+
330340
def _fit(
331341
self,
332342
X,
@@ -335,6 +345,7 @@ def _fit(
335345
max_depth=None,
336346
sample_weight=None,
337347
check_input=True,
348+
**fit_params,
338349
):
339350
"""Build a Bagging ensemble of estimators from the training
340351
set (X, y).
@@ -364,6 +375,11 @@ def _fit(
364375
check_input : bool, default=True
365376
Override value used when fitting base estimator. Only supported
366377
if the base estimator has a check_input parameter for fit function.
378+
If the metaestimator already checks the input, set this value to
379+
False to prevent redundant input checking (#23149).
380+
381+
fit_params : dict, default=None
382+
Parameters to pass to the `fit` method of the underlying estimator.
367383
368384
Returns
369385
-------
@@ -381,7 +397,14 @@ def _fit(
381397
y = self._validate_y(y)
382398

383399
# Check parameters
384-
self._validate_estimator()
400+
self.base_estimator_ = self._validate_estimator(self._get_estimator())
401+
402+
routed_params = process_routing(
403+
obj=self,
404+
method="fit",
405+
sample_weight=sample_weight,
406+
other_params=fit_params,
407+
)
385408

386409
if max_depth is not None:
387410
self.base_estimator_.max_depth = max_depth
@@ -465,11 +488,11 @@ def _fit(
465488
self,
466489
X,
467490
y,
468-
sample_weight,
469-
seeds[starts[i] : starts[i + 1]],
470-
total_n_estimators,
491+
seeds=seeds[starts[i] : starts[i + 1]],
492+
total_n_estimators=total_n_estimators,
471493
verbose=self.verbose,
472494
check_input=check_input,
495+
fit_params=routed_params.base_estimator.fit,
473496
)
474497
for i in range(n_jobs)
475498
)
@@ -538,6 +561,33 @@ def estimators_samples_(self):
538561
def n_features_(self):
539562
return self.n_features_in_
540563

564+
def get_metadata_routing(self):
565+
"""Get metadata routing of this object.
566+
567+
Please check :ref:`User Guide <metadata_routing>` on how the routing
568+
mechanism works.
569+
570+
Returns
571+
-------
572+
routing : MetadataRouter
573+
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
574+
routing information.
575+
"""
576+
base_estimator = self._validate_estimator(self._get_estimator())
577+
router = (
578+
MetadataRouter(owner=self.__class__.__name__)
579+
# no add_self(self) because the bagging metaestimator does not use
580+
# fit_params
581+
.add(
582+
base_estimator=base_estimator,
583+
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
584+
)
585+
# warn on sample_weight because that was already supported, the rest
586+
# raises
587+
.warn_on(child="base_estimator", method="fit", params=["sample_weight"])
588+
)
589+
return router
590+
541591

542592
class BaggingClassifier(ClassifierMixin, BaseBagging):
543593
"""A Bagging classifier.
@@ -738,9 +788,10 @@ def __init__(
738788
verbose=verbose,
739789
)
740790

741-
def _validate_estimator(self):
742-
"""Check the estimator and set the base_estimator_ attribute."""
743-
super()._validate_estimator(default=DecisionTreeClassifier())
791+
def _get_estimator(self):
792+
if self.base_estimator is None:
793+
return DecisionTreeClassifier()
794+
return self.base_estimator
744795

745796
def _set_oob_score(self, X, y):
746797
n_samples = y.shape[0]
@@ -1198,9 +1249,10 @@ def predict(self, X):
11981249

11991250
return y_hat
12001251

1201-
def _validate_estimator(self):
1202-
"""Check the estimator and set the base_estimator_ attribute."""
1203-
super()._validate_estimator(default=DecisionTreeRegressor())
1252+
def _get_estimator(self):
1253+
if self.base_estimator is None:
1254+
return DecisionTreeRegressor()
1255+
return self.base_estimator
12041256

12051257
def _set_oob_score(self, X, y):
12061258
n_samples = y.shape[0]

sklearn/ensemble/_base.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(self, base_estimator, *, n_estimators=10, estimator_params=tuple())
130130
def _validate_estimator(self, default=None):
131131
"""Check the estimator and the n_estimator attribute.
132132
133-
Sets the base_estimator_` attributes.
133+
Returns the base estimator instance.
134134
"""
135135
if not isinstance(self.n_estimators, numbers.Integral):
136136
raise ValueError(
@@ -147,13 +147,15 @@ def _validate_estimator(self, default=None):
147147
)
148148

149149
if self.base_estimator is not None:
150-
self.base_estimator_ = self.base_estimator
150+
base_estimator = self.base_estimator
151151
else:
152-
self.base_estimator_ = default
152+
base_estimator = default
153153

154-
if self.base_estimator_ is None:
154+
if base_estimator is None:
155155
raise ValueError("base_estimator cannot be None")
156156

157+
return base_estimator
158+
157159
def _make_estimator(self, append=True, random_state=None):
158160
"""Make and configure a copy of the `base_estimator_` attribute.
159161

sklearn/ensemble/_forest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def fit(self, X, y, sample_weight=None):
405405
else:
406406
n_samples_bootstrap = None
407407

408-
self._validate_estimator()
408+
self.base_estimator_ = self._validate_estimator()
409409
# TODO(1.2): Remove "mse" and "mae"
410410
if isinstance(self, (RandomForestRegressor, ExtraTreesRegressor)):
411411
if self.criterion == "mse":

sklearn/ensemble/_weight_boosting.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def fit(self, X, y, sample_weight=None):
136136
sample_weight /= sample_weight.sum()
137137

138138
# Check parameters
139-
self._validate_estimator()
139+
self.base_estimator_ = self._validate_estimator()
140140

141141
# Clear any previous fit results
142142
self.estimators_ = []
@@ -473,23 +473,25 @@ def __init__(
473473

474474
def _validate_estimator(self):
475475
"""Check the estimator and set the base_estimator_ attribute."""
476-
super()._validate_estimator(default=DecisionTreeClassifier(max_depth=1))
476+
base_estimator = super()._validate_estimator(
477+
default=DecisionTreeClassifier(max_depth=1)
478+
)
477479

478480
# SAMME-R requires predict_proba-enabled base estimators
479481
if self.algorithm == "SAMME.R":
480-
if not hasattr(self.base_estimator_, "predict_proba"):
482+
if not hasattr(base_estimator, "predict_proba"):
481483
raise TypeError(
482484
"AdaBoostClassifier with algorithm='SAMME.R' requires "
483485
"that the weak learner supports the calculation of class "
484486
"probabilities with a predict_proba method.\n"
485487
"Please change the base estimator or set "
486488
"algorithm='SAMME' instead."
487489
)
488-
if not has_fit_parameter(self.base_estimator_, "sample_weight"):
490+
if not has_fit_parameter(base_estimator, "sample_weight"):
489491
raise ValueError(
490-
"%s doesn't support sample_weight."
491-
% self.base_estimator_.__class__.__name__
492+
"%s doesn't support sample_weight." % base_estimator.__class__.__name__
492493
)
494+
return base_estimator
493495

494496
def _boost(self, iboost, X, y, sample_weight, random_state):
495497
"""Implement a single boost.
@@ -1031,7 +1033,7 @@ def __init__(
10311033

10321034
def _validate_estimator(self):
10331035
"""Check the estimator and set the base_estimator_ attribute."""
1034-
super()._validate_estimator(default=DecisionTreeRegressor(max_depth=3))
1036+
return super()._validate_estimator(default=DecisionTreeRegressor(max_depth=3))
10351037

10361038
def _boost(self, iboost, X, y, sample_weight, random_state):
10371039
"""Implement a single boost for regression

0 commit comments

Comments
 (0)