Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX Infer pos_label automatically in plot_roc_curve #15316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
Parameters
----------
estimator : estimator instance
Trained classifier.
Trained binary classifier.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
Expand All @@ -122,9 +122,9 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
Target values.

pos_label : int or str, default=None
The label of the positive class.
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1},
`pos_label` is set to 1, otherwise an error will be raised.
Label of the positive class.
By default, pos_label is inferred automatically by taking the last
class label from the estimator.classes_ attribute.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Expand Down Expand Up @@ -186,15 +186,26 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
raise ValueError('response methods not defined')

y_pred = prediction_method(X)
estimator_name = estimator.__class__.__name__
estimator_classes = getattr(estimator, "classes_", [])

if len(estimator_classes) != 2:
raise ValueError("Estimator {} is not a binary classifier: "
"its classes_ attribute is set to: {}".format(
estimator_name, estimator_classes))

if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator should solve a "
"binary classification problem")
raise ValueError("Predictions by {}.{} should have shape ({}, 2),"
" got {}.".format(
estimator_name, prediction_method.__name__,
y_pred.shape[0], y_pred.shape))
y_pred = y_pred[:, 1]
if pos_label is None:
pos_label = estimator_classes[1]
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate)
roc_auc = auc(fpr, tpr)
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator.__class__.__name__)
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
return viz.plot(ax=ax, name=name, **kwargs)
43 changes: 41 additions & 2 deletions sklearn/metrics/_plot/tests/test_plot_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_curve, auc, roc_auc_score


@pytest.fixture(scope="module")
Expand All @@ -25,7 +25,22 @@ def test_plot_roc_curve_error_non_binary(pyplot, data):
clf = DecisionTreeClassifier()
clf.fit(X, y)

msg = "Estimator should solve a binary classification problem"
msg = "Estimator DecisionTreeClassifier is not a binary classifier"
with pytest.raises(ValueError, match=msg):
plot_roc_curve(clf, X, y)


def test_plot_roc_curve_error_invalid_response_shape(pyplot, data):
X, y = data
clf = DecisionTreeClassifier()
clf.fit(X, y)

# Forcibly make classifier look like a binary classifier while actually
# being trained with 3 classes.
clf.classes_ = np.array([0, 1])

msg = (r"Predictions by DecisionTreeClassifier\.predict_proba should have"
r" shape \(150, 2\), got \(150, 3\)\.")
with pytest.raises(ValueError, match=msg):
plot_roc_curve(clf, X, y)

Expand Down Expand Up @@ -93,3 +108,27 @@ def test_plot_roc_curve(pyplot, response_method, data_binary,
assert viz.line_.get_label() == expected_label
assert viz.ax_.get_ylabel() == "True Positive Rate"
assert viz.ax_.get_xlabel() == "False Positive Rate"


def test_plot_roc_curve_pos_label(pyplot, data_binary):
X, y = data_binary
y = np.array(["neg", "pos"])[y]
lr = LogisticRegression()
lr.fit(X, y)
y_pred = lr.predict_proba(X)[:, 1]
viz = plot_roc_curve(lr, X, y)
assert_allclose(viz.roc_auc, roc_auc_score(y, y_pred))
viz = plot_roc_curve(lr, X, y, pos_label="pos")
assert_allclose(viz.roc_auc, roc_auc_score(y, y_pred))


def test_plot_roc_curve_pos_label_non_standard_integers(pyplot, data_binary):
X, y = data_binary
y = np.array([1, 2])[y]
lr = LogisticRegression()
lr.fit(X, y)
y_pred = lr.predict_proba(X)[:, 1]
viz = plot_roc_curve(lr, X, y)
assert_allclose(viz.roc_auc, roc_auc_score(y, y_pred))
viz = plot_roc_curve(lr, X, y, pos_label=2)
assert_allclose(viz.roc_auc, roc_auc_score(y, y_pred))