diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 00db32e1ef389..5b9f6d2b69027 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -369,11 +369,11 @@ # No Sample weight support METRICS_WITHOUT_SAMPLE_WEIGHT = [ - "confusion_matrix", # Left this one here because the tests in this file do - # not work for confusion_matrix, as its output is a - # matrix instead of a number. Testing of - # confusion_matrix with sample_weight is in - # test_classification.py + "confusion_matrix", # Left this one here because the tests in this file do + # not work for confusion_matrix, as its output is a + # matrix instead of a number. Testing of + # confusion_matrix with sample_weight is in + # test_classification.py "median_absolute_error", ] @@ -619,9 +619,9 @@ def test_invariance_string_vs_numbers_labels(): def test_inf_nan_input(): - invalids =[([0, 1], [np.inf, np.inf]), - ([0, 1], [np.nan, np.nan]), - ([0, 1], [np.nan, np.inf])] + invalids = [([0, 1], [np.inf, np.inf]), + ([0, 1], [np.nan, np.nan]), + ([0, 1], [np.nan, np.inf])] METRICS = dict() METRICS.update(THRESHOLDED_METRICS) @@ -1011,7 +1011,8 @@ def check_sample_weight_invariance(name, metric, y1, y2): sample_weight=np.hstack([sample_weight, sample_weight])) -def test_sample_weight_invariance(n_samples=50): +def generate_sample_weight_invariance(n_samples=50): + # create generative function to iterate through each relevant metric random_state = check_random_state(0) # regression y_true = random_state.random_sample(size=(n_samples,)) @@ -1088,6 +1089,12 @@ def test_sample_weight_invariance(n_samples=50): metric, y_true, y_pred) +def test_sample_weight_invariance(n_samples=50): + # iterate through each metric testing each case + for metrics in generate_sample_weight_invariance(n_samples): + pass + + @ignore_warnings def test_no_averaging_labels(): # test labels argument when not using averaging