From 7e13f9a9833ecbc44c6be3cd0f6bf428178aec70 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Tue, 11 Oct 2016 17:57:20 +0200 Subject: [PATCH 01/12] TST if LogisticRegressionCV handles string labels properly --- sklearn/linear_model/tests/test_logistic.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 8d35bb220c958..9431b4b5786f8 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -27,6 +27,7 @@ from sklearn.model_selection import StratifiedKFold from sklearn.datasets import load_iris, make_classification from sklearn.metrics import log_loss +from sklearn.preprocessing import LabelEncoder X = [[-1, 0], [0, 1], [1, 1]] X_sp = sp.csr_matrix(X) @@ -398,6 +399,38 @@ def test_logistic_cv(): assert_array_equal(scores.shape, (1, 3, 1)) +def test_multinomial_logistic_regression_string_inputs(): + # Test with string labels for LogisticRegression(CV) + n_samples, n_features, n_classes = 50, 5, 3 + X_ref, y = make_classification(n_samples=n_samples, n_features=n_features, + n_classes=n_classes, n_informative=3) + y_str = LabelEncoder().fit(['bar', 'baz', 'foo']).inverse_transform(y) + # For numerical labels, let y values be taken from set (-1, 0, 1) + y = np.array(y) - 1 + # Test for string labels + lr = LogisticRegression(solver='lbfgs', multi_class='multinomial') + lr_cv = LogisticRegressionCV(solver='lbfgs', multi_class='multinomial') + lr_str = LogisticRegression(solver='lbfgs', multi_class='multinomial') + lr_cv_str = LogisticRegressionCV(solver='lbfgs', multi_class='multinomial') + + lr.fit(X_ref, y) + lr_cv.fit(X_ref, y) + lr_str.fit(X_ref, y_str) + lr_cv_str.fit(X_ref, y_str) + + assert_array_almost_equal(lr.coef_, lr_str.coef_) + assert_equal(sorted(lr_str.classes_), ['bar', 'baz', 'foo']) + assert_array_almost_equal(lr_cv.coef_, lr_cv_str.coef_) + assert_equal(sorted(lr_str.classes_), ['bar', 'baz', 'foo']) + assert_equal(sorted(lr_cv_str.classes_), ['bar', 'baz', 'foo']) + + # The predictions should be in original labels + assert_equal(sorted(np.unique(lr_str.predict(X_ref))), + ['bar', 'baz', 'foo']) + assert_equal(sorted(np.unique(lr_cv_str.predict(X_ref))), + ['bar', 'baz', 'foo']) + + def test_logistic_cv_sparse(): X, y = make_classification(n_samples=50, n_features=5, random_state=0) From c9777e2b49facc8017f3d33af2cddeddc926c47b Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Tue, 11 Oct 2016 23:36:30 +0200 Subject: [PATCH 02/12] TST Add a test with class_weight dict --- sklearn/linear_model/tests/test_logistic.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 9431b4b5786f8..a5e9e212c7cf7 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -430,6 +430,12 @@ def test_multinomial_logistic_regression_string_inputs(): assert_equal(sorted(np.unique(lr_cv_str.predict(X_ref))), ['bar', 'baz', 'foo']) + # Make sure class weights can be given with string labels + lr_cv_str = LogisticRegression( + solver='lbfgs', class_weight={'bar': 1, 'baz': 2, 'foo': 0}, + multi_class='multinomial').fit(X_ref, y_str) + assert_equal(sorted(np.unique(lr_cv_str.predict(X_ref))), ['bar', 'baz']) + def test_logistic_cv_sparse(): X, y = make_classification(n_samples=50, n_features=5, From 3306f76b8d2ed859bc8ad678c753bf3338182830 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Thu, 13 Oct 2016 06:13:45 +0200 Subject: [PATCH 03/12] ENH Encode y and class_weight dict --- sklearn/linear_model/logistic.py | 78 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 98a7b5a558bc2..9c6b22094eaee 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1,4 +1,3 @@ - """ Logistic Regression """ @@ -28,7 +27,6 @@ from ..utils.extmath import row_norms from ..utils.optimize import newton_cg from ..utils.validation import check_X_y -from ..exceptions import DataConversionWarning from ..exceptions import NotFittedError from ..utils.fixes import expit from ..utils.multiclass import check_classification_targets @@ -925,9 +923,6 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, y_test = np.ones(y_test.shape, dtype=np.float64) y_test[~mask] = -1. - # To deal with object dtypes, we need to convert into an array of floats. - y_test = check_array(y_test, dtype=np.float64, ensure_2d=False) - scores = list() if isinstance(scoring, six.string_types): @@ -1561,64 +1556,67 @@ def fit(self, X, y, sample_weight=None): X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C") - + check_classification_targets(y) if self.solver == 'sag': max_squared_sum = row_norms(X, squared=True).max() else: max_squared_sum = None - check_classification_targets(y) + # Encode for string labels + label_encoder = LabelEncoder().fit(y) + y = label_encoder.transform(y) + self.classes_ = label_encoder.classes_ - if y.ndim == 2 and y.shape[1] == 1: - warnings.warn( - "A column-vector y was passed when a 1d array was" - " expected. Please change the shape of y to " - "(n_samples, ), for example using ravel().", - DataConversionWarning) - y = np.ravel(y) + enc_labels = label_encoder.transform(label_encoder.classes_) + cls_labels = self.classes_ # The original class labels - check_consistent_length(X, y) + class_weight = self.class_weight + if isinstance(class_weight, dict): + old_keys = list(class_weight.keys()) + new_keys = label_encoder.transform(old_keys) + # Don't modify the original class_weight dict. + class_weight = dict() + for new_key, old_key in zip(new_keys, old_keys): + class_weight[new_key] = self.class_weight[old_key] # init cross-validation generator cv = check_cv(self.cv, y, classifier=True) folds = list(cv.split(X, y)) - self._enc = LabelEncoder() - self._enc.fit(y) - - labels = self.classes_ = np.unique(y) - n_classes = len(labels) + # Use the label encoded classes + n_classes = len(enc_labels) if n_classes < 2: raise ValueError("This solver needs samples of at least 2 classes" " in the data, but the data contains only one" " class: %r" % self.classes_[0]) + if n_classes == 2: # OvR in case of binary problems is as good as fitting # the higher label n_classes = 1 - labels = labels[1:] + enc_labels = enc_labels[1:] + cls_labels = cls_labels[1:] # We need this hack to iterate only once over labels, in the case of # multi_class = multinomial, without changing the value of the labels. - iter_labels = labels if self.multi_class == 'multinomial': - iter_labels = [None] + iter_labels = iter_classes = [None] + else: + iter_labels = enc_labels + iter_classes = cls_labels - if self.class_weight and not(isinstance(self.class_weight, dict) or - self.class_weight in - ['balanced', 'auto']): + if class_weight and not(isinstance(class_weight, dict) or + class_weight in ['balanced', 'auto']): # 'auto' is deprecated and will be removed in 0.19 raise ValueError("class_weight provided should be a " "dict or 'balanced'") # compute the class weights for the entire dataset y - if self.class_weight in ("auto", "balanced"): - classes = np.unique(y) - class_weight = compute_class_weight(self.class_weight, classes, y) + if class_weight in ("auto", "balanced"): + classes = np.arange(len(self.classes_)) + class_weight = compute_class_weight(class_weight, classes, y) class_weight = dict(zip(classes, class_weight)) - else: - class_weight = self.class_weight path_func = delayed(_log_reg_scoring_path) @@ -1669,9 +1667,9 @@ def fit(self, X, y, sample_weight=None): self.n_iter_ = np.reshape(n_iter_, (n_classes, len(folds), len(self.Cs_))) - self.coefs_paths_ = dict(zip(labels, coefs_paths)) + self.coefs_paths_ = dict(zip(cls_labels, coefs_paths)) scores = np.reshape(scores, (n_classes, len(folds), -1)) - self.scores_ = dict(zip(labels, scores)) + self.scores_ = dict(zip(cls_labels, scores)) self.C_ = list() self.coef_ = np.empty((n_classes, X.shape[1])) @@ -1682,10 +1680,14 @@ def fit(self, X, y, sample_weight=None): scores = multi_scores coefs_paths = multi_coefs_paths - for index, label in enumerate(iter_labels): + for index, (cls_lbl, enc_lbl) in enumerate( + zip(iter_classes, iter_labels)): + if self.multi_class == 'ovr': - scores = self.scores_[label] - coefs_paths = self.coefs_paths_[label] + # The scores_ / coefs_paths_ dict have unencoded class + # labels as their keys + scores = self.scores_[cls_lbl] + coefs_paths = self.coefs_paths_[cls_lbl] if self.refit: best_index = scores.sum(axis=0).argmax() @@ -1698,8 +1700,10 @@ def fit(self, X, y, sample_weight=None): else: coef_init = np.mean(coefs_paths[:, best_index, :], axis=0) + # Note that y is label encoded and hence pos_class must be + # the encoded label / None (for 'multinomial') w, _, _ = logistic_regression_path( - X, y, pos_class=label, Cs=[C_], solver=self.solver, + X, y, pos_class=enc_lbl, Cs=[C_], solver=self.solver, fit_intercept=self.fit_intercept, coef=coef_init, max_iter=self.max_iter, tol=self.tol, penalty=self.penalty, copy=False, From 9bd4c76c9aabf7acfa2692c0b311f24031511bba Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Sat, 22 Oct 2016 01:36:10 +0200 Subject: [PATCH 04/12] Better variable names --- sklearn/linear_model/logistic.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 9c6b22094eaee..7cd19bf4eddd5 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1567,8 +1567,8 @@ def fit(self, X, y, sample_weight=None): y = label_encoder.transform(y) self.classes_ = label_encoder.classes_ - enc_labels = label_encoder.transform(label_encoder.classes_) - cls_labels = self.classes_ # The original class labels + encoded_labels = label_encoder.transform(label_encoder.classes_) + classes_labels = self.classes_ # The original class labels class_weight = self.class_weight if isinstance(class_weight, dict): @@ -1584,7 +1584,7 @@ def fit(self, X, y, sample_weight=None): folds = list(cv.split(X, y)) # Use the label encoded classes - n_classes = len(enc_labels) + n_classes = len(encoded_labels) if n_classes < 2: raise ValueError("This solver needs samples of at least 2 classes" @@ -1595,16 +1595,16 @@ def fit(self, X, y, sample_weight=None): # OvR in case of binary problems is as good as fitting # the higher label n_classes = 1 - enc_labels = enc_labels[1:] - cls_labels = cls_labels[1:] + encoded_labels = encoded_labels[1:] + classes_labels = classes_labels[1:] # We need this hack to iterate only once over labels, in the case of # multi_class = multinomial, without changing the value of the labels. if self.multi_class == 'multinomial': - iter_labels = iter_classes = [None] + iter_encoded_labels = iter_classes_labels = [None] else: - iter_labels = enc_labels - iter_classes = cls_labels + iter_encoded_labels = encoded_labels + iter_classes_labels = classes_labels if class_weight and not(isinstance(class_weight, dict) or class_weight in ['balanced', 'auto']): @@ -1636,7 +1636,7 @@ def fit(self, X, y, sample_weight=None): max_squared_sum=max_squared_sum, sample_weight=sample_weight ) - for label in iter_labels + for label in iter_encoded_labels for train, test in folds) if self.multi_class == 'multinomial': @@ -1667,9 +1667,9 @@ def fit(self, X, y, sample_weight=None): self.n_iter_ = np.reshape(n_iter_, (n_classes, len(folds), len(self.Cs_))) - self.coefs_paths_ = dict(zip(cls_labels, coefs_paths)) + self.coefs_paths_ = dict(zip(classes_labels, coefs_paths)) scores = np.reshape(scores, (n_classes, len(folds), -1)) - self.scores_ = dict(zip(cls_labels, scores)) + self.scores_ = dict(zip(classes_labels, scores)) self.C_ = list() self.coef_ = np.empty((n_classes, X.shape[1])) @@ -1680,14 +1680,14 @@ def fit(self, X, y, sample_weight=None): scores = multi_scores coefs_paths = multi_coefs_paths - for index, (cls_lbl, enc_lbl) in enumerate( - zip(iter_classes, iter_labels)): + for index, (casses_label, encoded_label) in enumerate( + zip(iter_classes_labels, iter_encoded_labels)): if self.multi_class == 'ovr': # The scores_ / coefs_paths_ dict have unencoded class # labels as their keys - scores = self.scores_[cls_lbl] - coefs_paths = self.coefs_paths_[cls_lbl] + scores = self.scores_[casses_label] + coefs_paths = self.coefs_paths_[casses_label] if self.refit: best_index = scores.sum(axis=0).argmax() @@ -1703,7 +1703,7 @@ def fit(self, X, y, sample_weight=None): # Note that y is label encoded and hence pos_class must be # the encoded label / None (for 'multinomial') w, _, _ = logistic_regression_path( - X, y, pos_class=enc_lbl, Cs=[C_], solver=self.solver, + X, y, pos_class=encoded_label, Cs=[C_], solver=self.solver, fit_intercept=self.fit_intercept, coef=coef_init, max_iter=self.max_iter, tol=self.tol, penalty=self.penalty, copy=False, From 599aa10849eb55de841b36ecb859f7f112541163 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 24 Oct 2016 14:15:02 +0200 Subject: [PATCH 05/12] TYPO casses --> classes --- sklearn/linear_model/logistic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 7cd19bf4eddd5..092518f6e3686 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1680,14 +1680,14 @@ def fit(self, X, y, sample_weight=None): scores = multi_scores coefs_paths = multi_coefs_paths - for index, (casses_label, encoded_label) in enumerate( + for index, (classes_label, encoded_label) in enumerate( zip(iter_classes_labels, iter_encoded_labels)): if self.multi_class == 'ovr': # The scores_ / coefs_paths_ dict have unencoded class # labels as their keys - scores = self.scores_[casses_label] - coefs_paths = self.coefs_paths_[casses_label] + scores = self.scores_[classes_label] + coefs_paths = self.coefs_paths_[classes_label] if self.refit: best_index = scores.sum(axis=0).argmax() From 41e4e19e5f6b592849dd9a471bdde19b1901d783 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Tue, 25 Oct 2016 18:34:43 +0200 Subject: [PATCH 06/12] FIX Use dict comprehension; classes_labels --> classes --- sklearn/linear_model/logistic.py | 35 +++++++++++++++----------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 092518f6e3686..258aabc0963f4 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1568,16 +1568,12 @@ def fit(self, X, y, sample_weight=None): self.classes_ = label_encoder.classes_ encoded_labels = label_encoder.transform(label_encoder.classes_) - classes_labels = self.classes_ # The original class labels + classes = self.classes_ # The original class labels class_weight = self.class_weight if isinstance(class_weight, dict): - old_keys = list(class_weight.keys()) - new_keys = label_encoder.transform(old_keys) - # Don't modify the original class_weight dict. - class_weight = dict() - for new_key, old_key in zip(new_keys, old_keys): - class_weight[new_key] = self.class_weight[old_key] + class_weight = {label_encoder.transform([cls])[0]: v + for cls, v in class_weight.items()} # init cross-validation generator cv = check_cv(self.cv, y, classifier=True) @@ -1596,15 +1592,15 @@ def fit(self, X, y, sample_weight=None): # the higher label n_classes = 1 encoded_labels = encoded_labels[1:] - classes_labels = classes_labels[1:] + classes = classes[1:] # We need this hack to iterate only once over labels, in the case of # multi_class = multinomial, without changing the value of the labels. if self.multi_class == 'multinomial': - iter_encoded_labels = iter_classes_labels = [None] + iter_encoded_labels = iter_classes = [None] else: iter_encoded_labels = encoded_labels - iter_classes_labels = classes_labels + iter_classes = classes if class_weight and not(isinstance(class_weight, dict) or class_weight in ['balanced', 'auto']): @@ -1614,9 +1610,10 @@ def fit(self, X, y, sample_weight=None): # compute the class weights for the entire dataset y if class_weight in ("auto", "balanced"): - classes = np.arange(len(self.classes_)) - class_weight = compute_class_weight(class_weight, classes, y) - class_weight = dict(zip(classes, class_weight)) + all_encoded_labels = np.arange(len(self.classes_)) + class_weight = compute_class_weight(class_weight, + all_encoded_labels, y) + class_weight = dict(zip(all_encoded_labels, class_weight)) path_func = delayed(_log_reg_scoring_path) @@ -1667,9 +1664,9 @@ def fit(self, X, y, sample_weight=None): self.n_iter_ = np.reshape(n_iter_, (n_classes, len(folds), len(self.Cs_))) - self.coefs_paths_ = dict(zip(classes_labels, coefs_paths)) + self.coefs_paths_ = dict(zip(classes, coefs_paths)) scores = np.reshape(scores, (n_classes, len(folds), -1)) - self.scores_ = dict(zip(classes_labels, scores)) + self.scores_ = dict(zip(classes, scores)) self.C_ = list() self.coef_ = np.empty((n_classes, X.shape[1])) @@ -1680,14 +1677,14 @@ def fit(self, X, y, sample_weight=None): scores = multi_scores coefs_paths = multi_coefs_paths - for index, (classes_label, encoded_label) in enumerate( - zip(iter_classes_labels, iter_encoded_labels)): + for index, (cls, encoded_label) in enumerate( + zip(iter_classes, iter_encoded_labels)): if self.multi_class == 'ovr': # The scores_ / coefs_paths_ dict have unencoded class # labels as their keys - scores = self.scores_[classes_label] - coefs_paths = self.coefs_paths_[classes_label] + scores = self.scores_[cls] + coefs_paths = self.coefs_paths_[cls] if self.refit: best_index = scores.sum(axis=0).argmax() From f8e3e96a912ea49ec173e27911abb90c77615dc4 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Tue, 25 Oct 2016 19:17:01 +0200 Subject: [PATCH 07/12] Don't use dict comprehension --- sklearn/linear_model/logistic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 258aabc0963f4..5083b631d53ba 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1572,8 +1572,8 @@ def fit(self, X, y, sample_weight=None): class_weight = self.class_weight if isinstance(class_weight, dict): - class_weight = {label_encoder.transform([cls])[0]: v - for cls, v in class_weight.items()} + class_weight = dict((label_encoder.transform([cls])[0], v) + for cls, v in class_weight.items()) # init cross-validation generator cv = check_cv(self.cv, y, classifier=True) From 026e97fa0dbb523cffbcd0416edaf9507f67f0a3 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 7 Nov 2016 14:38:16 +0100 Subject: [PATCH 08/12] MNT reorder validation to improve clarity --- sklearn/linear_model/logistic.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 5083b631d53ba..b25edf96855ef 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1557,24 +1557,31 @@ def fit(self, X, y, sample_weight=None): X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C") check_classification_targets(y) - if self.solver == 'sag': - max_squared_sum = row_norms(X, squared=True).max() - else: - max_squared_sum = None + + if class_weight and not(isinstance(class_weight, dict) or + class_weight in ['balanced', 'auto']): + # 'auto' is deprecated and will be removed in 0.19 + raise ValueError("class_weight provided should be a " + "dict or 'balanced'") # Encode for string labels label_encoder = LabelEncoder().fit(y) y = label_encoder.transform(y) - self.classes_ = label_encoder.classes_ - - encoded_labels = label_encoder.transform(label_encoder.classes_) - classes = self.classes_ # The original class labels class_weight = self.class_weight if isinstance(class_weight, dict): class_weight = dict((label_encoder.transform([cls])[0], v) for cls, v in class_weight.items()) + # The original class labels + classes = self.classes_ = label_encoder.classes_ + encoded_labels = label_encoder.transform(label_encoder.classes_) + + if self.solver == 'sag': + max_squared_sum = row_norms(X, squared=True).max() + else: + max_squared_sum = None + # init cross-validation generator cv = check_cv(self.cv, y, classifier=True) folds = list(cv.split(X, y)) @@ -1585,7 +1592,7 @@ def fit(self, X, y, sample_weight=None): if n_classes < 2: raise ValueError("This solver needs samples of at least 2 classes" " in the data, but the data contains only one" - " class: %r" % self.classes_[0]) + " class: %r" % classes[0]) if n_classes == 2: # OvR in case of binary problems is as good as fitting @@ -1602,12 +1609,6 @@ def fit(self, X, y, sample_weight=None): iter_encoded_labels = encoded_labels iter_classes = classes - if class_weight and not(isinstance(class_weight, dict) or - class_weight in ['balanced', 'auto']): - # 'auto' is deprecated and will be removed in 0.19 - raise ValueError("class_weight provided should be a " - "dict or 'balanced'") - # compute the class weights for the entire dataset y if class_weight in ("auto", "balanced"): all_encoded_labels = np.arange(len(self.classes_)) From 427eb41570b6fcc90ac561d9e24c07e5301c48ce Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 7 Nov 2016 14:49:17 +0100 Subject: [PATCH 09/12] Use enumerate --- sklearn/linear_model/logistic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index b25edf96855ef..1d8cab5d810d7 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1611,10 +1611,10 @@ def fit(self, X, y, sample_weight=None): # compute the class weights for the entire dataset y if class_weight in ("auto", "balanced"): - all_encoded_labels = np.arange(len(self.classes_)) class_weight = compute_class_weight(class_weight, - all_encoded_labels, y) - class_weight = dict(zip(all_encoded_labels, class_weight)) + np.arange(len(self.classes_)), + y) + class_weight = dict(zip(enumerate(class_weight))) path_func = delayed(_log_reg_scoring_path) From 7f3795e98f184d8b9bd8f82634fc14bc81b3ada1 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 7 Nov 2016 15:06:04 +0100 Subject: [PATCH 10/12] BUGFIX --- sklearn/linear_model/logistic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 1d8cab5d810d7..83433df7ce435 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1558,6 +1558,7 @@ def fit(self, X, y, sample_weight=None): order="C") check_classification_targets(y) + class_weight = self.class_weight if class_weight and not(isinstance(class_weight, dict) or class_weight in ['balanced', 'auto']): # 'auto' is deprecated and will be removed in 0.19 @@ -1567,8 +1568,6 @@ def fit(self, X, y, sample_weight=None): # Encode for string labels label_encoder = LabelEncoder().fit(y) y = label_encoder.transform(y) - - class_weight = self.class_weight if isinstance(class_weight, dict): class_weight = dict((label_encoder.transform([cls])[0], v) for cls, v in class_weight.items()) From cc70ff970088969bd568aa9e8ea129de33cbd387 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 7 Nov 2016 15:55:31 +0100 Subject: [PATCH 11/12] zip was not needed --- sklearn/linear_model/logistic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 83433df7ce435..e792371383228 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1613,7 +1613,7 @@ def fit(self, X, y, sample_weight=None): class_weight = compute_class_weight(class_weight, np.arange(len(self.classes_)), y) - class_weight = dict(zip(enumerate(class_weight))) + class_weight = dict(enumerate(class_weight)) path_func = delayed(_log_reg_scoring_path) From 4961d976aaf5100afae9ab77a86ba30c97359b77 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Wed, 9 Nov 2016 20:10:53 +0100 Subject: [PATCH 12/12] Add whatsnew --- doc/whats_new.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 08fbfedc79c92..d72e9ad8a40b1 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -97,6 +97,10 @@ Bug fixes attribute in `transform()`. :issue:`7553` by :user:`Ekaterina Krivich `. + - :class:`sklearn.linear_model.LogisticRegressionCV` now correctly handles + string labels. :issue:`5874` by `Raghav RV`_. + + .. _changes_0_18_1: Version 0.18.1