Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH add pos_label to CalibrationDisplay #21038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c4cff85
ENH add in calibration tools
glemaitre Sep 13, 2021
8d6b10f
TST check that we raise a consistent error message
glemaitre Sep 14, 2021
65b1499
add whats new
glemaitre Sep 14, 2021
3865f60
TST add test for pos_label
glemaitre Sep 14, 2021
aa3d76d
Merge remote-tracking branch 'origin/main' into is/21029
glemaitre Sep 14, 2021
1423fb7
ENH add pos_label to CalibrationDisplay
glemaitre Sep 14, 2021
2644263
DOC add to whats new
glemaitre Sep 14, 2021
65b3bd1
TST add test for pos_label in CalibrationDisplay
glemaitre Sep 14, 2021
ec9bdf2
TST add unit tests for current _get_response
glemaitre Sep 14, 2021
a989c67
TST add unit tests for current _get_response
glemaitre Sep 14, 2021
bc6efda
add a proper way to check the warning raised
glemaitre Sep 14, 2021
cf4c2a4
revert hidding warning
glemaitre Sep 14, 2021
4e30df2
Merge remote-tracking branch 'glemaitre/is/_get_response_test' into i…
glemaitre Sep 14, 2021
55636db
Update sklearn/tests/test_calibration.py
glemaitre Sep 24, 2021
9f4ab8d
Merge branch 'main' into is/21029
glemaitre Sep 24, 2021
973b326
Merge remote-tracking branch 'origin/main' into is/follow_up_21029
glemaitre Sep 24, 2021
21ed564
Merge remote-tracking branch 'glemaitre/is/21029' into is/follow_up_2…
glemaitre Sep 24, 2021
02dc3a5
Address ogrisel comments
glemaitre Sep 24, 2021
23e87df
Update sklearn/calibration.py
glemaitre Sep 27, 2021
1e2a386
Merge remote-tracking branch 'glemaitre/is/21029' into is/follow_up_2…
glemaitre Sep 27, 2021
367d489
Merge remote-tracking branch 'origin/main' into is/follow_up_21029
glemaitre Oct 26, 2021
3570c69
Apply suggestions from code review
glemaitre Oct 27, 2021
c1cebfe
TST add multiple cases for pos_label
glemaitre Oct 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ Changelog
`pos_label` to specify the positive class label.
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| :class:`CalibrationDisplay` accepts a parameter `pos_label` to
add this information to the plot.
:pr:`21038` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.cross_decomposition`
..................................

Expand Down
53 changes: 44 additions & 9 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,13 @@ class CalibrationDisplay:
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.

pos_label : str or int, default=None
The positive class when computing the calibration curve.
By default, `estimators.classes_[1]` is considered as the
positive class.

.. versionadded:: 1.1

Attributes
----------
line_ : matplotlib Artist
Expand Down Expand Up @@ -1054,11 +1061,14 @@ class CalibrationDisplay:
<...>
"""

def __init__(self, prob_true, prob_pred, y_prob, *, estimator_name=None):
def __init__(
self, prob_true, prob_pred, y_prob, *, estimator_name=None, pos_label=None
):
self.prob_true = prob_true
self.prob_pred = prob_pred
self.y_prob = y_prob
self.estimator_name = estimator_name
self.pos_label = pos_label

def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
"""Plot visualization.
Expand Down Expand Up @@ -1095,6 +1105,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name
info_pos_label = (
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
)

line_kwargs = {}
if name is not None:
Expand All @@ -1110,7 +1123,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
if "label" in line_kwargs:
ax.legend(loc="lower right")

ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives")
xlabel = f"Mean predicted probability {info_pos_label}"
ylabel = f"Fraction of positives {info_pos_label}"
ax.set(xlabel=xlabel, ylabel=ylabel)

self.ax_ = ax
self.figure_ = ax.figure
Expand All @@ -1125,6 +1140,7 @@ def from_estimator(
*,
n_bins=5,
strategy="uniform",
pos_label=None,
name=None,
ref_line=True,
ax=None,
Expand Down Expand Up @@ -1170,6 +1186,13 @@ def from_estimator(
- `'quantile'`: The bins have the same number of samples and depend
on predicted probabilities.

pos_label : str or int, default=None
The positive class when computing the calibration curve.
By default, `estimators.classes_[1]` is considered as the
positive class.

.. versionadded:: 1.1

name : str, default=None
Name for labeling curve. If `None`, the name of the estimator is
used.
Expand Down Expand Up @@ -1217,10 +1240,8 @@ def from_estimator(
if not is_classifier(estimator):
raise ValueError("'estimator' should be a fitted classifier.")

# FIXME: `pos_label` should not be set to None
# We should allow any int or string in `calibration_curve`.
y_prob, _ = _get_response(
X, estimator, response_method="predict_proba", pos_label=None
y_prob, pos_label = _get_response(
X, estimator, response_method="predict_proba", pos_label=pos_label
)

name = name if name is not None else estimator.__class__.__name__
Expand All @@ -1229,6 +1250,7 @@ def from_estimator(
y_prob,
n_bins=n_bins,
strategy=strategy,
pos_label=pos_label,
name=name,
ref_line=ref_line,
ax=ax,
Expand All @@ -1243,6 +1265,7 @@ def from_predictions(
*,
n_bins=5,
strategy="uniform",
pos_label=None,
name=None,
ref_line=True,
ax=None,
Expand Down Expand Up @@ -1283,6 +1306,13 @@ def from_predictions(
- `'quantile'`: The bins have the same number of samples and depend
on predicted probabilities.

pos_label : str or int, default=None
The positive class when computing the calibration curve.
By default, `estimators.classes_[1]` is considered as the
positive class.

.. versionadded:: 1.1

name : str, default=None
Name for labeling curve.

Expand Down Expand Up @@ -1328,11 +1358,16 @@ def from_predictions(
check_matplotlib_support(method_name)

prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy=strategy
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
)
name = name if name is not None else "Classifier"
name = "Classifier" if name is None else name
pos_label = _check_pos_label_consistency(pos_label, y_true)

disp = cls(
prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name
prob_true=prob_true,
prob_pred=prob_pred,
y_prob=y_prob,
estimator_name=name,
pos_label=pos_label,
)
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)
32 changes: 30 additions & 2 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
assert isinstance(viz.ax_, mpl.axes.Axes)
assert isinstance(viz.figure_, mpl.figure.Figure)

assert viz.ax_.get_xlabel() == "Mean predicted probability"
assert viz.ax_.get_ylabel() == "Fraction of positives"
assert viz.ax_.get_xlabel() == "Mean predicted probability (Positive class: 1)"
assert viz.ax_.get_ylabel() == "Fraction of positives (Positive class: 1)"
assert viz.line_.get_label() == "LogisticRegression"


Expand Down Expand Up @@ -823,6 +823,34 @@ def test_calibration_curve_pos_label(dtype_y_str):
assert_allclose(prob_true, [0, 0, 0.5, 1])


@pytest.mark.parametrize("pos_label, expected_pos_label", [(None, 1), (0, 0), (1, 1)])
def test_calibration_display_pos_label(
pyplot, iris_data_binary, pos_label, expected_pos_label
):
"""Check the behaviour of `pos_label` in the `CalibrationDisplay`."""
X, y = iris_data_binary

lr = LogisticRegression().fit(X, y)
viz = CalibrationDisplay.from_estimator(lr, X, y, pos_label=pos_label)

y_prob = lr.predict_proba(X)[:, expected_pos_label]
prob_true, prob_pred = calibration_curve(y, y_prob, pos_label=pos_label)

assert_allclose(viz.prob_true, prob_true)
assert_allclose(viz.prob_pred, prob_pred)
assert_allclose(viz.y_prob, y_prob)

assert (
viz.ax_.get_xlabel()
== f"Mean predicted probability (Positive class: {expected_pos_label})"
)
assert (
viz.ax_.get_ylabel()
== f"Fraction of positives (Positive class: {expected_pos_label})"
)
assert viz.line_.get_label() == "LogisticRegression"


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("ensemble", [True, False])
def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):
Expand Down