From 840adc0c59e887eecb4888afb9e81fb93c7506ef Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Thu, 31 Oct 2019 11:37:53 +0800 Subject: [PATCH 01/13] FIX Correctly infer pos_label in brier_score_loss --- sklearn/metrics/_classification.py | 2 +- sklearn/metrics/tests/test_classification.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 04d0a009df4b0..8630d10b7d6fd 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2486,6 +2486,6 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): np.array_equal(labels, [-1])): pos_label = 1 else: - pos_label = y_true.max() + pos_label = labels[-1] y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6d981ee4da53c..05d496f8a259f 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2180,6 +2180,16 @@ def test_brier_score_loss(): assert_almost_equal( brier_score_loss(['foo'], [0.4], pos_label='foo'), 0.36) + # correctly infer pos_label + y_true = np.array([0, 1, 1, 0]) + y_pred = np.array([0.8, 0.6, 0.4, 0.2]) + score1 = brier_score_loss(y_true, y_pred, pos_label=1) + score2 = brier_score_loss(y_true, y_pred) + assert_almost_equal(score1, score2) + y_true = np.array(["neg", "pos", "pos", "neg"]) + score2 = brier_score_loss(y_true, y_pred) + assert_almost_equal(score1, score2) + def test_balanced_accuracy_score_unseen(): assert_warns_message(UserWarning, 'y_pred contains classes not in y_true', From b9d750cd44cc96c19ecf4a427e9980b382864b95 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 2 Nov 2019 22:54:46 +0800 Subject: [PATCH 02/13] whats new --- doc/whats_new/v0.22.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 39235625093bc..66a5c82614f0e 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -506,6 +506,10 @@ Changelog ``multioutput`` parameter. :pr:`14732` by :user:`Agamemnon Krasoulis `. +- |Fix| FIxed a bug where :func:`metrics.brier_score_loss` will raise an error + when ``y_true`` is string and ``pos_label`` is not specified. + :pr:`15412` by `Hanmin Qin`_. + :mod:`sklearn.model_selection` .............................. From 538317e7e64f6c3c2e564fe934a838264f98c7dc Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Thu, 7 Nov 2019 08:10:44 -0600 Subject: [PATCH 03/13] Apply suggestions from code review Co-Authored-By: Guillaume Lemaitre --- doc/whats_new/v0.22.rst | 2 +- sklearn/metrics/tests/test_classification.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 66a5c82614f0e..17cad9a3b371f 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -506,7 +506,7 @@ Changelog ``multioutput`` parameter. :pr:`14732` by :user:`Agamemnon Krasoulis `. -- |Fix| FIxed a bug where :func:`metrics.brier_score_loss` will raise an error +- |Fix| Fixed a bug where :func:`metrics.brier_score_loss` will raise an error when ``y_true`` is string and ``pos_label`` is not specified. :pr:`15412` by `Hanmin Qin`_. diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 05d496f8a259f..72bf278711623 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2185,10 +2185,10 @@ def test_brier_score_loss(): y_pred = np.array([0.8, 0.6, 0.4, 0.2]) score1 = brier_score_loss(y_true, y_pred, pos_label=1) score2 = brier_score_loss(y_true, y_pred) - assert_almost_equal(score1, score2) + assert score1 == pytest.approx(score2) y_true = np.array(["neg", "pos", "pos", "neg"]) score2 = brier_score_loss(y_true, y_pred) - assert_almost_equal(score1, score2) + assert score1 == pytest.approx(score2) def test_balanced_accuracy_score_unseen(): From bf45832e6e434c05efe9c055fd00fee12e310f49 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Fri, 8 Nov 2019 22:59:49 +0800 Subject: [PATCH 04/13] review --- doc/whats_new/v0.22.rst | 2 +- sklearn/metrics/_classification.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 5255a52a5176b..9e83b2b6472b8 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -604,7 +604,7 @@ Changelog used as the :term:`scoring` parameter of model-selection tools. :pr:`14417` by `Thomas Fan`_. -- |Fix| Fixed a bug where :func:`metrics.brier_score_loss` will raise an error +- |Fix| Fixed a bug where :func:`metrics.brier_score_loss` would raise an error when ``y_true`` is string and ``pos_label`` is not specified. :pr:`15412` by `Hanmin Qin`_. diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 8630d10b7d6fd..5fc0e4feb4907 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2432,8 +2432,8 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): pos_label : int or str, default=None Label of the positive class. - Defaults to the greater label unless y_true is all 0 or all -1 - in which case pos_label defaults to 1. + Defaults to the greater label according to lexicographic order, unless + y_true is all 0 or all -1 in which case pos_label defaults to 1. Returns ------- From 6dfde656c06d61cadcf1339154d3cdaa1050a115 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Fri, 15 Nov 2019 16:03:58 +0800 Subject: [PATCH 05/13] new solution --- doc/whats_new/v0.22.rst | 4 --- sklearn/metrics/_classification.py | 28 ++++++++++++-------- sklearn/metrics/tests/test_classification.py | 14 +++------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 3803bf66ad034..6ea54ed8b7f44 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -631,10 +631,6 @@ Changelog used as the :term:`scoring` parameter of model-selection tools. :pr:`14417` by `Thomas Fan`_. -- |Fix| Fixed a bug where :func:`metrics.brier_score_loss` would raise an error - when ``y_true`` is string and ``pos_label`` is not specified. - :pr:`15412` by `Hanmin Qin`_. - :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 5fc0e4feb4907..52bab435ce0c0 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -34,6 +34,7 @@ from ..utils import check_array from ..utils import check_consistent_length from ..utils import column_or_1d +from ..utils import _determine_key_type from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.validation import _num_samples @@ -2428,7 +2429,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): Probabilities of the positive class. sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + The label of the positive class pos_label : int or str, default=None Label of the positive class. @@ -2477,15 +2478,20 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): raise ValueError("y_prob contains values less than 0.") # if pos_label=None, when y_true is in {-1, 1} or {0, 1}, - # pos_label is set to 1 (consistent with precision_recall_curve/roc_curve), - # otherwise pos_label is set to the greater label - # (different from precision_recall_curve/roc_curve, - # the purpose is to keep backward compatibility). - if pos_label is None: - if (np.array_equal(labels, [0]) or - np.array_equal(labels, [-1])): - pos_label = 1 - else: - pos_label = labels[-1] + # pos_label is set to 1 (consistent with precision_recall_curve/roc_curve) + if (pos_label is None and ( + _determine_key_type(labels) == 'str' or + not (np.array_equal(labels, [0, 1]) or + np.array_equal(labels, [-1, 1]) or + np.array_equal(labels, [0]) or + np.array_equal(labels, [-1]) or + np.array_equal(labels, [1])))): + raise ValueError("y_true takes value in {classes} and pos_label is " + "not specified: either make y_true take integer " + "value in {{0, 1}} or {{-1, 1}} or pass pos_label " + "explicitly.".format(classes=labels)) + elif pos_label is None: + pos_label = 1. + y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 72bf278711623..355b35dfbcd3c 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2152,10 +2152,6 @@ def test_brier_score_loss(): assert_almost_equal(brier_score_loss(y_true, y_true), 0.0) assert_almost_equal(brier_score_loss(y_true, y_pred), true_score) - assert_almost_equal(brier_score_loss(1. + y_true, y_pred), - true_score) - assert_almost_equal(brier_score_loss(2 * y_true - 1, y_pred), - true_score) with pytest.raises(ValueError): brier_score_loss(y_true, y_pred[1:]) with pytest.raises(ValueError): @@ -2175,19 +2171,17 @@ def test_brier_score_loss(): assert_almost_equal(brier_score_loss([-1], [0.4]), 0.16) assert_almost_equal(brier_score_loss([0], [0.4]), 0.16) assert_almost_equal(brier_score_loss([1], [0.4]), 0.36) - assert_almost_equal( - brier_score_loss(['foo'], [0.4], pos_label='bar'), 0.16) - assert_almost_equal( - brier_score_loss(['foo'], [0.4], pos_label='foo'), 0.36) - # correctly infer pos_label + # raise error when y_true is str and pos_label is not specified y_true = np.array([0, 1, 1, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) score1 = brier_score_loss(y_true, y_pred, pos_label=1) score2 = brier_score_loss(y_true, y_pred) assert score1 == pytest.approx(score2) y_true = np.array(["neg", "pos", "pos", "neg"]) - score2 = brier_score_loss(y_true, y_pred) + with pytest.raises(ValueError, match="y_true takes value"): + brier_score_loss(y_true, y_pred) + score2 = brier_score_loss(y_true, y_pred, pos_label="pos") assert score1 == pytest.approx(score2) From 60b65e61a60d6f1655eee4bf3b359361f6f55efd Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Fri, 15 Nov 2019 16:05:58 +0800 Subject: [PATCH 06/13] typo --- sklearn/metrics/_classification.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 52bab435ce0c0..c311ca4706bed 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2429,12 +2429,10 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): Probabilities of the positive class. sample_weight : array-like of shape (n_samples,), default=None - The label of the positive class + Sample weights. pos_label : int or str, default=None - Label of the positive class. - Defaults to the greater label according to lexicographic order, unless - y_true is all 0 or all -1 in which case pos_label defaults to 1. + The label of the positive class. Returns ------- From c03f3101011c2bebbb7d97a3efb8a1a3ab5fdd47 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 16 Nov 2019 13:04:33 +0800 Subject: [PATCH 07/13] address comment --- sklearn/metrics/_classification.py | 31 ++++++++++---------- sklearn/metrics/tests/test_classification.py | 4 +++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index c311ca4706bed..3332e5dd6f579 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2432,7 +2432,10 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): Sample weights. pos_label : int or str, default=None - The label of the positive class. + Label of the positive class. + Defaults to the greater label unless y_true is all 0 or all -1 + in which case pos_label defaults to 1. If `y_true` is str and + `pos_label` is not specified, an error will be raised. Returns ------- @@ -2475,21 +2478,17 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): if y_prob.min() < 0: raise ValueError("y_prob contains values less than 0.") - # if pos_label=None, when y_true is in {-1, 1} or {0, 1}, - # pos_label is set to 1 (consistent with precision_recall_curve/roc_curve) - if (pos_label is None and ( - _determine_key_type(labels) == 'str' or - not (np.array_equal(labels, [0, 1]) or - np.array_equal(labels, [-1, 1]) or - np.array_equal(labels, [0]) or - np.array_equal(labels, [-1]) or - np.array_equal(labels, [1])))): - raise ValueError("y_true takes value in {classes} and pos_label is " - "not specified: either make y_true take integer " - "value in {{0, 1}} or {{-1, 1}} or pass pos_label " - "explicitly.".format(classes=labels)) - elif pos_label is None: - pos_label = 1. + if pos_label is None: + if labels.dtype.kind == 'U': + raise ValueError("y_true takes value in {classes} and pos_label " + "is not specified: either make y_true take " + "integer value or pass pos_label " + "explicitly.".format(classes=labels)) + elif (np.array_equal(labels, [0]) or + np.array_equal(labels, [-1])): + pos_label = 1 + else: + pos_label = y_true.max() y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 355b35dfbcd3c..9414fbc85b283 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2152,6 +2152,10 @@ def test_brier_score_loss(): assert_almost_equal(brier_score_loss(y_true, y_true), 0.0) assert_almost_equal(brier_score_loss(y_true, y_pred), true_score) + assert_almost_equal(brier_score_loss(1. + y_true, y_pred), + true_score) + assert_almost_equal(brier_score_loss(2 * y_true - 1, y_pred), + true_score) with pytest.raises(ValueError): brier_score_loss(y_true, y_pred[1:]) with pytest.raises(ValueError): From a544fdc1a13323129f3eeba82b3c9f6e9e889152 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 16 Nov 2019 13:08:55 +0800 Subject: [PATCH 08/13] flake8 --- sklearn/metrics/_classification.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 3332e5dd6f579..1e14bee56fe90 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -34,7 +34,6 @@ from ..utils import check_array from ..utils import check_consistent_length from ..utils import column_or_1d -from ..utils import _determine_key_type from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.validation import _num_samples @@ -2435,7 +2434,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): Label of the positive class. Defaults to the greater label unless y_true is all 0 or all -1 in which case pos_label defaults to 1. If `y_true` is str and - `pos_label` is not specified, an error will be raised. + `pos_label` is not specified, an error will be raised. Returns ------- From bc59eda4d2e669ed2a36bcd82857953b99dabcc7 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 16 Nov 2019 15:37:32 +0800 Subject: [PATCH 09/13] perhaps better --- sklearn/metrics/_classification.py | 9 ++++----- sklearn/metrics/tests/test_classification.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 1e14bee56fe90..75cda76ee172f 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2478,11 +2478,10 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): raise ValueError("y_prob contains values less than 0.") if pos_label is None: - if labels.dtype.kind == 'U': - raise ValueError("y_true takes value in {classes} and pos_label " - "is not specified: either make y_true take " - "integer value or pass pos_label " - "explicitly.".format(classes=labels)) + if labels.dtype.kind in ('O', 'S', 'U'): + raise ValueError("y_true takes str values in {classes} " + "and pos_label is not " + "specified".format(classes=labels)) elif (np.array_equal(labels, [0]) or np.array_equal(labels, [-1])): pos_label = 1 diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 9414fbc85b283..df1eef039404a 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2183,7 +2183,7 @@ def test_brier_score_loss(): score2 = brier_score_loss(y_true, y_pred) assert score1 == pytest.approx(score2) y_true = np.array(["neg", "pos", "pos", "neg"]) - with pytest.raises(ValueError, match="y_true takes value"): + with pytest.raises(ValueError, match="y_true takes str values"): brier_score_loss(y_true, y_pred) score2 = brier_score_loss(y_true, y_pred, pos_label="pos") assert score1 == pytest.approx(score2) From 91e99cf9d75eb530962bf664862cefc366ba0f9d Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 17 Nov 2019 11:56:13 +0800 Subject: [PATCH 10/13] review --- sklearn/metrics/_classification.py | 20 +++++++++++++------- sklearn/metrics/tests/test_classification.py | 5 +++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 75cda76ee172f..9bd2051b6c65e 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2433,8 +2433,8 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): pos_label : int or str, default=None Label of the positive class. Defaults to the greater label unless y_true is all 0 or all -1 - in which case pos_label defaults to 1. If `y_true` is str and - `pos_label` is not specified, an error will be raised. + in which case pos_label defaults to 1. If `y_true` contains strings + and `pos_label` is not specified, an error will be raised. Returns ------- @@ -2477,16 +2477,22 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): if y_prob.min() < 0: raise ValueError("y_prob contains values less than 0.") + # Default behavior when pos_label=None: + # When y_true contains strings, an error will be raised. + # (differ from other functions to keep backward compatibility) + # When y_true is in {-1, 1} or {0, 1}, pos_label is set to 1. + # (consistent with precision_recall_curve/roc_curve) + # Otherwise pos_label is set to the greater label. + # (differ from other functions to keep backward compatibility) if pos_label is None: - if labels.dtype.kind in ('O', 'S', 'U'): - raise ValueError("y_true takes str values in {classes} " - "and pos_label is not " - "specified".format(classes=labels)) + if labels.dtype.kind in ('S', 'U'): + raise ValueError("pos_label must be specified when y_true " + "contains strings.") elif (np.array_equal(labels, [0]) or np.array_equal(labels, [-1])): pos_label = 1 else: - pos_label = y_true.max() + pos_label = labels.max() y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index df1eef039404a..ec9acf65c799b 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2176,14 +2176,15 @@ def test_brier_score_loss(): assert_almost_equal(brier_score_loss([0], [0.4]), 0.16) assert_almost_equal(brier_score_loss([1], [0.4]), 0.36) - # raise error when y_true is str and pos_label is not specified + # make sure the positive class is correctly inferred y_true = np.array([0, 1, 1, 0]) y_pred = np.array([0.8, 0.6, 0.4, 0.2]) score1 = brier_score_loss(y_true, y_pred, pos_label=1) score2 = brier_score_loss(y_true, y_pred) assert score1 == pytest.approx(score2) y_true = np.array(["neg", "pos", "pos", "neg"]) - with pytest.raises(ValueError, match="y_true takes str values"): + # raise error when y_true contains strings and pos_label is not specified + with pytest.raises(ValueError, match="pos_label must be specified"): brier_score_loss(y_true, y_pred) score2 = brier_score_loss(y_true, y_pred, pos_label="pos") assert score1 == pytest.approx(score2) From a9d84be2f27da21d4502288f3aaffe15316b4247 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Fri, 22 Nov 2019 11:20:07 +0800 Subject: [PATCH 11/13] consistency with merged PR --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index efacfbbcf5644..b57cace3908f9 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2500,7 +2500,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): # Otherwise pos_label is set to the greater label. # (differ from other functions to keep backward compatibility) if pos_label is None: - if labels.dtype.kind in ('S', 'U'): + if labels.dtype.kind in ('O', 'U', 'S'): raise ValueError("pos_label must be specified when y_true " "contains strings.") elif (np.array_equal(labels, [0]) or From fb199bd680431fe6c6e4b5c7c2834299ac9e7c24 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 27 Apr 2020 18:36:15 -0400 Subject: [PATCH 12/13] CLN Address comments --- sklearn/metrics/_classification.py | 2 +- sklearn/metrics/tests/test_classification.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 03e36fa33dedb..f042acc7c87d9 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2448,7 +2448,7 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): # Otherwise pos_label is set to the greater label. # (differ from other functions to keep backward compatibility) if pos_label is None: - if labels.dtype.kind in ('O', 'U', 'S'): + if any(isinstance(label, str) for label in labels): raise ValueError("pos_label must be specified when y_true " "contains strings.") elif (np.array_equal(labels, [0]) or diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 2f27125f94c6e..1c197e38cb14a 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2260,6 +2260,10 @@ def test_brier_score_loss(): score2 = brier_score_loss(y_true, y_pred, pos_label="pos") assert score1 == pytest.approx(score2) + y_pred_num_obj = np.array([0, 1, 1, 0], dtype=object) + score3 = brier_score_loss(y_pred_num_obj, y_pred) + assert score1 == pytest.approx(score3) + def test_balanced_accuracy_score_unseen(): assert_warns_message(UserWarning, 'y_pred contains classes not in y_true', From 8635fd816cbaa2bc96b417f8a4c682df2b7c7aed Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 27 Apr 2020 19:25:06 -0400 Subject: [PATCH 13/13] DOC Adds comment --- sklearn/metrics/tests/test_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 1c197e38cb14a..b831cad6a07da 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2260,6 +2260,7 @@ def test_brier_score_loss(): score2 = brier_score_loss(y_true, y_pred, pos_label="pos") assert score1 == pytest.approx(score2) + # positive class if correctly inferred an object array with all ints y_pred_num_obj = np.array([0, 1, 1, 0], dtype=object) score3 = brier_score_loss(y_pred_num_obj, y_pred) assert score1 == pytest.approx(score3)