Thanks to visit codestin.com
Credit goes to github.com

Skip to content
11 changes: 11 additions & 0 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
if prediction_method is None:
raise ValueError('response methods not defined')

# Note: set(estimator.classes_) is required to compare str and
# int/float values without raising a numpy FutureWarning.
expected_classes = set(estimator.classes_)
if pos_label is not None and pos_label not in expected_classes:
estimator_name = estimator.__class__.__name__
raise ValueError("pos_label={} is not a valid class label for {}. "
"Expected one of {}."
.format(repr(pos_label),
estimator_name,
expected_classes))

y_pred = prediction_method(X)

if y_pred.ndim != 1:
Expand Down
10 changes: 10 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,13 @@ def test_plot_roc_curve(pyplot, response_method, data_binary,
assert viz.line_.get_label() == expected_label
assert viz.ax_.get_ylabel() == "True Positive Rate"
assert viz.ax_.get_xlabel() == "False Positive Rate"


def test_invalid_pos_label(pyplot, data_binary):
X, y = data_binary
lr = LogisticRegression()
lr.fit(X, y)
msg = ("pos_label='invalid' is not a valid class label for "
"LogisticRegression. Expected one of {0, 1}.")
with pytest.raises(ValueError, match=msg):
plot_roc_curve(lr, X, y, pos_label="invalid")
19 changes: 12 additions & 7 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.multiclass import type_of_target
from ..utils.extmath import stable_cumsum
from ..utils.sparsefuncs import count_nonzero
from ..utils import _determine_key_type
from ..exceptions import UndefinedMetricWarning
from ..preprocessing import label_binarize
from ..preprocessing._label import _encode
Expand Down Expand Up @@ -526,13 +527,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):

# ensure binary classification if pos_label is not specified
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
if (pos_label is None and (
_determine_key_type(classes) == 'str' or
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1])))):
raise ValueError("y_true takes value in {classes} and pos_label is "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need something like the following?

if (pos_label is None and _determine_key_type(classes) == 'str'
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks indeed, I updated my PR based on your suggestion and I could get rid of the remaining numpy FutureWarning in array_equal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the error message here need to be updated.

"not specified: either make y_true take integer "
"value in {{0, 1}} or {{-1, 1}} or pass pos_label "
"explicitly.".format(classes=set(classes)))
elif pos_label is None:
pos_label = 1.

Expand Down
11 changes: 10 additions & 1 deletion sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def test_auc_score_non_binary_class():
roc_auc_score(y_true, y_pred)


def test_binary_clf_curve():
def test_binary_clf_curve_multiclass_error():
rng = check_random_state(404)
y_true = rng.randint(0, 3, size=10)
y_pred = rng.rand(10)
Expand All @@ -671,6 +671,15 @@ def test_binary_clf_curve():
precision_recall_curve(y_true, y_pred)


def test_binary_clf_curve_implicit_pos_label():
y_true = ["a", "b"]
y_pred = [0., 1.]
msg = ("make y_true take integer value in {0, 1} or {-1, 1}"
" or pass pos_label explicitly.")
with pytest.raises(ValueError, match=msg):
precision_recall_curve(y_true, y_pred)
Copy link
Member Author

@ogrisel ogrisel Nov 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qinhanmin2014 I have changed the test in this PR to trigger the exception in _binary_clf_curve from precision_recall_curve rather then plot_roc_curve so that we can merge both #15316 while also improving the error message.

Copy link
Member Author

@ogrisel ogrisel Nov 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is still useful to:

  • raise a meaningful error message when the user passes an invalid pos_label value;
  • improve the error message when passing string labels without passing an explicit pos_label for function that do not automatically label encode string labels (e.g. precision_recall_curve);
  • have more tests.



def test_precision_recall_curve():
y_true, _, probas_pred = make_prediction(binary=True)
_test_precision_recall_curve(y_true, probas_pred)
Expand Down