From 9d7e92a9e470c8873b8a49e479f6de67ac0309f6 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 30 Oct 2019 15:09:30 +0100 Subject: [PATCH 1/9] Improve error message with implicit pos_label in _binary_clf_curve --- sklearn/metrics/_ranking.py | 22 +++++++++++++++------- sklearn/metrics/tests/test_ranking.py | 11 ++++++++++- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index c271781638668..76149c703b0d2 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -31,6 +31,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero +from ..utils import _determine_key_type from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..preprocessing._label import _encode @@ -525,14 +526,21 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): sample_weight = column_or_1d(sample_weight) # ensure binary classification if pos_label is not specified + # _determine_key_type(classes) == 'str' is required to avoid + # triggering a FutureWarning by calling np.array_equal(a, b) + # where a has string values and b has integer values. classes = np.unique(y_true) - if (pos_label is None and - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1]))): - raise ValueError("Data is not binary and pos_label is not specified") + if (pos_label is None and ( + _determine_key_type(classes) == 'str' or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): + raise ValueError("y_true takes value in {classes} and pos_label is " + "not specified: either make y_true take integer " + "value in {{0, 1}} or {{-1, 1}} or pass pos_label " + "explicitly.".format(classes=set(classes))) elif pos_label is None: pos_label = 1. diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 0275a26055915..18c9444184231 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -662,7 +662,7 @@ def test_auc_score_non_binary_class(): roc_auc_score(y_true, y_pred) -def test_binary_clf_curve(): +def test_binary_clf_curve_multiclass_error(): rng = check_random_state(404) y_true = rng.randint(0, 3, size=10) y_pred = rng.rand(10) @@ -671,6 +671,15 @@ def test_binary_clf_curve(): precision_recall_curve(y_true, y_pred) +def test_binary_clf_curve_implicit_pos_label(): + y_true = ["a", "b"] + y_pred = [0., 1.] + msg = ("make y_true take integer value in {0, 1} or {-1, 1}" + " or pass pos_label explicitly.") + with pytest.raises(ValueError, match=msg): + precision_recall_curve(y_true, y_pred) + + def test_precision_recall_curve(): y_true, _, probas_pred = make_prediction(binary=True) _test_precision_recall_curve(y_true, probas_pred) From 2af412d03c4970002ca4a57014c56b0130feb817 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 12 Nov 2019 09:49:00 +0100 Subject: [PATCH 2/9] PEP8 fix indentation in test --- sklearn/metrics/tests/test_ranking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 18c9444184231..db4ab8263534f 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -1086,8 +1086,8 @@ def check_alternative_lrap_implementation(lrap_score, n_classes=5, # Score with ties y_score = _sparse_random_matrix(n_components=y_true.shape[0], - n_features=y_true.shape[1], - random_state=random_state) + n_features=y_true.shape[1], + random_state=random_state) if hasattr(y_score, "toarray"): y_score = y_score.toarray() From eb1e3a24930153f21832debe1483915206f628f3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 12 Nov 2019 10:11:21 +0100 Subject: [PATCH 3/9] Improve comment --- sklearn/metrics/_ranking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 76149c703b0d2..405e06e3ad3dd 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -528,7 +528,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): # ensure binary classification if pos_label is not specified # _determine_key_type(classes) == 'str' is required to avoid # triggering a FutureWarning by calling np.array_equal(a, b) - # where a has string values and b has integer values. + # when elements in the two arrays are not comparable. classes = np.unique(y_true) if (pos_label is None and ( _determine_key_type(classes) == 'str' or From 89588754b3e414c5ab7aea0ef147178c89a34d3a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 12 Nov 2019 10:12:08 +0100 Subject: [PATCH 4/9] More tests --- sklearn/metrics/tests/test_ranking.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index db4ab8263534f..d1b28d6566a2a 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -667,18 +667,26 @@ def test_binary_clf_curve_multiclass_error(): y_true = rng.randint(0, 3, size=10) y_pred = rng.rand(10) msg = "multiclass format is not supported" + with pytest.raises(ValueError, match=msg): precision_recall_curve(y_true, y_pred) + with pytest.raises(ValueError, match=msg): + roc_curve(y_true, y_pred) + def test_binary_clf_curve_implicit_pos_label(): y_true = ["a", "b"] y_pred = [0., 1.] msg = ("make y_true take integer value in {0, 1} or {-1, 1}" " or pass pos_label explicitly.") + with pytest.raises(ValueError, match=msg): precision_recall_curve(y_true, y_pred) + with pytest.raises(ValueError, match=msg): + roc_curve(y_true, y_pred) + def test_precision_recall_curve(): y_true, _, probas_pred = make_prediction(binary=True) From cf4c97970b03cb208d1ab12cbcfcac473e3d2103 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 20 Nov 2019 14:34:26 +0100 Subject: [PATCH 5/9] Improve support and test for various y_true encoding --- sklearn/metrics/_ranking.py | 7 ++--- sklearn/metrics/tests/test_ranking.py | 39 +++++++++++++++++++++------ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 405e06e3ad3dd..bd46065aaff30 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -531,16 +531,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): # when elements in the two arrays are not comparable. classes = np.unique(y_true) if (pos_label is None and ( - _determine_key_type(classes) == 'str' or + classes.dtype.kind in ('O', 'U', 'S') or not (np.array_equal(classes, [0, 1]) or np.array_equal(classes, [-1, 1]) or np.array_equal(classes, [0]) or np.array_equal(classes, [-1]) or np.array_equal(classes, [1])))): - raise ValueError("y_true takes value in {classes} and pos_label is " + classes_repr = ", ".join(repr(c) for c in classes) + raise ValueError("y_true takes value in {{{classes_repr}}} and pos_label is " "not specified: either make y_true take integer " "value in {{0, 1}} or {{-1, 1}} or pass pos_label " - "explicitly.".format(classes=set(classes))) + "explicitly.".format(classes_repr=classes_repr)) elif pos_label is None: pos_label = 1. diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index d1b28d6566a2a..5fa82e2d175da 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -675,17 +675,40 @@ def test_binary_clf_curve_multiclass_error(): roc_curve(y_true, y_pred) -def test_binary_clf_curve_implicit_pos_label(): - y_true = ["a", "b"] - y_pred = [0., 1.] - msg = ("make y_true take integer value in {0, 1} or {-1, 1}" - " or pass pos_label explicitly.") - +@pytest.mark.parametrize("curve_func", [ + precision_recall_curve, + roc_curve, +]) +def test_binary_clf_curve_implicit_pos_label(curve_func): + # Check that using string class labels raises an informative + # error for any supported string dtype: + msg = ("y_true takes value in {'a', 'b'} and pos_label is " + "not specified: either make y_true take integer " + "value in {0, 1} or {-1, 1} or pass pos_label " + "explicitly.") with pytest.raises(ValueError, match=msg): - precision_recall_curve(y_true, y_pred) + roc_curve(np.array(["a", "b"], dtype=' Date: Wed, 20 Nov 2019 14:36:22 +0100 Subject: [PATCH 6/9] cosmit --- sklearn/metrics/tests/test_ranking.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 5fa82e2d175da..ae0296718f43a 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -701,7 +701,6 @@ def test_binary_clf_curve_implicit_pos_label(curve_func): with pytest.raises(ValueError, match=msg): roc_curve(np.array([b"a", b"b"], dtype=' Date: Wed, 20 Nov 2019 15:20:46 +0100 Subject: [PATCH 7/9] PEP8 --- sklearn/metrics/_ranking.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index bd46065aaff30..5ca767bdbe760 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -538,10 +538,11 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): np.array_equal(classes, [-1]) or np.array_equal(classes, [1])))): classes_repr = ", ".join(repr(c) for c in classes) - raise ValueError("y_true takes value in {{{classes_repr}}} and pos_label is " - "not specified: either make y_true take integer " - "value in {{0, 1}} or {{-1, 1}} or pass pos_label " - "explicitly.".format(classes_repr=classes_repr)) + raise ValueError("y_true takes value in {{{classes_repr}}} and " + "pos_label is not specified: either make y_true " + "take integer value in {{0, 1}} or {{-1, 1}} or " + "pass pos_label explicitly.".format( + classes_repr=classes_repr)) elif pos_label is None: pos_label = 1. From da58f5ad5a7bd2cf681a3da2c70528ae4c2fdeb8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 20 Nov 2019 15:23:08 +0100 Subject: [PATCH 8/9] Unused import --- sklearn/metrics/_ranking.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 5ca767bdbe760..3deea4dcabec5 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -31,7 +31,6 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero -from ..utils import _determine_key_type from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..preprocessing._label import _encode From 27c0de315bfca2dfcf198e2a54a10b62bf99494c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 20 Nov 2019 17:33:48 +0100 Subject: [PATCH 9/9] Fix old comment --- sklearn/metrics/_ranking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 3deea4dcabec5..c88fe685e97c9 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -525,7 +525,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): sample_weight = column_or_1d(sample_weight) # ensure binary classification if pos_label is not specified - # _determine_key_type(classes) == 'str' is required to avoid + # classes.dtype.kind in ('O', 'U', 'S') is required to avoid # triggering a FutureWarning by calling np.array_equal(a, b) # when elements in the two arrays are not comparable. classes = np.unique(y_true)