Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5fd9e03

Browse files
ENH Raise an error when pos_label is not in binary y_true (#12313)
1 parent 4e2e1fa commit 5fd9e03

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

sklearn/metrics/ranking.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def _binary_uninterpolated_average_precision(
230230
raise ValueError("Parameter pos_label is fixed to 1 for "
231231
"multilabel-indicator y_true. Do not set "
232232
"pos_label or set pos_label to 1.")
233+
elif y_type == "binary":
234+
present_labels = np.unique(y_true)
235+
if len(present_labels) == 2 and pos_label not in present_labels:
236+
raise ValueError("pos_label=%r is invalid. Set it to a label in "
237+
"y_true." % pos_label)
233238
average_precision = partial(_binary_uninterpolated_average_precision,
234239
pos_label=pos_label)
235240
return _average_binary_score(average_precision, y_true, y_score,

sklearn/metrics/tests/test_ranking.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,15 +682,21 @@ def test_average_precision_constant_values():
682682
assert_equal(average_precision_score(y_true, y_score), .25)
683683

684684

685-
def test_average_precision_score_pos_label_multilabel_indicator():
685+
def test_average_precision_score_pos_label_errors():
686+
# Raise an error when pos_label is not in binary y_true
687+
y_true = np.array([0, 1])
688+
y_pred = np.array([0, 1])
689+
error_message = ("pos_label=2 is invalid. Set it to a label in y_true.")
690+
assert_raise_message(ValueError, error_message, average_precision_score,
691+
y_true, y_pred, pos_label=2)
686692
# Raise an error for multilabel-indicator y_true with
687693
# pos_label other than 1
688694
y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
689695
y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]])
690-
erorr_message = ("Parameter pos_label is fixed to 1 for multilabel"
696+
error_message = ("Parameter pos_label is fixed to 1 for multilabel"
691697
"-indicator y_true. Do not set pos_label or set "
692698
"pos_label to 1.")
693-
assert_raise_message(ValueError, erorr_message, average_precision_score,
699+
assert_raise_message(ValueError, error_message, average_precision_score,
694700
y_true, y_pred, pos_label=0)
695701

696702

0 commit comments

Comments
 (0)