diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index c271781638668..c88fe685e97c9 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -525,14 +525,23 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): sample_weight = column_or_1d(sample_weight) # ensure binary classification if pos_label is not specified + # classes.dtype.kind in ('O', 'U', 'S') is required to avoid + # triggering a FutureWarning by calling np.array_equal(a, b) + # when elements in the two arrays are not comparable. classes = np.unique(y_true) - if (pos_label is None and - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1]))): - raise ValueError("Data is not binary and pos_label is not specified") + if (pos_label is None and ( + classes.dtype.kind in ('O', 'U', 'S') or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): + classes_repr = ", ".join(repr(c) for c in classes) + raise ValueError("y_true takes value in {{{classes_repr}}} and " + "pos_label is not specified: either make y_true " + "take integer value in {{0, 1}} or {{-1, 1}} or " + "pass pos_label explicitly.".format( + classes_repr=classes_repr)) elif pos_label is None: pos_label = 1. diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 0275a26055915..ae0296718f43a 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -662,14 +662,53 @@ def test_auc_score_non_binary_class(): roc_auc_score(y_true, y_pred) -def test_binary_clf_curve(): +def test_binary_clf_curve_multiclass_error(): rng = check_random_state(404) y_true = rng.randint(0, 3, size=10) y_pred = rng.rand(10) msg = "multiclass format is not supported" + with pytest.raises(ValueError, match=msg): precision_recall_curve(y_true, y_pred) + with pytest.raises(ValueError, match=msg): + roc_curve(y_true, y_pred) + + +@pytest.mark.parametrize("curve_func", [ + precision_recall_curve, + roc_curve, +]) +def test_binary_clf_curve_implicit_pos_label(curve_func): + # Check that using string class labels raises an informative + # error for any supported string dtype: + msg = ("y_true takes value in {'a', 'b'} and pos_label is " + "not specified: either make y_true take integer " + "value in {0, 1} or {-1, 1} or pass pos_label " + "explicitly.") + with pytest.raises(ValueError, match=msg): + roc_curve(np.array(["a", "b"], dtype='