Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 329b7ee

Browse files
aashilTomDLT
authored andcommitted
dev-7369: Make common metric tests look nicer (scikit-learn#7620)
* Change the tests to yield after performing the _named_check. * Tests affected: -test_averaging_multiclass -test_averaging_multilabel -test_averaging_multilabel_all_zeroes -test_averaging_multilabel_all_ones -test_sample_weight_invariance * File updated: sklearn/metrics/tests/test_common.py
1 parent 0e12b2f commit 329b7ee

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.utils.testing import assert_raise_message
2323
from sklearn.utils.testing import assert_true
2424
from sklearn.utils.testing import ignore_warnings
25+
from sklearn.utils.testing import _named_check
2526

2627
from sklearn.metrics import accuracy_score
2728
from sklearn.metrics import average_precision_score
@@ -891,8 +892,8 @@ def test_averaging_multiclass(n_samples=50, n_classes=3):
891892
y_pred_binarize = lb.transform(y_pred)
892893

893894
for name in METRICS_WITH_AVERAGING:
894-
yield (check_averaging, name, y_true, y_true_binarize, y_pred,
895-
y_pred_binarize, y_score)
895+
yield (_named_check(check_averaging, name), name, y_true,
896+
y_true_binarize, y_pred, y_pred_binarize, y_score)
896897

897898

898899
def test_averaging_multilabel(n_classes=5, n_samples=40):
@@ -906,8 +907,8 @@ def test_averaging_multilabel(n_classes=5, n_samples=40):
906907
y_pred_binarize = y_pred
907908

908909
for name in METRICS_WITH_AVERAGING + THRESHOLDED_METRICS_WITH_AVERAGING:
909-
yield (check_averaging, name, y_true, y_true_binarize, y_pred,
910-
y_pred_binarize, y_score)
910+
yield (_named_check(check_averaging, name), name, y_true,
911+
y_true_binarize, y_pred, y_pred_binarize, y_score)
911912

912913

913914
def test_averaging_multilabel_all_zeroes():
@@ -918,8 +919,8 @@ def test_averaging_multilabel_all_zeroes():
918919
y_pred_binarize = y_pred
919920

920921
for name in METRICS_WITH_AVERAGING:
921-
yield (check_averaging, name, y_true, y_true_binarize, y_pred,
922-
y_pred_binarize, y_score)
922+
yield (_named_check(check_averaging, name), name, y_true,
923+
y_true_binarize, y_pred, y_pred_binarize, y_score)
923924

924925
# Test _average_binary_score for weight.sum() == 0
925926
binary_metric = (lambda y_true, y_score, average="macro":
@@ -937,8 +938,8 @@ def test_averaging_multilabel_all_ones():
937938
y_pred_binarize = y_pred
938939

939940
for name in METRICS_WITH_AVERAGING:
940-
yield (check_averaging, name, y_true, y_true_binarize, y_pred,
941-
y_pred_binarize, y_score)
941+
yield (_named_check(check_averaging, name), name, y_true,
942+
y_true_binarize, y_pred, y_pred_binarize, y_score)
942943

943944

944945
@ignore_warnings
@@ -1025,9 +1026,11 @@ def test_sample_weight_invariance(n_samples=50):
10251026
continue
10261027
metric = ALL_METRICS[name]
10271028
if name in THRESHOLDED_METRICS:
1028-
yield check_sample_weight_invariance, name, metric, y_true, y_score
1029+
yield _named_check(check_sample_weight_invariance, name), name,\
1030+
metric, y_true, y_score
10291031
else:
1030-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
1032+
yield _named_check(check_sample_weight_invariance, name), name,\
1033+
metric, y_true, y_pred
10311034

10321035
# multiclass
10331036
random_state = check_random_state(0)
@@ -1040,9 +1043,11 @@ def test_sample_weight_invariance(n_samples=50):
10401043
continue
10411044
metric = ALL_METRICS[name]
10421045
if name in THRESHOLDED_METRICS:
1043-
yield check_sample_weight_invariance, name, metric, y_true, y_score
1046+
yield _named_check(check_sample_weight_invariance, name), name,\
1047+
metric, y_true, y_score
10441048
else:
1045-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
1049+
yield _named_check(check_sample_weight_invariance, name), name,\
1050+
metric, y_true, y_pred
10461051

10471052
# multilabel indicator
10481053
_, ya = make_multilabel_classification(n_features=1, n_classes=20,
@@ -1062,11 +1067,11 @@ def test_sample_weight_invariance(n_samples=50):
10621067

10631068
metric = ALL_METRICS[name]
10641069
if name in THRESHOLDED_METRICS:
1065-
yield (check_sample_weight_invariance, name, metric, y_true,
1066-
y_score)
1070+
yield (_named_check(check_sample_weight_invariance, name), name,
1071+
metric, y_true, y_score)
10671072
else:
1068-
yield (check_sample_weight_invariance, name, metric, y_true,
1069-
y_pred)
1073+
yield (_named_check(check_sample_weight_invariance, name), name,
1074+
metric, y_true, y_pred)
10701075

10711076

10721077
@ignore_warnings

0 commit comments

Comments
 (0)