diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 1911cc5cbde57..d676312e240de 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -138,6 +138,12 @@ Bug fixes ``partial_fit`` was less than the total number of classes in the data. :issue:`7786` by `Srivatsan Ramesh`_ + - Fixes issue in :class:`calibration.CalibratedClassifierCV` where + the sum of probabilities of each class for a data was not 1, and + ``CalibratedClassifierCV`` now handles the case where the training set + has less number of classes than the total data. :issue:`7799` by + `Srivatsan Ramesh`_ + API changes summary ------------------- diff --git a/sklearn/calibration.py b/sklearn/calibration.py index ed3e85b643815..b96799f73d13d 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -14,9 +14,10 @@ import numpy as np from scipy.optimize import fmin_bfgs +from sklearn.preprocessing import LabelEncoder from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone -from .preprocessing import LabelBinarizer +from .preprocessing import label_binarize, LabelBinarizer from .utils import check_X_y, check_array, indexable, column_or_1d from .utils.validation import check_is_fitted from .utils.fixes import signature @@ -50,7 +51,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): The method to use for calibration. Can be 'sigmoid' which corresponds to Platt's method or 'isotonic' which is a non-parametric approach. It is not advised to use isotonic calibration - with too few calibration samples ``(<<1000)`` since it tends to overfit. + with too few calibration samples ``(<<1000)`` since it tends to + overfit. Use sigmoids (Platt's calibration) in this case. cv : integer, cross-validation generator, iterable or "prefit", optional @@ -63,8 +65,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): - An iterable yielding train/test splits. For integer/None inputs, if ``y`` is binary or multiclass, - :class:`sklearn.model_selection.StratifiedKFold` is used. If ``y`` - is neither binary nor multiclass, :class:`sklearn.model_selection.KFold` + :class:`sklearn.model_selection.StratifiedKFold` is used. If ``y`` is + neither binary nor multiclass, :class:`sklearn.model_selection.KFold` is used. Refer :ref:`User Guide ` for the various @@ -124,15 +126,16 @@ def fit(self, X, y, sample_weight=None): X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'], force_all_finite=False) X, y = indexable(X, y) - lb = LabelBinarizer().fit(y) - self.classes_ = lb.classes_ + le = LabelBinarizer().fit(y) + self.classes_ = le.classes_ # Check that each cross-validation fold can have at least one # example per class n_folds = self.cv if isinstance(self.cv, int) \ else self.cv.n_folds if hasattr(self.cv, "n_folds") else None if n_folds and \ - np.any([np.sum(y == class_) < n_folds for class_ in self.classes_]): + np.any([np.sum(y == class_) < n_folds for class_ in + self.classes_]): raise ValueError("Requesting %d-fold cross-validation but provided" " less than %d examples for at least one class." % (n_folds, n_folds)) @@ -175,7 +178,8 @@ def fit(self, X, y, sample_weight=None): this_estimator.fit(X[train], y[train]) calibrated_classifier = _CalibratedClassifier( - this_estimator, method=self.method) + this_estimator, method=self.method, + classes=self.classes_) if sample_weight is not None: calibrated_classifier.fit(X[test], y[test], sample_weight[test]) @@ -253,6 +257,11 @@ class _CalibratedClassifier(object): corresponds to Platt's method or 'isotonic' which is a non-parametric approach based on isotonic regression. + classes : array-like, shape (n_classes,), optional + Contains unique classes used to fit the base estimator. + if None, then classes is extracted from the given target values + in fit(). + References ---------- .. [1] Obtaining calibrated probability estimates from decision trees @@ -267,9 +276,10 @@ class _CalibratedClassifier(object): .. [4] Predicting Good Probabilities with Supervised Learning, A. Niculescu-Mizil & R. Caruana, ICML 2005 """ - def __init__(self, base_estimator, method='sigmoid'): + def __init__(self, base_estimator, method='sigmoid', classes=None): self.base_estimator = base_estimator self.method = method + self.classes = classes def _preproc(self, X): n_classes = len(self.classes_) @@ -285,7 +295,8 @@ def _preproc(self, X): raise RuntimeError('classifier has no decision_function or ' 'predict_proba method.') - idx_pos_class = np.arange(df.shape[1]) + idx_pos_class = self.label_encoder_.\ + transform(self.base_estimator.classes_) return df, idx_pos_class @@ -308,9 +319,15 @@ def fit(self, X, y, sample_weight=None): self : object Returns an instance of self. """ - lb = LabelBinarizer() - Y = lb.fit_transform(y) - self.classes_ = lb.classes_ + + self.label_encoder_ = LabelEncoder() + if self.classes is None: + self.label_encoder_.fit(y) + else: + self.label_encoder_.fit(self.classes) + + self.classes_ = self.label_encoder_.classes_ + Y = label_binarize(y, self.classes_) df, idx_pos_class = self._preproc(X) self.calibrators_ = [] diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 68a6efb395971..e4499e35d5a67 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -1,8 +1,10 @@ # Authors: Alexandre Gramfort # License: BSD 3 clause +from __future__ import division import numpy as np from scipy import sparse +from sklearn.model_selection import LeaveOneOut from sklearn.utils.testing import (assert_array_almost_equal, assert_equal, assert_greater, assert_almost_equal, @@ -14,7 +16,6 @@ from sklearn.naive_bayes import MultinomialNB from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.svm import LinearSVC -from sklearn.linear_model import Ridge from sklearn.pipeline import Pipeline from sklearn.preprocessing import Imputer from sklearn.metrics import brier_score_loss, log_loss @@ -87,12 +88,6 @@ def test_calibration(): brier_score_loss((y_test + 1) % 2, prob_pos_pc_clf_relabeled)) - # check that calibration can also deal with regressors that have - # a decision_function - clf_base_regressor = CalibratedClassifierCV(Ridge()) - clf_base_regressor.fit(X_train, y_train) - clf_base_regressor.predict(X_test) - # Check failure cases: # only "isotonic" and "sigmoid" should be accepted as methods clf_invalid_method = CalibratedClassifierCV(clf, method="foo") @@ -159,6 +154,7 @@ def test_calibration_multiclass(): def softmax(y_pred): e = np.exp(-y_pred) return e / e.sum(axis=1).reshape(-1, 1) + uncalibrated_log_loss = \ log_loss(y_test, softmax(clf.decision_function(X_test))) calibrated_log_loss = log_loss(y_test, probas) @@ -275,3 +271,36 @@ def test_calibration_nan_imputer(): clf_c = CalibratedClassifierCV(clf, cv=2, method='isotonic') clf_c.fit(X, y) clf_c.predict(X) + + +def test_calibration_prob_sum(): + # Test that sum of probabilities is 1. A non-regression test for + # issue #7796 + num_classes = 2 + X, y = make_classification(n_samples=10, n_features=5, + n_classes=num_classes) + clf = LinearSVC(C=1.0) + clf_prob = CalibratedClassifierCV(clf, method="sigmoid", cv=LeaveOneOut()) + clf_prob.fit(X, y) + + probs = clf_prob.predict_proba(X) + assert_array_almost_equal(probs.sum(axis=1), np.ones(probs.shape[0])) + + +def test_calibration_less_classes(): + # Test to check calibration works fine when train set in a test-train + # split does not contain all classes + # Since this test uses LOO, at each iteration train set will not contain a + # class label + X = np.random.randn(10, 5) + y = np.arange(10) + clf = LinearSVC(C=1.0) + cal_clf = CalibratedClassifierCV(clf, method="sigmoid", cv=LeaveOneOut()) + cal_clf.fit(X, y) + + for i, calibrated_classifier in \ + enumerate(cal_clf.calibrated_classifiers_): + proba = calibrated_classifier.predict_proba(X) + assert_array_equal(proba[:, i], np.zeros(len(y))) + assert_equal(np.all(np.hstack([proba[:, :i], + proba[:, i + 1:]])), True)