diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 4577f32fd7761..e0f5c76dc7bb3 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -6,6 +6,7 @@ import numpy as np import scipy.sparse as sp +import warnings from abc import ABCMeta, abstractmethod @@ -543,12 +544,19 @@ def fit(self, X, y, coef_init=None, intercept_init=None, sample_weight : array-like, shape (n_samples,), optional Weights applied to individual samples. - If not provided, uniform weights are assumed. + If not provided, uniform weights are assumed. These weights will + be multiplied with class_weight (passed through the + contructor) if class_weight is specified Returns ------- self : returns an instance of self. """ + if class_weight is not None: + warnings.warn("You are trying to set class_weight through the fit " + "method, which will be deprecated in version " + "v0.17 of scikit-learn. Pass the class_weight into " + "the constructor instead.", DeprecationWarning) return self._fit(X, y, alpha=self.alpha, C=1.0, loss=self.loss, learning_rate=self.learning_rate, coef_init=coef_init, intercept_init=intercept_init, diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 6085ecb31bbb4..a817088fb983d 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -14,6 +14,7 @@ from sklearn.utils.testing import assert_false, assert_true from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises_regexp +from sklearn.utils.testing import assert_warns_message from sklearn import linear_model, datasets, metrics from sklearn.base import clone @@ -597,6 +598,37 @@ def test_wrong_class_weight_format(self): clf = self.factory(alpha=0.1, n_iter=1000, class_weight=[0.5]) clf.fit(X, Y) + def test_class_weight_warning(self): + """Tests that class_weight passed through fit raises warning. + This test should be removed after deprecating support for this""" + + clf = self.factory() + warning_message = ("You are trying to set class_weight through the " + "fit " + "method, which will be deprecated in version " + "v0.17 of scikit-learn. Pass the class_weight into " + "the constructor instead.") + assert_warns_message(DeprecationWarning, + warning_message, + clf.fit, X4, Y4, + class_weight=1) + + def test_weights_multiplied(self): + """Tests that class_weight and sample_weight are multiplicative""" + class_weights = {1: .6, 2: .3} + sample_weights = np.random.random(Y4.shape[0]) + multiplied_together = np.copy(sample_weights) + multiplied_together[Y4 == 1] *= class_weights[1] + multiplied_together[Y4 == 2] *= class_weights[2] + + clf1 = self.factory(alpha=0.1, n_iter=20, class_weight=class_weights) + clf2 = self.factory(alpha=0.1, n_iter=20) + + clf1.fit(X4, Y4, sample_weight=sample_weights) + clf2.fit(X4, Y4, sample_weight=multiplied_together) + + assert_array_equal(clf1.coef_, clf2.coef_) + def test_auto_weight(self): """Test class weights for imbalanced data""" # compute reference metrics on iris dataset that is quite balanced by