diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 3aabed6214771..9d9cf3c95b450 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -38,6 +38,14 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. + +:mod:`sklearn.calibration` +.......................... + +- |Enhancement| :func:`calibration.calibration_curve` accepts a parameter + `pos_label` to specify the positive class label. + :pr:`21032` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... @@ -45,6 +53,7 @@ Changelog message when the solver does not support sparse matrices with int64 indices. :pr:`21093` by `Tom Dupre la Tour`_. + :mod:`sklearn.utils` .................... diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 0785938135513..6d9abf82d3470 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -37,12 +37,17 @@ from .utils.multiclass import check_classification_targets from .utils.fixes import delayed -from .utils.validation import check_is_fitted, check_consistent_length -from .utils.validation import _check_sample_weight, _num_samples +from .utils.validation import ( + _check_sample_weight, + _num_samples, + check_consistent_length, + check_is_fitted, +) from .utils import _safe_indexing from .isotonic import IsotonicRegression from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict +from .metrics._base import _check_pos_label_consistency from .metrics._plot.base import _get_response @@ -866,7 +871,9 @@ def predict(self, T): return expit(-(self.a_ * T + self.b_)) -def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="uniform"): +def calibration_curve( + y_true, y_prob, *, pos_label=None, normalize=False, n_bins=5, strategy="uniform" +): """Compute true and predicted probabilities for a calibration curve. The method assumes the inputs come from a binary classifier, and @@ -884,6 +891,11 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un y_prob : array-like of shape (n_samples,) Probabilities of the positive class. + pos_label : int or str, default=None + The label of the positive class. + + .. versionadded:: 1.1 + normalize : bool, default=False Whether y_prob needs to be normalized into the [0, 1] interval, i.e. is not a proper probability. If True, the smallest value in y_prob @@ -934,6 +946,7 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un y_true = column_or_1d(y_true) y_prob = column_or_1d(y_prob) check_consistent_length(y_true, y_prob) + pos_label = _check_pos_label_consistency(pos_label, y_true) if normalize: # Normalize predicted values into interval [0, 1] y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min()) @@ -945,9 +958,9 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un labels = np.unique(y_true) if len(labels) > 2: raise ValueError( - "Only binary classification is supported. Provided labels %s." % labels + f"Only binary classification is supported. Provided labels {labels}." ) - y_true = label_binarize(y_true, classes=labels)[:, 0] + y_true = y_true == pos_label if strategy == "quantile": # Determine bin edges by distribution of data quantiles = np.linspace(0, 1, n_bins + 1) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 7b8d656bef939..4ad983d72e007 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -786,6 +786,43 @@ def test_calibration_display_ref_line(pyplot, iris_data_binary): assert labels.count("Perfectly calibrated") == 1 +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_calibration_curve_pos_label_error_str(dtype_y_str): + """Check error message when a `pos_label` is not specified with `str` targets.""" + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg = ( + "y_true takes value in {'eggs', 'spam'} and pos_label is not " + "specified: either make y_true take value in {0, 1} or {-1, 1} or " + "pass pos_label explicitly" + ) + with pytest.raises(ValueError, match=err_msg): + calibration_curve(y1, y2) + + +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_calibration_curve_pos_label(dtype_y_str): + """Check the behaviour when passing explicitly `pos_label`.""" + y_true = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) + classes = np.array(["spam", "egg"], dtype=dtype_y_str) + y_true_str = classes[y_true] + y_pred = np.array([0.1, 0.2, 0.3, 0.4, 0.65, 0.7, 0.8, 0.9, 1.0]) + + # default case + prob_true, _ = calibration_curve(y_true, y_pred, n_bins=4) + assert_allclose(prob_true, [0, 0.5, 1, 1]) + # if `y_true` contains `str`, then `pos_label` is required + prob_true, _ = calibration_curve(y_true_str, y_pred, n_bins=4, pos_label="egg") + assert_allclose(prob_true, [0, 0.5, 1, 1]) + + prob_true, _ = calibration_curve(y_true, 1 - y_pred, n_bins=4, pos_label=0) + assert_allclose(prob_true, [0, 0, 0.5, 1]) + prob_true, _ = calibration_curve(y_true_str, 1 - y_pred, n_bins=4, pos_label="spam") + assert_allclose(prob_true, [0, 0, 0.5, 1]) + + @pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) @pytest.mark.parametrize("ensemble", [True, False]) def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):