diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 9a7e08c9d9ff2..fe5e21577a434 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -980,8 +980,8 @@ class CalibrationDisplay: y_prob : ndarray of shape (n_samples,) Probability estimates for the positive class, for each sample. - name : str, default=None - Name for labeling curve. + estimator_name : str, default=None + Name of estimator. If None, the estimator name is not shown. Attributes ---------- @@ -1022,11 +1022,11 @@ class CalibrationDisplay: <...> """ - def __init__(self, prob_true, prob_pred, y_prob, *, name=None): + def __init__(self, prob_true, prob_pred, y_prob, *, estimator_name=None): self.prob_true = prob_true self.prob_pred = prob_pred self.y_prob = y_prob - self.name = name + self.estimator_name = estimator_name def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): """Plot visualization. @@ -1041,7 +1041,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): created. name : str, default=None - Name for labeling curve. + Name for labeling curve. If `None`, use `estimator_name` if + not `None`, otherwise no labeling is shown. ref_line : bool, default=True If `True`, plots a reference line representing a perfectly @@ -1061,8 +1062,7 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): if ax is None: fig, ax = plt.subplots() - name = self.name if name is None else name - self.name = name + name = self.estimator_name if name is None else name line_kwargs = {} if name is not None: @@ -1298,6 +1298,9 @@ def from_predictions( prob_true, prob_pred = calibration_curve( y_true, y_prob, n_bins=n_bins, strategy=strategy ) + name = name if name is not None else "Classifier" - disp = cls(prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, name=name) + disp = cls( + prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name + ) return disp.plot(ax=ax, ref_line=ref_line, **kwargs) diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index cb71c6f9cbe98..92e84ce9b7974 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -294,8 +294,8 @@ def plot(self, ax=None, *, name=None, **kwargs): created. name : str, default=None - Name of DET curve for labeling. If `None`, use the name of the - estimator. + Name of DET curve for labeling. If `None`, use `estimator_name` if + it is not `None`, otherwise no labeling is shown. **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index fb09d299d39d4..eaf8240062174 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -109,8 +109,8 @@ def plot(self, ax=None, *, name=None, **kwargs): created. name : str, default=None - Name of precision recall curve for labeling. If `None`, use the - name of the estimator. + Name of precision recall curve for labeling. If `None`, use + `estimator_name` if not `None`, otherwise no labeling is shown. **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 1eed3557e4553..7d222b82e4638 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -94,8 +94,8 @@ def plot(self, ax=None, *, name=None, **kwargs): created. name : str, default=None - Name of ROC Curve for labeling. If `None`, use the name of the - estimator. + Name of ROC Curve for labeling. If `None`, use `estimator_name` if + not `None`, otherwise no labeling is shown. Returns ------- diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b06f14b082cf5..040571df4681b 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -693,7 +693,7 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy) assert_allclose(viz.prob_pred, prob_pred) assert_allclose(viz.y_prob, y_prob) - assert viz.name == "LogisticRegression" + assert viz.estimator_name == "LogisticRegression" # cannot fail thanks to pyplot fixture import matplotlib as mpl # noqa @@ -715,7 +715,7 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary): clf.fit(X, y) viz = CalibrationDisplay.from_estimator(clf, X, y) assert clf.__class__.__name__ in viz.line_.get_label() - assert viz.name == clf.__class__.__name__ + assert viz.estimator_name == clf.__class__.__name__ @pytest.mark.parametrize( @@ -726,24 +726,23 @@ def test_calibration_display_default_labels(pyplot, name, expected_label): prob_pred = np.array([0.2, 0.8, 0.8, 0.4]) y_prob = np.array([]) - viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name) + viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name) viz.plot() assert viz.line_.get_label() == expected_label def test_calibration_display_label_class_plot(pyplot): # Checks that when instantiating `CalibrationDisplay` class then calling - # `plot`, `self.name` is the one given in `plot` + # `plot`, `self.estimator_name` is the one given in `plot` prob_true = np.array([0, 1, 1, 0]) prob_pred = np.array([0.2, 0.8, 0.8, 0.4]) y_prob = np.array([]) name = "name one" - viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name) - assert viz.name == name + viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name) + assert viz.estimator_name == name name = "name two" viz.plot(name=name) - assert viz.name == name assert viz.line_.get_label() == name @@ -764,7 +763,7 @@ def test_calibration_display_name_multiple_calls( params = (clf, X, y) if constructor_name == "from_estimator" else (y, y_prob) viz = constructor(*params, name=clf_name) - assert viz.name == clf_name + assert viz.estimator_name == clf_name pyplot.close("all") viz.plot() assert clf_name == viz.line_.get_label()