22
22
from sklearn .utils .testing import assert_raise_message
23
23
from sklearn .utils .testing import assert_true
24
24
from sklearn .utils .testing import ignore_warnings
25
+ from sklearn .utils .testing import _named_check
25
26
26
27
from sklearn .metrics import accuracy_score
27
28
from sklearn .metrics import average_precision_score
@@ -891,8 +892,8 @@ def test_averaging_multiclass(n_samples=50, n_classes=3):
891
892
y_pred_binarize = lb .transform (y_pred )
892
893
893
894
for name in METRICS_WITH_AVERAGING :
894
- yield (check_averaging , name , y_true , y_true_binarize , y_pred ,
895
- y_pred_binarize , y_score )
895
+ yield (_named_check ( check_averaging , name ), name , y_true ,
896
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
896
897
897
898
898
899
def test_averaging_multilabel (n_classes = 5 , n_samples = 40 ):
@@ -906,8 +907,8 @@ def test_averaging_multilabel(n_classes=5, n_samples=40):
906
907
y_pred_binarize = y_pred
907
908
908
909
for name in METRICS_WITH_AVERAGING + THRESHOLDED_METRICS_WITH_AVERAGING :
909
- yield (check_averaging , name , y_true , y_true_binarize , y_pred ,
910
- y_pred_binarize , y_score )
910
+ yield (_named_check ( check_averaging , name ), name , y_true ,
911
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
911
912
912
913
913
914
def test_averaging_multilabel_all_zeroes ():
@@ -918,8 +919,8 @@ def test_averaging_multilabel_all_zeroes():
918
919
y_pred_binarize = y_pred
919
920
920
921
for name in METRICS_WITH_AVERAGING :
921
- yield (check_averaging , name , y_true , y_true_binarize , y_pred ,
922
- y_pred_binarize , y_score )
922
+ yield (_named_check ( check_averaging , name ), name , y_true ,
923
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
923
924
924
925
# Test _average_binary_score for weight.sum() == 0
925
926
binary_metric = (lambda y_true , y_score , average = "macro" :
@@ -937,8 +938,8 @@ def test_averaging_multilabel_all_ones():
937
938
y_pred_binarize = y_pred
938
939
939
940
for name in METRICS_WITH_AVERAGING :
940
- yield (check_averaging , name , y_true , y_true_binarize , y_pred ,
941
- y_pred_binarize , y_score )
941
+ yield (_named_check ( check_averaging , name ), name , y_true ,
942
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
942
943
943
944
944
945
@ignore_warnings
@@ -1025,9 +1026,11 @@ def test_sample_weight_invariance(n_samples=50):
1025
1026
continue
1026
1027
metric = ALL_METRICS [name ]
1027
1028
if name in THRESHOLDED_METRICS :
1028
- yield check_sample_weight_invariance , name , metric , y_true , y_score
1029
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1030
+ metric , y_true , y_score
1029
1031
else :
1030
- yield check_sample_weight_invariance , name , metric , y_true , y_pred
1032
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1033
+ metric , y_true , y_pred
1031
1034
1032
1035
# multiclass
1033
1036
random_state = check_random_state (0 )
@@ -1040,9 +1043,11 @@ def test_sample_weight_invariance(n_samples=50):
1040
1043
continue
1041
1044
metric = ALL_METRICS [name ]
1042
1045
if name in THRESHOLDED_METRICS :
1043
- yield check_sample_weight_invariance , name , metric , y_true , y_score
1046
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1047
+ metric , y_true , y_score
1044
1048
else :
1045
- yield check_sample_weight_invariance , name , metric , y_true , y_pred
1049
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1050
+ metric , y_true , y_pred
1046
1051
1047
1052
# multilabel indicator
1048
1053
_ , ya = make_multilabel_classification (n_features = 1 , n_classes = 20 ,
@@ -1062,11 +1067,11 @@ def test_sample_weight_invariance(n_samples=50):
1062
1067
1063
1068
metric = ALL_METRICS [name ]
1064
1069
if name in THRESHOLDED_METRICS :
1065
- yield (check_sample_weight_invariance , name , metric , y_true ,
1066
- y_score )
1070
+ yield (_named_check ( check_sample_weight_invariance , name ), name ,
1071
+ metric , y_true , y_score )
1067
1072
else :
1068
- yield (check_sample_weight_invariance , name , metric , y_true ,
1069
- y_pred )
1073
+ yield (_named_check ( check_sample_weight_invariance , name ), name ,
1074
+ metric , y_true , y_pred )
1070
1075
1071
1076
1072
1077
@ignore_warnings
0 commit comments