Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0bfe10d

Browse files
jschendeljnothman
authored andcommitted
FIX Pass sample_weight as kwargs in VotingClassifier (#9493)
1 parent 00358b1 commit 0bfe10d

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

sklearn/ensemble/tests/test_voting_classifier.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.svm import SVC
1818
from sklearn.multiclass import OneVsRestClassifier
1919
from sklearn.neighbors import KNeighborsClassifier
20+
from sklearn.base import BaseEstimator, ClassifierMixin
2021

2122

2223
# Load the iris dataset and randomly permute it
@@ -274,6 +275,20 @@ def test_sample_weight():
274275
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)
275276

276277

278+
def test_sample_weight_kwargs():
279+
"""Check that VotingClassifier passes sample_weight as kwargs"""
280+
class MockClassifier(BaseEstimator, ClassifierMixin):
281+
"""Mock Classifier to check that sample_weight is received as kwargs"""
282+
def fit(self, X, y, *args, **sample_weight):
283+
assert_true('sample_weight' in sample_weight)
284+
285+
clf = MockClassifier()
286+
eclf = VotingClassifier(estimators=[('mock', clf)], voting='soft')
287+
288+
# Should not raise an error.
289+
eclf.fit(X, y, sample_weight=np.ones((len(y),)))
290+
291+
277292
def test_set_params():
278293
"""set_params should be able to set estimators"""
279294
clf1 = LogisticRegression(random_state=123, C=1.0)

sklearn/ensemble/voting_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from ..utils.metaestimators import _BaseComposition
2424

2525

26-
def _parallel_fit_estimator(estimator, X, y, sample_weight):
26+
def _parallel_fit_estimator(estimator, X, y, sample_weight=None):
2727
"""Private function used to fit an estimator within a job."""
2828
if sample_weight is not None:
29-
estimator.fit(X, y, sample_weight)
29+
estimator.fit(X, y, sample_weight=sample_weight)
3030
else:
3131
estimator.fit(X, y)
3232
return estimator
@@ -185,7 +185,7 @@ def fit(self, X, y, sample_weight=None):
185185

186186
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
187187
delayed(_parallel_fit_estimator)(clone(clf), X, transformed_y,
188-
sample_weight)
188+
sample_weight=sample_weight)
189189
for clf in clfs if clf is not None)
190190

191191
return self

0 commit comments

Comments
 (0)