diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py index 4765d0e32d0bb..023be79912d12 100644 --- a/sklearn/ensemble/tests/test_voting_classifier.py +++ b/sklearn/ensemble/tests/test_voting_classifier.py @@ -17,6 +17,7 @@ from sklearn.svm import SVC from sklearn.multiclass import OneVsRestClassifier from sklearn.neighbors import KNeighborsClassifier +from sklearn.base import BaseEstimator, ClassifierMixin # Load the iris dataset and randomly permute it @@ -274,6 +275,20 @@ def test_sample_weight(): assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight) +def test_sample_weight_kwargs(): + """Check that VotingClassifier passes sample_weight as kwargs""" + class MockClassifier(BaseEstimator, ClassifierMixin): + """Mock Classifier to check that sample_weight is received as kwargs""" + def fit(self, X, y, *args, **sample_weight): + assert_true('sample_weight' in sample_weight) + + clf = MockClassifier() + eclf = VotingClassifier(estimators=[('mock', clf)], voting='soft') + + # Should not raise an error. + eclf.fit(X, y, sample_weight=np.ones((len(y),))) + + def test_set_params(): """set_params should be able to set estimators""" clf1 = LogisticRegression(random_state=123, C=1.0) diff --git a/sklearn/ensemble/voting_classifier.py b/sklearn/ensemble/voting_classifier.py index 88b329d836978..ad6c0125dd664 100644 --- a/sklearn/ensemble/voting_classifier.py +++ b/sklearn/ensemble/voting_classifier.py @@ -23,10 +23,10 @@ from ..utils.metaestimators import _BaseComposition -def _parallel_fit_estimator(estimator, X, y, sample_weight): +def _parallel_fit_estimator(estimator, X, y, sample_weight=None): """Private function used to fit an estimator within a job.""" if sample_weight is not None: - estimator.fit(X, y, sample_weight) + estimator.fit(X, y, sample_weight=sample_weight) else: estimator.fit(X, y) return estimator @@ -185,7 +185,7 @@ def fit(self, X, y, sample_weight=None): self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_parallel_fit_estimator)(clone(clf), X, transformed_y, - sample_weight) + sample_weight=sample_weight) for clf in clfs if clf is not None) return self