Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 07560e4

Browse files
committed
FIX scikit-learn#3485: class_weight='auto' on SGDClassifier
1 parent 6c7f029 commit 07560e4

File tree

5 files changed

+59
-5
lines changed

5 files changed

+59
-5
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ Bug fixes
105105
and ``pandas.DataFrame`` in recent versions of pandas. By
106106
`Gael Varoquaux`_.
107107

108+
- Fixed a regression for :class:`linear_model.SGDClassifier` with
109+
``class_weight="auto"`` on data with non-contiguous labels. By
110+
`Olivier Grisel`_.
111+
108112

109113
.. _changes_0_15:
110114

sklearn/linear_model/stochastic_gradient.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,8 @@ def _partial_fit(self, X, y, alpha, C,
327327
n_classes = self.classes_.shape[0]
328328

329329
# Allocate datastructures from input arguments
330-
y_ind = np.searchsorted(self.classes_, y) # XXX use a LabelBinarizer?
331330
self._expanded_class_weight = compute_class_weight(self.class_weight,
332-
self.classes_,
333-
y_ind)
331+
self.classes_, y)
334332
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
335333

336334
if self.coef_ is None or coef_init is not None:

sklearn/linear_model/tests/test_sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def test_underflow_or_overlow():
847847
assert_array_equal(np.unique(y), [0, 1])
848848

849849
model = SGDClassifier(alpha=0.1, loss='squared_hinge', n_iter=500)
850-
850+
851851
# smoke test: model is stable on scaled data
852852
model.fit(scale(X), y)
853853

sklearn/tests/test_common.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sklearn.datasets import make_classification
2727

2828
from sklearn.cross_validation import train_test_split
29+
from sklearn.linear_model.base import LinearClassifierMixin
2930
from sklearn.utils.estimator_checks import (
3031
check_parameters_default_constructible,
3132
check_regressors_classifiers_sparse_data,
@@ -44,6 +45,7 @@
4445
check_classifiers_pickle,
4546
check_class_weight_classifiers,
4647
check_class_weight_auto_classifiers,
48+
check_class_weight_auto_linear_classifier,
4749
check_estimators_overwrite_params,
4850
check_cluster_overwrite_params,
4951
check_sparsify_binary_classifier,
@@ -214,7 +216,7 @@ def test_class_weight_classifiers():
214216
yield check_class_weight_classifiers, name, Classifier
215217

216218

217-
def test_class_weight_auto_classifies():
219+
def test_class_weight_auto_classifiers():
218220
"""Test that class_weight="auto" improves f1-score"""
219221

220222
# This test is broken; its success depends on:
@@ -251,6 +253,26 @@ def test_class_weight_auto_classifies():
251253
X_train, y_train, X_test, y_test, weights)
252254

253255

256+
def test_class_weight_auto_linear_classifiers():
257+
classifiers = all_estimators(type_filter='classifier')
258+
259+
with warnings.catch_warnings(record=True):
260+
linear_classifiers = [
261+
(name, clazz)
262+
for name, clazz in classifiers
263+
if 'class_weight' in clazz().get_params().keys()
264+
and issubclass(clazz, LinearClassifierMixin)]
265+
266+
for name, Classifier in linear_classifiers:
267+
if name == "LogisticRegressionCV":
268+
# Contrary to RidgeClassifierCV, LogisticRegressionCV use actual
269+
# CV folds and fit a model for each CV iteration before averaging
270+
# the coef. Therefore it is expected to not behave exactly as the
271+
# other linear model.
272+
continue
273+
yield check_class_weight_auto_linear_classifier, name, Classifier
274+
275+
254276
def test_estimators_overwrite_params():
255277
# test whether any classifier overwrites his init parameters during fit
256278
for est_type in ["classifier", "regressor", "transformer"]:

sklearn/utils/estimator_checks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,36 @@ def check_class_weight_auto_classifiers(name, Classifier, X_train, y_train,
725725
f1_score(y_test, y_pred))
726726

727727

728+
def check_class_weight_auto_linear_classifier(name, Classifier):
729+
"""Test class weights with non-contiguous class labels."""
730+
X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
731+
[1.0, 1.0], [1.0, 0.0]])
732+
y = [1, 1, 1, -1, -1]
733+
734+
with warnings.catch_warnings(record=True):
735+
classifier = Classifier()
736+
if hasattr(classifier, "n_iter"):
737+
# This is a very small dataset, default n_iter are likely to prevent
738+
# convergence
739+
classifier.set_params(n_iter=1000)
740+
set_random_state(classifier)
741+
742+
# Let the model compute the class frequencies
743+
classifier.set_params(class_weight='auto')
744+
coef_auto = classifier.fit(X, y).coef_.copy()
745+
746+
# Count each label occurrence to reweight manually
747+
mean_weight = (1. / 3 + 1. / 2) / 2
748+
class_weight = {
749+
1: 1. / 3 / mean_weight,
750+
-1: 1. / 2 / mean_weight,
751+
}
752+
classifier.set_params(class_weight=class_weight)
753+
coef_manual = classifier.fit(X, y).coef_.copy()
754+
755+
assert_array_almost_equal(coef_auto, coef_manual)
756+
757+
728758
def check_estimators_overwrite_params(name, Estimator):
729759
X, y = make_blobs(random_state=0, n_samples=9)
730760
y = multioutput_estimator_convert_y_2d(name, y)

0 commit comments

Comments
 (0)