From c4cff8514e5ef455773fcc824e06d585d207bab9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 13 Sep 2021 20:47:50 +0200 Subject: [PATCH 1/6] ENH add in calibration tools --- sklearn/calibration.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 9a7e08c9d9ff2..3eef8ddf08494 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -37,12 +37,17 @@ from .utils.multiclass import check_classification_targets from .utils.fixes import delayed -from .utils.validation import check_is_fitted, check_consistent_length -from .utils.validation import _check_sample_weight, _num_samples +from .utils.validation import ( + _check_sample_weight, + _num_samples, + check_consistent_length, + check_is_fitted, +) from .utils import _safe_indexing from .isotonic import IsotonicRegression from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict +from .metrics._base import _check_pos_label_consistency from .metrics._plot.base import _get_response @@ -847,7 +852,9 @@ def predict(self, T): return expit(-(self.a_ * T + self.b_)) -def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="uniform"): +def calibration_curve( + y_true, y_prob, *, pos_label=None, normalize=False, n_bins=5, strategy="uniform" +): """Compute true and predicted probabilities for a calibration curve. The method assumes the inputs come from a binary classifier, and @@ -865,6 +872,11 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un y_prob : array-like of shape (n_samples,) Probabilities of the positive class. + pos_label : int or str, default=None + The label of the positive class. + + .. versionadded:: 1.1 + normalize : bool, default=False Whether y_prob needs to be normalized into the [0, 1] interval, i.e. is not a proper probability. If True, the smallest value in y_prob @@ -915,6 +927,7 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un y_true = column_or_1d(y_true) y_prob = column_or_1d(y_prob) check_consistent_length(y_true, y_prob) + pos_label = _check_pos_label_consistency(pos_label, y_true) if normalize: # Normalize predicted values into interval [0, 1] y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min()) @@ -926,9 +939,9 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un labels = np.unique(y_true) if len(labels) > 2: raise ValueError( - "Only binary classification is supported. Provided labels %s." % labels + "Only binary classification is supported. Provided labels {labels}." ) - y_true = label_binarize(y_true, classes=labels)[:, 0] + y_true = y_true == pos_label if strategy == "quantile": # Determine bin edges by distribution of data quantiles = np.linspace(0, 1, n_bins + 1) From 8d6b10fa74a127d8c67e849f79b4fd921ad1087b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 10:24:42 +0200 Subject: [PATCH 2/6] TST check that we raise a consistent error message --- sklearn/tests/test_calibration.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b06f14b082cf5..c568d89b913e6 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -785,3 +785,19 @@ def test_calibration_display_ref_line(pyplot, iris_data_binary): labels = viz2.ax_.get_legend_handles_labels()[1] assert labels.count("Perfectly calibrated") == 1 + + +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_calibration_curve_pos_label_error_str(dtype_y_str): + """Check error message when a `pos_label` is not specified with `str` targets.""" + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg = ( + "y_true takes value in {'eggs', 'spam'} and pos_label is not " + "specified: either make y_true take value in {0, 1} or {-1, 1} or " + "pass pos_label explicit" + ) + with pytest.raises(ValueError, match=err_msg): + calibration_curve(y1, y2) From 65b149913e7ee12a3c5c39068ab1ffd827fa67fe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 10:27:24 +0200 Subject: [PATCH 3/6] add whats new --- doc/whats_new/v1.1.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index fba40e25a9e7e..2616f55d9561f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -38,6 +38,13 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.calibration` +.......................... + +- |Enhancement| :func:`calibration.calibration_curve` accepts a parameter + `pos_label` to specify the positive class label. + :pr:`21032` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.utils` .................... From 3865f60f85e71894f74b5b2f74577bad04aaf2da Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 14 Sep 2021 11:53:20 +0200 Subject: [PATCH 4/6] TST add test for pos_label --- sklearn/tests/test_calibration.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index c568d89b913e6..c78c273c7038c 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -801,3 +801,24 @@ def test_calibration_curve_pos_label_error_str(dtype_y_str): ) with pytest.raises(ValueError, match=err_msg): calibration_curve(y1, y2) + + +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_calibration_curve_pos_label(dtype_y_str): + """Check the behaviour when passing explicitly `pos_label`.""" + y_true = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) + classes = np.array(["spam", "egg"], dtype=dtype_y_str) + y_true_str = classes[y_true] + y_pred = np.array([0.1, 0.2, 0.3, 0.4, 0.65, 0.7, 0.8, 0.9, 1.0]) + + # default case + prob_true, _ = calibration_curve(y_true, y_pred, n_bins=4) + assert_allclose(prob_true, [0, 0.5, 1, 1]) + # if `y_true` contains `str`, then `pos_label` is required + prob_true, _ = calibration_curve(y_true_str, y_pred, n_bins=4, pos_label="egg") + assert_allclose(prob_true, [0, 0.5, 1, 1]) + + prob_true, _ = calibration_curve(y_true, 1 - y_pred, n_bins=4, pos_label=0) + assert_allclose(prob_true, [0, 0, 0.5, 1]) + prob_true, _ = calibration_curve(y_true_str, 1 - y_pred, n_bins=4, pos_label="spam") + assert_allclose(prob_true, [0, 0, 0.5, 1]) From 55636db589fd43f56bd43f4258d382d20f5b5d77 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Sep 2021 15:31:40 +0200 Subject: [PATCH 5/6] Update sklearn/tests/test_calibration.py Co-authored-by: Olivier Grisel --- sklearn/tests/test_calibration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index c78c273c7038c..be790c2a98ba8 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -797,7 +797,7 @@ def test_calibration_curve_pos_label_error_str(dtype_y_str): err_msg = ( "y_true takes value in {'eggs', 'spam'} and pos_label is not " "specified: either make y_true take value in {0, 1} or {-1, 1} or " - "pass pos_label explicit" + "pass pos_label explicitly" ) with pytest.raises(ValueError, match=err_msg): calibration_curve(y1, y2) From 23e87df80b7507a7b66c86c42b4b5cb920f57633 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 27 Sep 2021 11:46:07 +0200 Subject: [PATCH 6/6] Update sklearn/calibration.py Co-authored-by: Thomas J. Fan --- sklearn/calibration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 104cd84f4f219..a390159c2db38 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -939,7 +939,7 @@ def calibration_curve( labels = np.unique(y_true) if len(labels) > 2: raise ValueError( - "Only binary classification is supported. Provided labels {labels}." + f"Only binary classification is supported. Provided labels {labels}." ) y_true = y_true == pos_label