Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cf8cfd1

Browse files
author
Bill DeRose
committed
pass sample_weight when predicting on stacked folds
1 parent 136ef79 commit cf8cfd1

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

sklearn/ensemble/_stacking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ def fit(self, X, y, sample_weight=None):
165165
self._method_name(name, est, meth)
166166
for name, est, meth in zip(names, all_estimators, stack_method)
167167
]
168-
169168
predictions = Parallel(n_jobs=self.n_jobs)(
170169
delayed(cross_val_predict)(clone(est), X, y, cv=deepcopy(cv),
171170
method=meth, n_jobs=self.n_jobs,
171+
fit_params={
172+
'sample_weight': sample_weight
173+
},
172174
verbose=self.verbose)
173175
for est, meth in zip(all_estimators, self.stack_method_)
174176
if est != 'drop'

sklearn/ensemble/tests/test_stacking.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sklearn.model_selection import StratifiedKFold
3939
from sklearn.model_selection import KFold
4040

41+
from sklearn.utils._mocking import CheckingClassifier
4142
from sklearn.utils._testing import assert_allclose
4243
from sklearn.utils._testing import assert_allclose_dense_sparse
4344
from sklearn.utils._testing import ignore_warnings
@@ -439,6 +440,19 @@ def test_stacking_with_sample_weight(stacker, X, y):
439440
assert np.abs(y_pred_no_weight - y_pred_biased).sum() > 0
440441

441442

443+
def test_fit_stacking_with_sample_weight_passed():
444+
# check sample_weight is passed to all invokations of fit
445+
stacker = StackingClassifier(
446+
estimators=[
447+
('lr', CheckingClassifier(expected_fit_params=['sample_weight']))
448+
],
449+
final_estimator=CheckingClassifier(
450+
expected_fit_params=['sample_weight']
451+
)
452+
)
453+
stacker.fit(X_iris, y_iris, sample_weight=np.ones(X_iris.shape[0]))
454+
455+
442456
@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
443457
@pytest.mark.parametrize(
444458
"stacker, X, y",

0 commit comments

Comments
 (0)