From 359b1147183296fbabcdf7868ead6d6860675d69 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 10:18:03 +0200 Subject: [PATCH 01/10] TST make consistent error for metric having pos_label --- sklearn/metrics/tests/test_common.py | 40 ++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 24f01d46610a7..3a8f11ba4d89a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1,5 +1,6 @@ from functools import partial +from inspect import signature from itertools import product from itertools import chain from itertools import permutations @@ -1412,3 +1413,42 @@ def test_thresholded_metric_permutation_invariance(name): current_score = metric(y_true_perm, y_score_perm) assert_almost_equal(score, current_score) + + +@pytest.mark.parametrize( + "metric", + [ + average_precision_score, + brier_score_loss, + f1_score, + fbeta_score, + jaccard_score, + precision_recall_curve, + precision_score, + recall_score, + roc_curve, + ], +) +def test_metrics_pos_label_error_str(metric): + # check that the error message if `pos_label` is not specified and the + # targets is made of strings. + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg_pos_label_None = ( + "y_true takes value in {'eggs', 'spam'} and pos_label is not " + "specified: either make y_true take value in {0, 1} or {-1, 1} or " + "pass pos_label explicit" + ) + err_msg_pos_label_1 = "pos_label=1 is invalid. Set it to a label in y_true" + + pos_label_default = signature(metric).parameters["pos_label"].default + + err_msg = ( + err_msg_pos_label_1 + if pos_label_default == 1 + else err_msg_pos_label_None + ) + with pytest.raises(ValueError, match=err_msg): + metric(y1, y2) From 8630620a7ef14283be5c6e0c1f37d07391b6e705 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 11:27:31 +0200 Subject: [PATCH 02/10] iter --- sklearn/metrics/_classification.py | 23 ++++++++++--- sklearn/metrics/_ranking.py | 8 +++-- sklearn/metrics/tests/test_common.py | 50 ++++++++++++++++++++-------- 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index b67f5bd972c1d..3456b61bfe781 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -101,7 +101,20 @@ def _check_targets(y_true, y_pred): y_true = column_or_1d(y_true) y_pred = column_or_1d(y_pred) if y_type == "binary": - unique_values = np.union1d(y_true, y_pred) + try: + unique_values = np.union1d(y_true, y_pred) + except TypeError as e: + # We expect y_true and y_pred to be of the same data type. + # If `y_true` was provided to the classifier as strings, + # `y_pred` given by the classifier will also be encoded with + # strings. So we raise a meaningful error + raise TypeError( + f"Labels in y_true and y_pred should be of the same type. " + f"Got y_true={np.unique(y_true)} and " + f"y_pred={np.unique(y_pred)}. Make sure that the " + f"predictions provided by the classifier coincides with " + f"the true labels." + ) from e if len(unique_values) > 2: y_type = "multiclass" @@ -1252,13 +1265,15 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred) + present_labels = unique_labels(y_true, y_pred).tolist() if average == 'binary': if y_type == 'binary': if pos_label not in present_labels: if len(present_labels) >= 2: - raise ValueError("pos_label=%r is not a valid label: " - "%r" % (pos_label, present_labels)) + raise ValueError( + f"pos_label={pos_label} is not a valid label. It " + f"should be one of {present_labels}" + ) labels = [pos_label] else: average_options = list(average_options) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 6727de0c05c65..f786cd770ca53 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -208,10 +208,12 @@ def _binary_uninterpolated_average_precision( "multilabel-indicator y_true. Do not set " "pos_label or set pos_label to 1.") elif y_type == "binary": - present_labels = np.unique(y_true) + present_labels = np.unique(y_true).tolist() if len(present_labels) == 2 and pos_label not in present_labels: - raise ValueError("pos_label=%r is invalid. Set it to a label in " - "y_true." % pos_label) + raise ValueError( + f"pos_label={pos_label} is not a valid label. It should be " + f"one of {present_labels}" + ) average_precision = partial(_binary_uninterpolated_average_precision, pos_label=pos_label) return _average_binary_score(average_precision, y_true, y_score, diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 3a8f11ba4d89a..1a661eff93229 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1415,33 +1415,57 @@ def test_thresholded_metric_permutation_invariance(name): assert_almost_equal(score, current_score) +@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS) +def test_metrics_consistent_type_error(metric_name): + # check that an understable message is raised when the type between y_true + # and y_pred mismatch + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg = "Labels in y_true and y_pred should be of the same type." + with pytest.raises(TypeError, match=err_msg): + CLASSIFICATION_METRICS[metric_name](y1, y2) + + @pytest.mark.parametrize( - "metric", + "metric, y_pred_threshold", [ - average_precision_score, - brier_score_loss, - f1_score, - fbeta_score, - jaccard_score, - precision_recall_curve, - precision_score, - recall_score, - roc_curve, + (average_precision_score, True), + (brier_score_loss, True), + (f1_score, False), + (partial(fbeta_score, beta=1), False), + (jaccard_score, False), + (precision_recall_curve, True), + (precision_score, False), + (recall_score, False), + (roc_curve, True), ], ) -def test_metrics_pos_label_error_str(metric): +@pytest.mark.parametrize("dtype_y_str", [None, object]) +def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): # check that the error message if `pos_label` is not specified and the # targets is made of strings. rng = np.random.RandomState(42) - y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2) + if dtype_y_str is not None: + y1 = y1.astype(dtype_y_str) y2 = rng.randint(0, 2, size=y1.size) + if not y_pred_threshold: + y2 = np.array(["spam", "eggs"])[y2] + if dtype_y_str is not None: + y2 = y2.astype(dtype_y_str) + err_msg_pos_label_None = ( "y_true takes value in {'eggs', 'spam'} and pos_label is not " "specified: either make y_true take value in {0, 1} or {-1, 1} or " "pass pos_label explicit" ) - err_msg_pos_label_1 = "pos_label=1 is invalid. Set it to a label in y_true" + err_msg_pos_label_1 = ( + r"pos_label=1 is not a valid label. It should be one of " + r"\['eggs', 'spam'\]" + ) pos_label_default = signature(metric).parameters["pos_label"].default From 2d08bcdcbfddc45b041aa379581e732481988433 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 19 Aug 2020 12:04:48 +0200 Subject: [PATCH 03/10] fix --- sklearn/metrics/tests/test_classification.py | 2 +- sklearn/metrics/tests/test_ranking.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6677f3119dacd..1bfe5af3a7cdf 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1247,7 +1247,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label: array\(\[0, 1\]\)" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index e08a8909cfe72..f49e469973f97 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -888,17 +888,18 @@ def test_average_precision_score_pos_label_errors(): # Raise an error when pos_label is not in binary y_true y_true = np.array([0, 1]) y_pred = np.array([0, 1]) - error_message = ("pos_label=2 is invalid. Set it to a label in y_true.") - with pytest.raises(ValueError, match=error_message): + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" + with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=2) # Raise an error for multilabel-indicator y_true with # pos_label other than 1 y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]]) y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]]) - error_message = ("Parameter pos_label is fixed to 1 for multilabel" - "-indicator y_true. Do not set pos_label or set " - "pos_label to 1.") - with pytest.raises(ValueError, match=error_message): + err_msg = ( + "Parameter pos_label is fixed to 1 for multilabel-indicator y_true. " + "Do not set pos_label or set pos_label to 1." + ) + with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=0) From b84c5d9e4c9539c8bc05265d41f861eea4f6a250 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 26 Aug 2020 11:01:54 +0200 Subject: [PATCH 04/10] iter --- sklearn/metrics/_classification.py | 4 ++-- sklearn/metrics/_ranking.py | 5 +++-- sklearn/metrics/tests/test_classification.py | 2 +- sklearn/metrics/tests/test_common.py | 12 ++++-------- sklearn/metrics/tests/test_ranking.py | 2 +- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 3456b61bfe781..380aad62f24df 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1265,10 +1265,10 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred).tolist() + present_labels = unique_labels(y_true, y_pred) if average == 'binary': if y_type == 'binary': - if pos_label not in present_labels: + if not np.isin(pos_label, present_labels): if len(present_labels) >= 2: raise ValueError( f"pos_label={pos_label} is not a valid label. It " diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index f786cd770ca53..a6aac2cbce87f 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -29,6 +29,7 @@ from ..utils import check_consistent_length from ..utils import column_or_1d, check_array from ..utils.multiclass import type_of_target +from ..utils.multiclass import unique_labels from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero from ..utils.validation import _deprecate_positional_args @@ -208,8 +209,8 @@ def _binary_uninterpolated_average_precision( "multilabel-indicator y_true. Do not set " "pos_label or set pos_label to 1.") elif y_type == "binary": - present_labels = np.unique(y_true).tolist() - if len(present_labels) == 2 and pos_label not in present_labels: + present_labels = unique_labels(y_true) + if len(present_labels) == 2 and not np.isin(pos_label, present_labels): raise ValueError( f"pos_label={pos_label} is not a valid label. It should be " f"one of {present_labels}" diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 1bfe5af3a7cdf..e34bcb71382f4 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1247,7 +1247,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 1a661eff93229..a5a7834f052fe 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1442,20 +1442,16 @@ def test_metrics_consistent_type_error(metric_name): (roc_curve, True), ], ) -@pytest.mark.parametrize("dtype_y_str", [None, object]) +@pytest.mark.parametrize("dtype_y_str", [str, object]) def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): # check that the error message if `pos_label` is not specified and the # targets is made of strings. rng = np.random.RandomState(42) - y1 = np.array(["spam"] * 3 + ["eggs"] * 2) - if dtype_y_str is not None: - y1 = y1.astype(dtype_y_str) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) y2 = rng.randint(0, 2, size=y1.size) if not y_pred_threshold: - y2 = np.array(["spam", "eggs"])[y2] - if dtype_y_str is not None: - y2 = y2.astype(dtype_y_str) + y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2] err_msg_pos_label_None = ( "y_true takes value in {'eggs', 'spam'} and pos_label is not " @@ -1464,7 +1460,7 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): ) err_msg_pos_label_1 = ( r"pos_label=1 is not a valid label. It should be one of " - r"\['eggs', 'spam'\]" + r"\['eggs' 'spam'\]" ) pos_label_default = signature(metric).parameters["pos_label"].default diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index f49e469973f97..ea54c31485fca 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -888,7 +888,7 @@ def test_average_precision_score_pos_label_errors(): # Raise an error when pos_label is not in binary y_true y_true = np.array([0, 1]) y_pred = np.array([0, 1]) - err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0 1\]" with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=2) # Raise an error for multilabel-indicator y_true with From f019bde3b15712b8acc0e50e4192d0d70da4f99d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 31 Aug 2020 10:31:32 +0200 Subject: [PATCH 05/10] iter --- sklearn/metrics/_ranking.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index a6aac2cbce87f..3481f697e9dd7 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -29,7 +29,6 @@ from ..utils import check_consistent_length from ..utils import column_or_1d, check_array from ..utils.multiclass import type_of_target -from ..utils.multiclass import unique_labels from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero from ..utils.validation import _deprecate_positional_args @@ -209,7 +208,7 @@ def _binary_uninterpolated_average_precision( "multilabel-indicator y_true. Do not set " "pos_label or set pos_label to 1.") elif y_type == "binary": - present_labels = unique_labels(y_true) + present_labels = np.unique(y_true) if len(present_labels) == 2 and not np.isin(pos_label, present_labels): raise ValueError( f"pos_label={pos_label} is not a valid label. It should be " From cbc441f34bfd1f1309dcc9a8365c9e3e784278ce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 31 Aug 2020 10:44:39 +0200 Subject: [PATCH 06/10] iter --- sklearn/metrics/tests/test_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index a5a7834f052fe..d734eaa956597 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1432,7 +1432,10 @@ def test_metrics_consistent_type_error(metric_name): "metric, y_pred_threshold", [ (average_precision_score, True), - (brier_score_loss, True), + # FIXME: `brier_score_loss` does not follow this convention. + # See discussion in: + # https://github.com/scikit-learn/scikit-learn/issues/18307 + # (brier_score_loss, True), (f1_score, False), (partial(fbeta_score, beta=1), False), (jaccard_score, False), From c0dc3cf5b64532406e4586e228214ccce1bf7ac7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 31 Aug 2020 11:58:14 +0200 Subject: [PATCH 07/10] iter --- sklearn/metrics/_classification.py | 6 ++++-- sklearn/metrics/_ranking.py | 6 ++++-- sklearn/metrics/tests/test_common.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 380aad62f24df..8a4a148272257 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1265,10 +1265,12 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred) + # Convert to Python primitive type to avoid NumPy type / Python str + # comparison. See https://github.com/numpy/numpy/issues/6784 + present_labels = unique_labels(y_true, y_pred).tolist() if average == 'binary': if y_type == 'binary': - if not np.isin(pos_label, present_labels): + if pos_label not in present_labels: if len(present_labels) >= 2: raise ValueError( f"pos_label={pos_label} is not a valid label. It " diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 3481f697e9dd7..813c9892624ed 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -208,8 +208,10 @@ def _binary_uninterpolated_average_precision( "multilabel-indicator y_true. Do not set " "pos_label or set pos_label to 1.") elif y_type == "binary": - present_labels = np.unique(y_true) - if len(present_labels) == 2 and not np.isin(pos_label, present_labels): + # Convert to Python primitive type to avoid NumPy type / Python str + # comparison. See https://github.com/numpy/numpy/issues/6784 + present_labels = np.unique(y_true).tolist() + if len(present_labels) == 2 and pos_label not in present_labels: raise ValueError( f"pos_label={pos_label} is not a valid label. It should be " f"one of {present_labels}" diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index d734eaa956597..8235d62924fe8 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1463,7 +1463,7 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): ) err_msg_pos_label_1 = ( r"pos_label=1 is not a valid label. It should be one of " - r"\['eggs' 'spam'\]" + r"\['eggs', 'spam'\]" ) pos_label_default = signature(metric).parameters["pos_label"].default From ff89a931fb4e09d01c858d24c8d9d1ff6646119a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 31 Aug 2020 12:17:31 +0200 Subject: [PATCH 08/10] iter --- sklearn/metrics/tests/test_classification.py | 2 +- sklearn/metrics/tests/test_ranking.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index e34bcb71382f4..1bfe5af3a7cdf 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1247,7 +1247,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label. It should be one of \[0 1\]" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index ea54c31485fca..f49e469973f97 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -888,7 +888,7 @@ def test_average_precision_score_pos_label_errors(): # Raise an error when pos_label is not in binary y_true y_true = np.array([0, 1]) y_pred = np.array([0, 1]) - err_msg = r"pos_label=2 is not a valid label. It should be one of \[0 1\]" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=2) # Raise an error for multilabel-indicator y_true with From 7bc0799d9495d01b04f30c750f6bbec13393adab Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Sep 2020 10:18:45 +0200 Subject: [PATCH 09/10] Apply suggestions from code review Co-authored-by: Thomas J. Fan --- sklearn/metrics/tests/test_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 8235d62924fe8..6592d8d1b5e5b 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1435,7 +1435,8 @@ def test_metrics_consistent_type_error(metric_name): # FIXME: `brier_score_loss` does not follow this convention. # See discussion in: # https://github.com/scikit-learn/scikit-learn/issues/18307 - # (brier_score_loss, True), + pytest.param(brier_score_loss, True, + marks=pytest.mark.xfail(reason="#18307"), (f1_score, False), (partial(fbeta_score, beta=1), False), (jaccard_score, False), From 55420c5b33666d075a2938f6d9b37e8dbc53ebdc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Sep 2020 10:55:06 +0200 Subject: [PATCH 10/10] FIX syntax error --- sklearn/metrics/tests/test_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6592d8d1b5e5b..4641a7875a11d 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1435,8 +1435,9 @@ def test_metrics_consistent_type_error(metric_name): # FIXME: `brier_score_loss` does not follow this convention. # See discussion in: # https://github.com/scikit-learn/scikit-learn/issues/18307 - pytest.param(brier_score_loss, True, - marks=pytest.mark.xfail(reason="#18307"), + pytest.param( + brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307") + ), (f1_score, False), (partial(fbeta_score, beta=1), False), (jaccard_score, False),