@@ -682,15 +682,21 @@ def test_average_precision_constant_values():
682
682
assert_equal (average_precision_score (y_true , y_score ), .25 )
683
683
684
684
685
- def test_average_precision_score_pos_label_multilabel_indicator ():
685
+ def test_average_precision_score_pos_label_errors ():
686
+ # Raise an error when pos_label is not in binary y_true
687
+ y_true = np .array ([0 , 1 ])
688
+ y_pred = np .array ([0 , 1 ])
689
+ error_message = ("pos_label=2 is invalid. Set it to a label in y_true." )
690
+ assert_raise_message (ValueError , error_message , average_precision_score ,
691
+ y_true , y_pred , pos_label = 2 )
686
692
# Raise an error for multilabel-indicator y_true with
687
693
# pos_label other than 1
688
694
y_true = np .array ([[1 , 0 ], [0 , 1 ], [0 , 1 ], [1 , 0 ]])
689
695
y_pred = np .array ([[0.9 , 0.1 ], [0.1 , 0.9 ], [0.8 , 0.2 ], [0.2 , 0.8 ]])
690
- erorr_message = ("Parameter pos_label is fixed to 1 for multilabel"
696
+ error_message = ("Parameter pos_label is fixed to 1 for multilabel"
691
697
"-indicator y_true. Do not set pos_label or set "
692
698
"pos_label to 1." )
693
- assert_raise_message (ValueError , erorr_message , average_precision_score ,
699
+ assert_raise_message (ValueError , error_message , average_precision_score ,
694
700
y_true , y_pred , pos_label = 0 )
695
701
696
702
0 commit comments