From 0b19e2da51ab095d7566c0c76fec97d088b70fce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 12:27:18 +0200 Subject: [PATCH 01/14] MNT make error message consistent in brier_score_loss --- sklearn/metrics/_base.py | 49 ++++++++++++++++++++++++++++++ sklearn/metrics/_classification.py | 26 ++++++++++------ sklearn/metrics/_ranking.py | 27 ++++------------ 3 files changed, 72 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 21d0ab38f6a91..5fa63cf18de71 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -200,3 +200,52 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, pair_scores[ix] = (a_true_score + b_true_score) / 2 return np.average(pair_scores, weights=prevalence) + + +def _check_ambiguity_pos_label(pos_label, y_true): + """Check if `pos_label` need to be specified or not. + + In binary classification, we fix `pos_label=1` if the labels are in the set + {-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the + `pos_label` parameters. + + Parameters + ---------- + pos_label : int, str or None + The positive label. + y_true : ndarray of shape (n_samples,) + The target vector. + + Returns + ------- + pos_label : int + If `pos_label` can be inferred, it will be returned. + + Raises + ------ + ValueError + In the case that `y_true` does not have label in {-1, 1} or {0, 1}, + it will raise a `ValueError`. + """ + # ensure binary classification if pos_label is not specified + # classes.dtype.kind in ('O', 'U', 'S') is required to avoid + # triggering a FutureWarning by calling np.array_equal(a, b) + # when elements in the two arrays are not comparable. + classes = np.unique(y_true) + if (pos_label is None and ( + classes.dtype.kind in ('O', 'U', 'S') or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): + classes_repr = ", ".join(repr(c) for c in classes) + raise ValueError("y_true takes value in {{{classes_repr}}} and " + "pos_label is not specified: either make y_true " + "take value in {{0, 1}} or {{-1, 1}} or " + "pass pos_label explicitly.".format( + classes_repr=classes_repr)) + elif pos_label is None: + pos_label = 1.0 + + return pos_label diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index b67f5bd972c1d..27dd48a43d97f 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -41,6 +41,8 @@ from ..utils.sparsefuncs import count_nonzero from ..exceptions import UndefinedMetricWarning +from ._base import _check_ambiguity_pos_label + def _check_zero_division(zero_division): if isinstance(zero_division, str) and zero_division == "warn": @@ -2451,10 +2453,13 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): assert_all_finite(y_prob) check_consistent_length(y_true, y_prob, sample_weight) - labels = np.unique(y_true) - if len(labels) > 2: - raise ValueError("Only binary classification is supported. " - "Labels in y_true: %s." % labels) + y_type = type_of_target(y_true) + if y_type != "binary": + raise ValueError( + f"Only binary classification is supported. The type of the target " + f"is {y_type}." + ) + if y_prob.max() > 1: raise ValueError("y_prob contains values greater than 1.") if y_prob.min() < 0: @@ -2465,11 +2470,14 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): # otherwise pos_label is set to the greater label # (different from precision_recall_curve/roc_curve, # the purpose is to keep backward compatibility). - if pos_label is None: - if (np.array_equal(labels, [0]) or - np.array_equal(labels, [-1])): - pos_label = 1 + try: + pos_label = _check_ambiguity_pos_label(pos_label, y_true) + except ValueError: + classes = np.unique(y_true) + if classes.dtype.kind not in ('O', 'U', 'S'): + # for backward compatibility + pos_label = classes[-1] else: - pos_label = y_true.max() + raise y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 6727de0c05c65..e27c2a2f2903c 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -36,7 +36,11 @@ from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique -from ._base import _average_binary_score, _average_multiclass_ovo_score +from ._base import ( + _average_binary_score, + _average_multiclass_ovo_score, + _check_ambiguity_pos_label, +) def auc(x, y): @@ -638,26 +642,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): if sample_weight is not None: sample_weight = column_or_1d(sample_weight) - # ensure binary classification if pos_label is not specified - # classes.dtype.kind in ('O', 'U', 'S') is required to avoid - # triggering a FutureWarning by calling np.array_equal(a, b) - # when elements in the two arrays are not comparable. - classes = np.unique(y_true) - if (pos_label is None and ( - classes.dtype.kind in ('O', 'U', 'S') or - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1])))): - classes_repr = ", ".join(repr(c) for c in classes) - raise ValueError("y_true takes value in {{{classes_repr}}} and " - "pos_label is not specified: either make y_true " - "take value in {{0, 1}} or {{-1, 1}} or " - "pass pos_label explicitly.".format( - classes_repr=classes_repr)) - elif pos_label is None: - pos_label = 1. + pos_label = _check_ambiguity_pos_label(pos_label, y_true) # make y_true a boolean vector y_true = (y_true == pos_label) From d7a783c9188d16ddf581ffeb422bffee70742c8f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 12:28:50 +0200 Subject: [PATCH 02/14] iter --- sklearn/metrics/_classification.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 27dd48a43d97f..f84d72e8e1237 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2417,9 +2417,14 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): Sample weights. pos_label : int or str, default=None - Label of the positive class. - Defaults to the greater label unless y_true is all 0 or all -1 - in which case pos_label defaults to 1. + Label of the positive class. `pos_label` will be infered in the + following manner: + + * if `y_true` in {-1, 1} or {0, 1}, `pos_label` defaults to 1; + * else if `y_true` contains string, an error will be raised and + `pos_label` should be explicitely specified; + * otherwise, `pos_label` defaults to the greater label, + i.e. `np.unique(y_true)[-1]`. Returns ------- @@ -2465,17 +2470,13 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): if y_prob.min() < 0: raise ValueError("y_prob contains values less than 0.") - # if pos_label=None, when y_true is in {-1, 1} or {0, 1}, - # pos_label is set to 1 (consistent with precision_recall_curve/roc_curve), - # otherwise pos_label is set to the greater label - # (different from precision_recall_curve/roc_curve, - # the purpose is to keep backward compatibility). try: pos_label = _check_ambiguity_pos_label(pos_label, y_true) except ValueError: classes = np.unique(y_true) if classes.dtype.kind not in ('O', 'U', 'S'): - # for backward compatibility + # for backward compatibility, if classes are not string then + # `pos_label` will correspond to the greater label pos_label = classes[-1] else: raise From 9f7fc167d305993ab1290ddeeff3703e4c825fee Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 14:17:20 +0200 Subject: [PATCH 03/14] TST change error matching --- sklearn/metrics/tests/test_classification.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6677f3119dacd..b41ff6985a6cd 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2262,9 +2262,12 @@ def test_brier_score_loss(): # ensure to raise an error for multiclass y_true y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) - error_message = ("Only binary classification is supported. Labels " - "in y_true: {}".format(np.array([0, 1, 2]))) - with pytest.raises(ValueError, match=re.escape(error_message)): + error_message = ( + "Only binary classification is supported. The type of the target is " + "multiclass" + ) + + with pytest.raises(ValueError, match=error_message): brier_score_loss(y_true, y_pred) # calculate correctly when there's only one class in y_true From e2d3760cc4e062cfd0ed1e835550b59f6e0a3f8b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 18 Aug 2020 14:21:04 +0200 Subject: [PATCH 04/14] PEP8 --- sklearn/metrics/tests/test_classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index b41ff6985a6cd..63ac6378d0b4d 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -4,7 +4,6 @@ from itertools import chain from itertools import permutations import warnings -import re import numpy as np from scipy import linalg From 785b33f6d845810bc179251953a56a3c923331df Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 09:35:52 +0200 Subject: [PATCH 05/14] TST check consistent error message --- sklearn/metrics/tests/test_common.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 24f01d46610a7..8a028ea79db41 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1412,3 +1412,27 @@ def test_thresholded_metric_permutation_invariance(name): current_score = metric(y_true_perm, y_score_perm) assert_almost_equal(score, current_score) + + +@pytest.mark.parametrize( + "metric", + [ + roc_curve, + precision_recall_curve, + brier_score_loss, + ], +) +def test_classification_pos_label_error_with_string(metric): + # check that we raise a consistent error if pos_label is not provided + # when the target is composed of strings + random_state = check_random_state(0) + y1 = np.array(["eggs"] * 2 + ["spam"] * 3, dtype=object) + y2 = random_state.randint(0, 2, size=(5,)) + + err_msg = ( + "y_true takes value in {'eggs', 'spam'} and pos_label is not " + "specified: either make y_true take value in {0, 1} or {-1, 1} or " + "pass pos_label explicit" + ) + with pytest.raises(ValueError, match=err_msg): + metric(y1, y2) From 247596cee172225af871b6ac6c89b2518ae2c8da Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 12:11:17 +0200 Subject: [PATCH 06/14] STY --- sklearn/metrics/_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 5fa63cf18de71..7aedba6b325c1 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -240,11 +240,11 @@ def _check_ambiguity_pos_label(pos_label, y_true): np.array_equal(classes, [-1]) or np.array_equal(classes, [1])))): classes_repr = ", ".join(repr(c) for c in classes) - raise ValueError("y_true takes value in {{{classes_repr}}} and " - "pos_label is not specified: either make y_true " - "take value in {{0, 1}} or {{-1, 1}} or " - "pass pos_label explicitly.".format( - classes_repr=classes_repr)) + raise ValueError( + f"y_true takes value in {{{classes_repr}}} and pos_label is not " + f"specified: either make y_true take value in {{0, 1}} or " + f"{{-1, 1}} or pass pos_label explicitly." + ) elif pos_label is None: pos_label = 1.0 From 8e0f7fbfdd717859e13635377b5b36b30c80f8e6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 26 Aug 2020 10:34:42 +0200 Subject: [PATCH 07/14] Apply suggestions from code review Co-authored-by: Thomas J. Fan --- sklearn/metrics/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 7aedba6b325c1..feb5b933b1a4e 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -233,7 +233,7 @@ def _check_ambiguity_pos_label(pos_label, y_true): # when elements in the two arrays are not comparable. classes = np.unique(y_true) if (pos_label is None and ( - classes.dtype.kind in ('O', 'U', 'S') or + classes.dtype.kind in 'OUS' or not (np.array_equal(classes, [0, 1]) or np.array_equal(classes, [-1, 1]) or np.array_equal(classes, [0]) or From dfee51257030c1fadbb40ac22461d38fb9d12ee1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 26 Aug 2020 10:37:00 +0200 Subject: [PATCH 08/14] iter --- sklearn/metrics/tests/test_common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 8a028ea79db41..5c53154972df0 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1422,11 +1422,14 @@ def test_thresholded_metric_permutation_invariance(name): brier_score_loss, ], ) -def test_classification_pos_label_error_with_string(metric): +@pytest.mark.parametrize("dtype", [None, object]) +def test_classification_pos_label_error_with_string(metric, dtype): # check that we raise a consistent error if pos_label is not provided # when the target is composed of strings random_state = check_random_state(0) - y1 = np.array(["eggs"] * 2 + ["spam"] * 3, dtype=object) + y1 = np.array(["eggs"] * 2 + ["spam"] * 3) + if dtype is not None: + y1 = y1.astype(dtype) y2 = random_state.randint(0, 2, size=(5,)) err_msg = ( From 2d024fbaa4df5389926c48b4deb8d68dd8aca30e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 26 Aug 2020 10:39:13 +0200 Subject: [PATCH 09/14] cleaner --- sklearn/metrics/tests/test_common.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5c53154972df0..b954b7952afea 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1422,14 +1422,12 @@ def test_thresholded_metric_permutation_invariance(name): brier_score_loss, ], ) -@pytest.mark.parametrize("dtype", [None, object]) +@pytest.mark.parametrize("dtype", [str, object]) def test_classification_pos_label_error_with_string(metric, dtype): # check that we raise a consistent error if pos_label is not provided # when the target is composed of strings random_state = check_random_state(0) - y1 = np.array(["eggs"] * 2 + ["spam"] * 3) - if dtype is not None: - y1 = y1.astype(dtype) + y1 = np.array(["eggs"] * 2 + ["spam"] * 3, dtype=dtype) y2 = random_state.randint(0, 2, size=(5,)) err_msg = ( From 383fb5039325534d44b75b10627db54d3a5211a3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Sep 2020 16:48:38 +0200 Subject: [PATCH 10/14] iter --- sklearn/metrics/tests/test_common.py | 86 ++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5ff4b08368e90..6b18921184981 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1392,25 +1392,91 @@ def test_thresholded_multilabel_multioutput_permutations_invariance(name): @pytest.mark.parametrize( - "metric", + 'name', + sorted(set(THRESHOLDED_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)) +def test_thresholded_metric_permutation_invariance(name): + n_samples, n_classes = 100, 3 + random_state = check_random_state(0) + + y_score = random_state.rand(n_samples, n_classes) + temp = np.exp(-y_score) + y_score = temp / temp.sum(axis=-1).reshape(-1, 1) + y_true = random_state.randint(0, n_classes, size=n_samples) + + metric = ALL_METRICS[name] + score = metric(y_true, y_score) + for perm in permutations(range(n_classes), n_classes): + inverse_perm = np.zeros(n_classes, dtype=int) + inverse_perm[list(perm)] = np.arange(n_classes) + y_score_perm = y_score[:, inverse_perm] + y_true_perm = np.take(perm, y_true) + + current_score = metric(y_true_perm, y_score_perm) + assert_almost_equal(score, current_score) + + +@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS) +def test_metrics_consistent_type_error(metric_name): + # check that an understable message is raised when the type between y_true + # and y_pred mismatch + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg = "Labels in y_true and y_pred should be of the same type." + with pytest.raises(TypeError, match=err_msg): + CLASSIFICATION_METRICS[metric_name](y1, y2) + + +@pytest.mark.parametrize( + "metric, y_pred_threshold", [ + (average_precision_score, True), + # FIXME: `brier_score_loss` does not follow this convention. + # See discussion in: + # https://github.com/scikit-learn/scikit-learn/issues/18307 + pytest.param( + brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307") + ), + (f1_score, False), + (partial(fbeta_score, beta=1), False), + (jaccard_score, False), + (precision_recall_curve, True), + (precision_score, False), + (recall_score, False), + (roc_curve, True), roc_curve, precision_recall_curve, brier_score_loss, ], ) -@pytest.mark.parametrize("dtype", [str, object]) -def test_classification_pos_label_error_with_string(metric, dtype): - # check that we raise a consistent error if pos_label is not provided - # when the target is composed of strings - random_state = check_random_state(0) - y1 = np.array(["eggs"] * 2 + ["spam"] * 3, dtype=dtype) - y2 = random_state.randint(0, 2, size=(5,)) - - err_msg = ( +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): + # check that the error message if `pos_label` is not specified and the + # targets is made of strings. + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) + y2 = rng.randint(0, 2, size=y1.size) + + if not y_pred_threshold: + y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2] + + err_msg_pos_label_None = ( "y_true takes value in {'eggs', 'spam'} and pos_label is not " "specified: either make y_true take value in {0, 1} or {-1, 1} or " "pass pos_label explicit" ) + err_msg_pos_label_1 = ( + r"pos_label=1 is not a valid label. It should be one of " + r"\['eggs', 'spam'\]" + ) + + pos_label_default = signature(metric).parameters["pos_label"].default + + err_msg = ( + err_msg_pos_label_1 + if pos_label_default == 1 + else err_msg_pos_label_None + ) with pytest.raises(ValueError, match=err_msg): metric(y1, y2) From 3a9a55a9730befef36a125d5b1259cb8b8ec1a17 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Sep 2020 16:51:39 +0200 Subject: [PATCH 11/14] iter --- sklearn/metrics/tests/test_common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6b18921184981..297793541007b 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1445,9 +1445,6 @@ def test_metrics_consistent_type_error(metric_name): (precision_score, False), (recall_score, False), (roc_curve, True), - roc_curve, - precision_recall_curve, - brier_score_loss, ], ) @pytest.mark.parametrize("dtype_y_str", [str, object]) From fc8aaf9fe4142ee132abd6098cc3eaf89d0cae66 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Sep 2020 16:53:46 +0200 Subject: [PATCH 12/14] iter --- sklearn/metrics/tests/test_common.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 297793541007b..e503a64a47769 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1432,12 +1432,7 @@ def test_metrics_consistent_type_error(metric_name): "metric, y_pred_threshold", [ (average_precision_score, True), - # FIXME: `brier_score_loss` does not follow this convention. - # See discussion in: - # https://github.com/scikit-learn/scikit-learn/issues/18307 - pytest.param( - brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307") - ), + (brier_score_loss, True), (f1_score, False), (partial(fbeta_score, beta=1), False), (jaccard_score, False), From d65a9e69dd6739f6c0fd6bfc7ab86ba7193b3a42 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 10 Sep 2020 23:44:25 +0200 Subject: [PATCH 13/14] rename function with better name --- sklearn/metrics/_base.py | 2 +- sklearn/metrics/_ranking.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index d9d0127b0c4b3..bacf7519390f3 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -202,7 +202,7 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, return np.average(pair_scores, weights=prevalence) -def _check_ambiguity_pos_label(pos_label, y_true): +def _check_pos_label_consistency(pos_label, y_true): """Check if `pos_label` need to be specified or not. In binary classification, we fix `pos_label=1` if the labels are in the set diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 092f8c49ccaba..f7e99891a5fe0 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -39,7 +39,7 @@ from ._base import ( _average_binary_score, _average_multiclass_ovo_score, - _check_ambiguity_pos_label, + _check_pos_label_consistency, ) @@ -698,7 +698,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): if sample_weight is not None: sample_weight = column_or_1d(sample_weight) - pos_label = _check_ambiguity_pos_label(pos_label, y_true) + pos_label = _check_pos_label_consistency(pos_label, y_true) # make y_true a boolean vector y_true = (y_true == pos_label) From 7ce18700a90a032e7c447649b0e062fe78d5d81f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 29 Sep 2020 22:58:01 +0200 Subject: [PATCH 14/14] iter --- sklearn/metrics/_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index c02904bfbd0e0..fa843d550077e 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -41,7 +41,7 @@ from ..utils.sparsefuncs import count_nonzero from ..exceptions import UndefinedMetricWarning -from ._base import _check_ambiguity_pos_label +from ._base import _check_pos_label_consistency def _check_zero_division(zero_division): @@ -2493,7 +2493,7 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): raise ValueError("y_prob contains values less than 0.") try: - pos_label = _check_ambiguity_pos_label(pos_label, y_true) + pos_label = _check_pos_label_consistency(pos_label, y_true) except ValueError: classes = np.unique(y_true) if classes.dtype.kind not in ('O', 'U', 'S'):