You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR adds metadata routing to BaggingClassifier and
BaggingRegressor (see scikit-learn#22893).
With this change, in addition to sample_weight, which was already
supported, it's now also possible to pass arbitrary fit_params to the
sub estimator.
Implementation
Most of the changes should be pretty straightforward with the existing
infrastructure for testing metadata routing. There was one aspect which
was not quite trivial though: The current implementation of bagging
works by inspecting the sub estimator's fit method. If the sub estimator
supports sample_weight, then subsampling is performed by making use of
sample weight. This will also happen if the user does not explicitly
pass sample weight.
At first, I wanted to change the implementation such that if sample
weights are requested, subsampling should use the sample weight
approach, otherwise it shouldn't. However, that breaks a couple of
tests, so I rolled back the change and stuck very closely to the
existing implementation. I can't judge if this prevents the user from
doing certain things or if subsampling using vs not using sample_weight
are equivalent.
Coincidental changes
The method _validate_estimator on the BaseEnsemble class used to
validate, and then set as attribute, the sub estimator. This was
inconvenient because for get_metadata_routing, we want to fetch the sub
estimator, which is not easily possible with this method. Therefore, a
change was introduced that the method now returns the sub estimator and
the caller is now responsible for setting it as an attribute. This has
the added advantages that the caller can now decide the attribute name
and that this method now more closely mirrors
_BaseHeterogeneousEnsemble._validate_estimators. Affected by this change
are random forests, extra trees, and ada boosting.
The function process_routing used to mutate the incoming param
dict (adding new items), now it creates a shallow copy first.
Extended docstring for check_input of BaseBagging._fit.
Testing
I noticed that the bagging tests didn't have a test case for sparse
input + using sample weights, so I extended an existing test to cover
it.
The test test_bagging_sample_weight_unsupported_but_passed now raises a
TypeError, not ValueError, when sample_weight are passed but not
supported.
0 commit comments