Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d21847c

Browse files
FIX use same API for CalibrationDisplay than other Display (#21031)
* FIX use same API for CalibrationDisplay than other Display * Update sklearn/calibration.py Co-authored-by: Thomas J. Fan <[email protected]> * iter Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 66f6fb3 commit d21847c

File tree

5 files changed

+24
-22
lines changed

5 files changed

+24
-22
lines changed

sklearn/calibration.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -980,8 +980,8 @@ class CalibrationDisplay:
980980
y_prob : ndarray of shape (n_samples,)
981981
Probability estimates for the positive class, for each sample.
982982
983-
name : str, default=None
984-
Name for labeling curve.
983+
estimator_name : str, default=None
984+
Name of estimator. If None, the estimator name is not shown.
985985
986986
Attributes
987987
----------
@@ -1022,11 +1022,11 @@ class CalibrationDisplay:
10221022
<...>
10231023
"""
10241024

1025-
def __init__(self, prob_true, prob_pred, y_prob, *, name=None):
1025+
def __init__(self, prob_true, prob_pred, y_prob, *, estimator_name=None):
10261026
self.prob_true = prob_true
10271027
self.prob_pred = prob_pred
10281028
self.y_prob = y_prob
1029-
self.name = name
1029+
self.estimator_name = estimator_name
10301030

10311031
def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10321032
"""Plot visualization.
@@ -1041,7 +1041,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10411041
created.
10421042
10431043
name : str, default=None
1044-
Name for labeling curve.
1044+
Name for labeling curve. If `None`, use `estimator_name` if
1045+
not `None`, otherwise no labeling is shown.
10451046
10461047
ref_line : bool, default=True
10471048
If `True`, plots a reference line representing a perfectly
@@ -1061,8 +1062,7 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10611062
if ax is None:
10621063
fig, ax = plt.subplots()
10631064

1064-
name = self.name if name is None else name
1065-
self.name = name
1065+
name = self.estimator_name if name is None else name
10661066

10671067
line_kwargs = {}
10681068
if name is not None:
@@ -1298,6 +1298,9 @@ def from_predictions(
12981298
prob_true, prob_pred = calibration_curve(
12991299
y_true, y_prob, n_bins=n_bins, strategy=strategy
13001300
)
1301+
name = name if name is not None else "Classifier"
13011302

1302-
disp = cls(prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, name=name)
1303+
disp = cls(
1304+
prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name
1305+
)
13031306
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)

sklearn/metrics/_plot/det_curve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
294294
created.
295295
296296
name : str, default=None
297-
Name of DET curve for labeling. If `None`, use the name of the
298-
estimator.
297+
Name of DET curve for labeling. If `None`, use `estimator_name` if
298+
it is not `None`, otherwise no labeling is shown.
299299
300300
**kwargs : dict
301301
Additional keywords arguments passed to matplotlib `plot` function.

sklearn/metrics/_plot/precision_recall_curve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
109109
created.
110110
111111
name : str, default=None
112-
Name of precision recall curve for labeling. If `None`, use the
113-
name of the estimator.
112+
Name of precision recall curve for labeling. If `None`, use
113+
`estimator_name` if not `None`, otherwise no labeling is shown.
114114
115115
**kwargs : dict
116116
Keyword arguments to be passed to matplotlib's `plot`.

sklearn/metrics/_plot/roc_curve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
9494
created.
9595
9696
name : str, default=None
97-
Name of ROC Curve for labeling. If `None`, use the name of the
98-
estimator.
97+
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
98+
not `None`, otherwise no labeling is shown.
9999
100100
Returns
101101
-------

sklearn/tests/test_calibration.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
693693
assert_allclose(viz.prob_pred, prob_pred)
694694
assert_allclose(viz.y_prob, y_prob)
695695

696-
assert viz.name == "LogisticRegression"
696+
assert viz.estimator_name == "LogisticRegression"
697697

698698
# cannot fail thanks to pyplot fixture
699699
import matplotlib as mpl # noqa
@@ -715,7 +715,7 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
715715
clf.fit(X, y)
716716
viz = CalibrationDisplay.from_estimator(clf, X, y)
717717
assert clf.__class__.__name__ in viz.line_.get_label()
718-
assert viz.name == clf.__class__.__name__
718+
assert viz.estimator_name == clf.__class__.__name__
719719

720720

721721
@pytest.mark.parametrize(
@@ -726,24 +726,23 @@ def test_calibration_display_default_labels(pyplot, name, expected_label):
726726
prob_pred = np.array([0.2, 0.8, 0.8, 0.4])
727727
y_prob = np.array([])
728728

729-
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name)
729+
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
730730
viz.plot()
731731
assert viz.line_.get_label() == expected_label
732732

733733

734734
def test_calibration_display_label_class_plot(pyplot):
735735
# Checks that when instantiating `CalibrationDisplay` class then calling
736-
# `plot`, `self.name` is the one given in `plot`
736+
# `plot`, `self.estimator_name` is the one given in `plot`
737737
prob_true = np.array([0, 1, 1, 0])
738738
prob_pred = np.array([0.2, 0.8, 0.8, 0.4])
739739
y_prob = np.array([])
740740

741741
name = "name one"
742-
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name)
743-
assert viz.name == name
742+
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
743+
assert viz.estimator_name == name
744744
name = "name two"
745745
viz.plot(name=name)
746-
assert viz.name == name
747746
assert viz.line_.get_label() == name
748747

749748

@@ -764,7 +763,7 @@ def test_calibration_display_name_multiple_calls(
764763
params = (clf, X, y) if constructor_name == "from_estimator" else (y, y_prob)
765764

766765
viz = constructor(*params, name=clf_name)
767-
assert viz.name == clf_name
766+
assert viz.estimator_name == clf_name
768767
pyplot.close("all")
769768
viz.plot()
770769
assert clf_name == viz.line_.get_label()

0 commit comments

Comments
 (0)