diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 7f71c841773f9..bacf7519390f3 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -200,3 +200,52 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, pair_scores[ix] = (a_true_score + b_true_score) / 2 return np.average(pair_scores, weights=prevalence) + + +def _check_pos_label_consistency(pos_label, y_true): + """Check if `pos_label` need to be specified or not. + + In binary classification, we fix `pos_label=1` if the labels are in the set + {-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the + `pos_label` parameters. + + Parameters + ---------- + pos_label : int, str or None + The positive label. + y_true : ndarray of shape (n_samples,) + The target vector. + + Returns + ------- + pos_label : int + If `pos_label` can be inferred, it will be returned. + + Raises + ------ + ValueError + In the case that `y_true` does not have label in {-1, 1} or {0, 1}, + it will raise a `ValueError`. + """ + # ensure binary classification if pos_label is not specified + # classes.dtype.kind in ('O', 'U', 'S') is required to avoid + # triggering a FutureWarning by calling np.array_equal(a, b) + # when elements in the two arrays are not comparable. + classes = np.unique(y_true) + if (pos_label is None and ( + classes.dtype.kind in 'OUS' or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): + classes_repr = ", ".join(repr(c) for c in classes) + raise ValueError( + f"y_true takes value in {{{classes_repr}}} and pos_label is not " + f"specified: either make y_true take value in {{0, 1}} or " + f"{{-1, 1}} or pass pos_label explicitly." + ) + elif pos_label is None: + pos_label = 1.0 + + return pos_label diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 89b86f5146d70..fa843d550077e 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -41,6 +41,8 @@ from ..utils.sparsefuncs import count_nonzero from ..exceptions import UndefinedMetricWarning +from ._base import _check_pos_label_consistency + def _check_zero_division(zero_division): if isinstance(zero_division, str) and zero_division == "warn": @@ -2437,9 +2439,14 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): Sample weights. pos_label : int or str, default=None - Label of the positive class. - Defaults to the greater label unless y_true is all 0 or all -1 - in which case pos_label defaults to 1. + Label of the positive class. `pos_label` will be infered in the + following manner: + + * if `y_true` in {-1, 1} or {0, 1}, `pos_label` defaults to 1; + * else if `y_true` contains string, an error will be raised and + `pos_label` should be explicitely specified; + * otherwise, `pos_label` defaults to the greater label, + i.e. `np.unique(y_true)[-1]`. Returns ------- @@ -2473,25 +2480,27 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): assert_all_finite(y_prob) check_consistent_length(y_true, y_prob, sample_weight) - labels = np.unique(y_true) - if len(labels) > 2: - raise ValueError("Only binary classification is supported. " - "Labels in y_true: %s." % labels) + y_type = type_of_target(y_true) + if y_type != "binary": + raise ValueError( + f"Only binary classification is supported. The type of the target " + f"is {y_type}." + ) + if y_prob.max() > 1: raise ValueError("y_prob contains values greater than 1.") if y_prob.min() < 0: raise ValueError("y_prob contains values less than 0.") - # if pos_label=None, when y_true is in {-1, 1} or {0, 1}, - # pos_label is set to 1 (consistent with precision_recall_curve/roc_curve), - # otherwise pos_label is set to the greater label - # (different from precision_recall_curve/roc_curve, - # the purpose is to keep backward compatibility). - if pos_label is None: - if (np.array_equal(labels, [0]) or - np.array_equal(labels, [-1])): - pos_label = 1 + try: + pos_label = _check_pos_label_consistency(pos_label, y_true) + except ValueError: + classes = np.unique(y_true) + if classes.dtype.kind not in ('O', 'U', 'S'): + # for backward compatibility, if classes are not string then + # `pos_label` will correspond to the greater label + pos_label = classes[-1] else: - pos_label = y_true.max() + raise y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index eab32dbc6b2f6..f7e99891a5fe0 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -36,7 +36,11 @@ from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique -from ._base import _average_binary_score, _average_multiclass_ovo_score +from ._base import ( + _average_binary_score, + _average_multiclass_ovo_score, + _check_pos_label_consistency, +) def auc(x, y): @@ -694,26 +698,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): if sample_weight is not None: sample_weight = column_or_1d(sample_weight) - # ensure binary classification if pos_label is not specified - # classes.dtype.kind in ('O', 'U', 'S') is required to avoid - # triggering a FutureWarning by calling np.array_equal(a, b) - # when elements in the two arrays are not comparable. - classes = np.unique(y_true) - if (pos_label is None and ( - classes.dtype.kind in ('O', 'U', 'S') or - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1])))): - classes_repr = ", ".join(repr(c) for c in classes) - raise ValueError("y_true takes value in {{{classes_repr}}} and " - "pos_label is not specified: either make y_true " - "take value in {{0, 1}} or {{-1, 1}} or " - "pass pos_label explicitly.".format( - classes_repr=classes_repr)) - elif pos_label is None: - pos_label = 1. + pos_label = _check_pos_label_consistency(pos_label, y_true) # make y_true a boolean vector y_true = (y_true == pos_label) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 0fa97eea609f2..c32e9c89ada47 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -4,7 +4,6 @@ from itertools import chain from itertools import permutations import warnings -import re import numpy as np from scipy import linalg @@ -2320,9 +2319,12 @@ def test_brier_score_loss(): # ensure to raise an error for multiclass y_true y_true = np.array([0, 1, 2, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) - error_message = ("Only binary classification is supported. Labels " - "in y_true: {}".format(np.array([0, 1, 2]))) - with pytest.raises(ValueError, match=re.escape(error_message)): + error_message = ( + "Only binary classification is supported. The type of the target is " + "multiclass" + ) + + with pytest.raises(ValueError, match=error_message): brier_score_loss(y_true, y_pred) # calculate correctly when there's only one class in y_true diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 297793541007b..e503a64a47769 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1432,12 +1432,7 @@ def test_metrics_consistent_type_error(metric_name): "metric, y_pred_threshold", [ (average_precision_score, True), - # FIXME: `brier_score_loss` does not follow this convention. - # See discussion in: - # https://github.com/scikit-learn/scikit-learn/issues/18307 - pytest.param( - brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307") - ), + (brier_score_loss, True), (f1_score, False), (partial(fbeta_score, beta=1), False), (jaccard_score, False),