diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index e9f733d1e6ae0..bad51d4f045bb 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -364,6 +364,13 @@ Changelog :mod:`sklearn.metrics` ...................... +- |Feature| Added a new parameter ``zero_division`` to multiple classification + metrics: :func:`precision_score`, :func:`recall_score`, :func:`f1_score`, + :func:`fbeta_score`, :func:`precision_recall_fscore_support`, + :func:`classification_report`. This allows to set returned value for + ill-defined metrics. + :pr:`14900` by :user:`Marc Torrellas Socastro `. + - |MajorFeature| :func:`metrics.plot_roc_curve` has been added to plot roc curves. This function introduces the visualization API described in the :ref:`User Guide `. :pr:`14357` by `Thomas Fan`_. diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index 14a3ab15af545..fdbfd52425a41 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -41,6 +41,15 @@ from ..exceptions import UndefinedMetricWarning +def _check_zero_division(zero_division): + if isinstance(zero_division, str) and zero_division == "warn": + return + elif isinstance(zero_division, (int, float)) and zero_division in [0, 1]: + return + raise ValueError('Got zero_division={0}.' + ' Must be one of ["warn", 0, 1]'.format(zero_division)) + + def _check_targets(y_true, y_pred): """Check that y_true and y_pred belong to the same classification task @@ -947,7 +956,7 @@ def zero_one_loss(y_true, y_pred, normalize=True, sample_weight=None): def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', - sample_weight=None): + sample_weight=None, zero_division="warn"): """Compute the F1 score, also known as balanced F-score or F-measure The F1 score can be interpreted as a weighted average of the precision and @@ -1017,6 +1026,11 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight : array-like of shape = [n_samples], optional Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division, i.e. when all + predictions and labels are negative. If set to "warn", this acts as 0, + but warnings are also raised. + Returns ------- f1_score : float or array of float, shape = [n_unique_labels] @@ -1046,20 +1060,27 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', 0.26... >>> f1_score(y_true, y_pred, average=None) array([0.8, 0. , 0. ]) + >>> y_true = [0, 0, 0, 0, 0, 0] + >>> y_pred = [0, 0, 0, 0, 0, 0] + >>> f1_score(y_true, y_pred, zero_division=1) + 1.0... Notes ----- - When ``true positive + false positive == 0`` or - ``true positive + false negative == 0``, f-score returns 0 and raises - ``UndefinedMetricWarning``. + When ``true positive + false positive == 0``, precision is undefined; + When ``true positive + false negative == 0``, recall is undefined. + In such cases, by default the metric will be set to 0, as will f-score, + and ``UndefinedMetricWarning`` will be raised. This behavior can be + modified with ``zero_division``. """ return fbeta_score(y_true, y_pred, 1, labels=labels, pos_label=pos_label, average=average, - sample_weight=sample_weight) + sample_weight=sample_weight, + zero_division=zero_division) def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, - average='binary', sample_weight=None): + average='binary', sample_weight=None, zero_division="warn"): """Compute the F-beta score The F-beta score is the weighted harmonic mean of precision and recall, @@ -1129,6 +1150,11 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, sample_weight : array-like of shape = [n_samples], optional Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division, i.e. when all + predictions and labels are negative. If set to "warn", this acts as 0, + but warnings are also raised. + Returns ------- fbeta_score : float (if average is not None) or array of float, shape =\ @@ -1166,23 +1192,28 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1, ----- When ``true positive + false positive == 0`` or ``true positive + false negative == 0``, f-score returns 0 and raises - ``UndefinedMetricWarning``. + ``UndefinedMetricWarning``. This behavior can be + modified with ``zero_division``. """ + _, _, f, _ = precision_recall_fscore_support(y_true, y_pred, beta=beta, labels=labels, pos_label=pos_label, average=average, warn_for=('f-score',), - sample_weight=sample_weight) + sample_weight=sample_weight, + zero_division=zero_division) return f -def _prf_divide(numerator, denominator, metric, modifier, average, warn_for): +def _prf_divide(numerator, denominator, metric, + modifier, average, warn_for, zero_division="warn"): """Performs division and handles divide-by-zero. - On zero-division, sets the corresponding result elements to zero - and raises a warning. + On zero-division, sets the corresponding result elements equal to + 0 or 1 (according to ``zero_division``). Plus, if + ``zero_division != "warn"`` raises a warning. The metric, modifier and average arguments are used only for determining an appropriate warning. @@ -1191,16 +1222,23 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for): denominator = denominator.copy() denominator[mask] = 1 # avoid infs/nans result = numerator / denominator + if not np.any(mask): return result + # if ``zero_division=1``, set those with denominator == 0 equal to 1 + result[mask] = 0.0 if zero_division in ["warn", 0] else 1.0 + + # the user will be removing warnings if zero_division is set to something + # different than its default value. If we are computing only f-score + # the warning will be raised only if precision and recall are ill-defined + if zero_division != "warn" or metric not in warn_for: + return result + # build appropriate warning # E.g. "Precision and F-score are ill-defined and being set to 0.0 in - # labels with no predicted samples" - axis0 = 'sample' - axis1 = 'label' - if average == 'samples': - axis0, axis1 = axis1, axis0 + # labels with no predicted samples. Use ``zero_division`` parameter to + # control this behavior." if metric in warn_for and 'f-score' in warn_for: msg_start = '{0} and F-score are'.format(metric.title()) @@ -1211,14 +1249,23 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for): else: return result + _warn_prf(average, modifier, msg_start, len(result)) + + return result + + +def _warn_prf(average, modifier, msg_start, result_size): + axis0, axis1 = 'sample', 'label' + if average == 'samples': + axis0, axis1 = axis1, axis0 msg = ('{0} ill-defined and being set to 0.0 {{0}} ' - 'no {1} {2}s.'.format(msg_start, modifier, axis0)) - if len(mask) == 1: + 'no {1} {2}s. Use `zero_division` parameter to control' + ' this behavior.'.format(msg_start, modifier, axis0)) + if result_size == 1: msg = msg.format('due to') else: msg = msg.format('in {0}s with'.format(axis1)) warnings.warn(msg, UndefinedMetricWarning, stacklevel=2) - return result def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): @@ -1259,7 +1306,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, pos_label=1, average=None, warn_for=('precision', 'recall', 'f-score'), - sample_weight=None): + sample_weight=None, + zero_division="warn"): """Compute precision, recall, F-measure and support for each class The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of @@ -1343,6 +1391,13 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, sample_weight : array-like of shape = [n_samples], optional Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division: + - recall: when there are no positive labels + - precision: when there are no positive predictions + - f-score: both + If set to "warn", this acts as 0, but warnings are also raised. + Returns ------- precision : float (if average is not None) or array of float, shape =\ @@ -1397,9 +1452,11 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, ----- When ``true positive + false positive == 0``, precision is undefined; When ``true positive + false negative == 0``, recall is undefined. - In such cases, the metric will be set to 0, as will f-score, and - ``UndefinedMetricWarning`` will be raised. + In such cases, by default the metric will be set to 0, as will f-score, + and ``UndefinedMetricWarning`` will be raised. This behavior can be + modified with ``zero_division``. """ + _check_zero_division(zero_division) if beta < 0: raise ValueError("beta should be >=0 in the F-beta score") labels = _check_set_wise_labels(y_true, y_pred, average, labels, @@ -1422,18 +1479,28 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, # Finally, we have all our sufficient statistics. Divide! # beta2 = beta ** 2 - # Divide, and on zero-division, set scores to 0 and warn: + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + precision = _prf_divide(tp_sum, pred_sum, 'precision', + 'predicted', average, warn_for, zero_division) + recall = _prf_divide(tp_sum, true_sum, 'recall', + 'true', average, warn_for, zero_division) + + # warn for f-score only if zero_division is warn, it is in warn_for + # and BOTH prec and rec are ill-defined + if zero_division == "warn" and ("f-score",) == warn_for: + if (pred_sum[true_sum == 0] == 0).any(): + _warn_prf( + average, "true nor predicted", 'F-score is', len(true_sum) + ) - precision = _prf_divide(tp_sum, pred_sum, - 'precision', 'predicted', average, warn_for) - recall = _prf_divide(tp_sum, true_sum, - 'recall', 'true', average, warn_for) + # if tp == 0 F will be 1 only if all predictions are zero, all labels are + # zero, and zero_division=1. In all other case, 0 if np.isposinf(beta): f_score = recall else: - # Don't need to warn for F: either P or R warned, or tp == 0 where pos - # and true are nonzero, in which case, F is well-defined and zero denom = beta2 * precision + recall + denom[denom == 0.] = 1 # avoid division by 0 f_score = (1 + beta2) * precision * recall / denom @@ -1441,7 +1508,16 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, if average == 'weighted': weights = true_sum if weights.sum() == 0: - return 0, 0, 0, None + zero_division_value = 0.0 if zero_division in ["warn", 0] else 1.0 + # precision is zero_division if there are no positive predictions + # recall is zero_division if there are no positive labels + # fscore is zero_division if all labels AND predictions are + # negative + return (zero_division_value if pred_sum.sum() == 0 else 0, + zero_division_value, + zero_division_value if pred_sum.sum() == 0 else 0, + None) + elif average == 'samples': weights = sample_weight else: @@ -1458,7 +1534,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, def precision_score(y_true, y_pred, labels=None, pos_label=1, - average='binary', sample_weight=None): + average='binary', sample_weight=None, + zero_division="warn"): """Compute the precision The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of @@ -1524,6 +1601,10 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, sample_weight : array-like of shape = [n_samples], optional Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + Returns ------- precision : float (if average is not None) or array of float, shape =\ @@ -1548,23 +1629,31 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, 0.22... >>> precision_score(y_true, y_pred, average=None) array([0.66..., 0. , 0. ]) + >>> y_pred = [0, 0, 0, 0, 0, 0] + >>> precision_score(y_true, y_pred, average=None) + array([0.33..., 0. , 0. ]) + >>> precision_score(y_true, y_pred, average=None, zero_division=1) + array([0.33..., 1. , 1. ]) Notes ----- When ``true positive + false positive == 0``, precision returns 0 and - raises ``UndefinedMetricWarning``. + raises ``UndefinedMetricWarning``. This behavior can be + modified with ``zero_division``. + """ p, _, _, _ = precision_recall_fscore_support(y_true, y_pred, labels=labels, pos_label=pos_label, average=average, warn_for=('precision',), - sample_weight=sample_weight) + sample_weight=sample_weight, + zero_division=zero_division) return p def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', - sample_weight=None): + sample_weight=None, zero_division="warn"): """Compute the recall The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of @@ -1629,6 +1718,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight : array-like of shape = [n_samples], optional Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + Returns ------- recall : float (if average is not None) or array of float, shape =\ @@ -1654,18 +1747,25 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', 0.33... >>> recall_score(y_true, y_pred, average=None) array([1., 0., 0.]) + >>> y_true = [0, 0, 0, 0, 0, 0] + >>> recall_score(y_true, y_pred, average=None) + array([0.5, 0. , 0. ]) + >>> recall_score(y_true, y_pred, average=None, zero_division=1) + array([0.5, 1. , 1. ]) Notes ----- When ``true positive + false negative == 0``, recall returns 0 and raises - ``UndefinedMetricWarning``. + ``UndefinedMetricWarning``. This behavior can be modified with + ``zero_division``. """ _, r, _, _ = precision_recall_fscore_support(y_true, y_pred, labels=labels, pos_label=pos_label, average=average, warn_for=('recall',), - sample_weight=sample_weight) + sample_weight=sample_weight, + zero_division=zero_division) return r @@ -1747,7 +1847,8 @@ def balanced_accuracy_score(y_true, y_pred, sample_weight=None, def classification_report(y_true, y_pred, labels=None, target_names=None, - sample_weight=None, digits=2, output_dict=False): + sample_weight=None, digits=2, output_dict=False, + zero_division="warn"): """Build a text report showing the main classification metrics Read more in the :ref:`User Guide `. @@ -1777,6 +1878,10 @@ def classification_report(y_true, y_pred, labels=None, target_names=None, output_dict : bool (default = False) If True, return output as dict + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + Returns ------- report : string / dict @@ -1876,7 +1981,8 @@ class 2 1.00 0.67 0.80 3 p, r, f1, s = precision_recall_fscore_support(y_true, y_pred, labels=labels, average=None, - sample_weight=sample_weight) + sample_weight=sample_weight, + zero_division=zero_division) rows = zip(target_names, p, r, f1, s) if y_type.startswith('multilabel'): diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 5373d9af56d84..f668b253b553b 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -287,7 +287,7 @@ def test_precision_recall_f_ignored_labels(): # ensure the above were meaningful tests: for average in ['macro', 'weighted', 'micro']: assert (recall_13(average=average) != - recall_all(average=average)) + recall_all(average=average)) def test_average_precision_score_score_non_binary_class(): @@ -1450,28 +1450,33 @@ def test_precision_recall_f1_score_multilabel_2(): @ignore_warnings -def test_precision_recall_f1_score_with_an_empty_prediction(): +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_precision_recall_f1_score_with_an_empty_prediction(zero_division): y_true = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 1, 0]]) y_pred = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 0]]) # true_pos = [ 0. 1. 1. 0.] # false_pos = [ 0. 0. 0. 1.] # false_neg = [ 1. 1. 0. 0.] + zero_division = 1.0 if zero_division == 1.0 else 0.0 p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average=None) - assert_array_almost_equal(p, [0.0, 1.0, 1.0, 0.0], 2) - assert_array_almost_equal(r, [0.0, 0.5, 1.0, 0.0], 2) + average=None, + zero_division=zero_division) + assert_array_almost_equal(p, [zero_division, 1.0, 1.0, 0.0], 2) + assert_array_almost_equal(r, [0.0, 0.5, 1.0, zero_division], 2) assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2) assert_array_almost_equal(s, [1, 2, 1, 0], 2) - f2 = fbeta_score(y_true, y_pred, beta=2, average=None) + f2 = fbeta_score(y_true, y_pred, beta=2, average=None, + zero_division=zero_division) support = s assert_array_almost_equal(f2, [0, 0.55, 1, 0], 2) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="macro") - assert_almost_equal(p, 0.5) - assert_almost_equal(r, 1.5 / 4) + average="macro", + zero_division=zero_division) + assert_almost_equal(p, (2 + zero_division) / 4) + assert_almost_equal(r, (1.5 + zero_division) / 4) assert_almost_equal(f, 2.5 / (4 * 1.5)) assert s is None assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, @@ -1479,24 +1484,29 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): np.mean(f2)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="micro") + average="micro", + zero_division=zero_division) assert_almost_equal(p, 2 / 3) assert_almost_equal(r, 0.5) assert_almost_equal(f, 2 / 3 / (2 / 3 + 0.5)) assert s is None assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, - average="micro"), + average="micro", + zero_division=zero_division), (1 + 4) * p * r / (4 * p + r)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="weighted") - assert_almost_equal(p, 3 / 4) + average="weighted", + zero_division=zero_division) + assert_almost_equal(p, 3 / 4 if zero_division == 0 else 1.0) assert_almost_equal(r, 0.5) assert_almost_equal(f, (2 / 1.5 + 1) / 4) assert s is None assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, - average="weighted"), - np.average(f2, weights=support)) + average="weighted", + zero_division=zero_division), + np.average(f2, weights=support), + ) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples") @@ -1508,36 +1518,93 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): assert_almost_equal(f, 1 / 3) assert s is None assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, - average="samples"), + average="samples", + zero_division=zero_division), 0.333, 2) @pytest.mark.parametrize('beta', [1]) @pytest.mark.parametrize('average', ["macro", "micro", "weighted", "samples"]) -def test_precision_recall_f1_no_labels(beta, average): +@pytest.mark.parametrize('zero_division', [0, 1]) +def test_precision_recall_f1_no_labels(beta, average, zero_division): + y_true = np.zeros((20, 3)) + y_pred = np.zeros_like(y_true) + + p, r, f, s = assert_no_warnings(precision_recall_fscore_support, y_true, + y_pred, average=average, beta=beta, + zero_division=zero_division) + fbeta = assert_no_warnings(fbeta_score, y_true, y_pred, beta=beta, + average=average, zero_division=zero_division) + + zero_division = float(zero_division) + assert_almost_equal(p, zero_division) + assert_almost_equal(r, zero_division) + assert_almost_equal(f, zero_division) + assert s is None + + assert_almost_equal(fbeta, float(zero_division)) + + +@pytest.mark.parametrize('average', ["macro", "micro", "weighted", "samples"]) +def test_precision_recall_f1_no_labels_check_warnings(average): y_true = np.zeros((20, 3)) y_pred = np.zeros_like(y_true) - p, r, f, s = assert_warns(UndefinedMetricWarning, - precision_recall_fscore_support, - y_true, y_pred, average=average, - beta=beta) + func = precision_recall_fscore_support + with pytest.warns(UndefinedMetricWarning): + p, r, f, s = func(y_true, y_pred, average=average, beta=1.0) + assert_almost_equal(p, 0) assert_almost_equal(r, 0) assert_almost_equal(f, 0) assert s is None - fbeta = assert_warns(UndefinedMetricWarning, fbeta_score, - y_true, y_pred, - beta=beta, average=average) + with pytest.warns(UndefinedMetricWarning): + fbeta = fbeta_score(y_true, y_pred, average=average, beta=1.0) + assert_almost_equal(fbeta, 0) -def test_precision_recall_f1_no_labels_average_none(): +@pytest.mark.parametrize('zero_division', [0, 1]) +def test_precision_recall_f1_no_labels_average_none(zero_division): y_true = np.zeros((20, 3)) y_pred = np.zeros_like(y_true) - beta = 1 + # tp = [0, 0, 0] + # fn = [0, 0, 0] + # fp = [0, 0, 0] + # support = [0, 0, 0] + # |y_hat_i inter y_i | = [0, 0, 0] + # |y_i| = [0, 0, 0] + # |y_hat_i| = [0, 0, 0] + + p, r, f, s = assert_no_warnings(precision_recall_fscore_support, + y_true, y_pred, + average=None, beta=1.0, + zero_division=zero_division) + fbeta = assert_no_warnings(fbeta_score, y_true, y_pred, beta=1.0, + average=None, zero_division=zero_division) + + zero_division = float(zero_division) + assert_array_almost_equal( + p, [zero_division, zero_division, zero_division], 2 + ) + assert_array_almost_equal( + r, [zero_division, zero_division, zero_division], 2 + ) + assert_array_almost_equal( + f, [zero_division, zero_division, zero_division], 2 + ) + assert_array_almost_equal(s, [0, 0, 0], 2) + + assert_array_almost_equal( + fbeta, [zero_division, zero_division, zero_division], 2 + ) + + +def test_precision_recall_f1_no_labels_average_none_warn(): + y_true = np.zeros((20, 3)) + y_pred = np.zeros_like(y_true) # tp = [0, 0, 0] # fn = [0, 0, 0] @@ -1547,138 +1614,227 @@ def test_precision_recall_f1_no_labels_average_none(): # |y_i| = [0, 0, 0] # |y_hat_i| = [0, 0, 0] - p, r, f, s = assert_warns(UndefinedMetricWarning, - precision_recall_fscore_support, - y_true, y_pred, average=None, beta=beta) + with pytest.warns(UndefinedMetricWarning): + p, r, f, s = precision_recall_fscore_support( + y_true, y_pred, average=None, beta=1 + ) + assert_array_almost_equal(p, [0, 0, 0], 2) assert_array_almost_equal(r, [0, 0, 0], 2) assert_array_almost_equal(f, [0, 0, 0], 2) assert_array_almost_equal(s, [0, 0, 0], 2) - fbeta = assert_warns(UndefinedMetricWarning, fbeta_score, - y_true, y_pred, beta=beta, average=None) + with pytest.warns(UndefinedMetricWarning): + fbeta = fbeta_score(y_true, y_pred, beta=1, average=None) + assert_array_almost_equal(fbeta, [0, 0, 0], 2) def test_prf_warnings(): # average of per-label scores f, w = precision_recall_fscore_support, UndefinedMetricWarning - my_assert = assert_warns_message for average in [None, 'weighted', 'macro']: + msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 in labels with no predicted samples.') - my_assert(w, msg, f, [0, 1, 2], [1, 1, 2], average=average) + 'being set to 0.0 in labels with no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, [0, 1, 2], [1, 1, 2], average=average) msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 in labels with no true samples.') - my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average) + 'being set to 0.0 in labels with no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, [1, 1, 2], [0, 1, 2], average=average) # average of per-sample scores msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 in samples with no predicted labels.') - my_assert(w, msg, f, np.array([[1, 0], [1, 0]]), - np.array([[1, 0], [0, 0]]), average='samples') + 'being set to 0.0 in samples with no predicted labels.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, np.array([[1, 0], [1, 0]]), + np.array([[1, 0], [0, 0]]), average='samples') msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 in samples with no true labels.') - my_assert(w, msg, f, np.array([[1, 0], [0, 0]]), - np.array([[1, 0], [1, 0]]), - average='samples') + 'being set to 0.0 in samples with no true labels.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, np.array([[1, 0], [0, 0]]), + np.array([[1, 0], [1, 0]]), average='samples') # single score: micro-average msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 due to no predicted samples.') - my_assert(w, msg, f, np.array([[1, 1], [1, 1]]), - np.array([[0, 0], [0, 0]]), average='micro') + 'being set to 0.0 due to no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, np.array([[1, 1], [1, 1]]), + np.array([[0, 0], [0, 0]]), average='micro') msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 due to no true samples.') - my_assert(w, msg, f, np.array([[0, 0], [0, 0]]), - np.array([[1, 1], [1, 1]]), average='micro') + 'being set to 0.0 due to no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, np.array([[0, 0], [0, 0]]), + np.array([[1, 1], [1, 1]]), average='micro') # single positive label msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 due to no predicted samples.') - my_assert(w, msg, f, [1, 1], [-1, -1], average='binary') + 'being set to 0.0 due to no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, [1, 1], [-1, -1], average='binary') msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 due to no true samples.') - my_assert(w, msg, f, [-1, -1], [1, 1], average='binary') + 'being set to 0.0 due to no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + assert_warns_message(w, msg, f, [-1, -1], [1, 1], average='binary') with warnings.catch_warnings(record=True) as record: warnings.simplefilter('always') precision_recall_fscore_support([0, 0], [0, 0], average="binary") msg = ('Recall and F-score are ill-defined and ' - 'being set to 0.0 due to no true samples.') + 'being set to 0.0 due to no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') assert str(record.pop().message) == msg msg = ('Precision and F-score are ill-defined and ' - 'being set to 0.0 due to no predicted samples.') + 'being set to 0.0 due to no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') assert str(record.pop().message) == msg -def test_recall_warnings(): +@pytest.mark.parametrize('zero_division', [0, 1]) +def test_prf_no_warnings_if_zero_division_set(zero_division): + # average of per-label scores + f = precision_recall_fscore_support + for average in [None, 'weighted', 'macro']: + + assert_no_warnings(f, [0, 1, 2], [1, 1, 2], average=average, + zero_division=zero_division) + + assert_no_warnings(f, [1, 1, 2], [0, 1, 2], average=average, + zero_division=zero_division) + + # average of per-sample scores + assert_no_warnings(f, np.array([[1, 0], [1, 0]]), + np.array([[1, 0], [0, 0]]), average='samples', + zero_division=zero_division) + + assert_no_warnings(f, np.array([[1, 0], [0, 0]]), + np.array([[1, 0], [1, 0]]), + average='samples', zero_division=zero_division) + + # single score: micro-average + assert_no_warnings(f, np.array([[1, 1], [1, 1]]), + np.array([[0, 0], [0, 0]]), average='micro', + zero_division=zero_division) + + assert_no_warnings(f, np.array([[0, 0], [0, 0]]), + np.array([[1, 1], [1, 1]]), average='micro', + zero_division=zero_division) + + # single positive label + assert_no_warnings(f, [1, 1], [-1, -1], average='binary', + zero_division=zero_division) + + assert_no_warnings(f, [-1, -1], [1, 1], average='binary', + zero_division=zero_division) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter('always') + precision_recall_fscore_support([0, 0], [0, 0], average="binary", + zero_division=zero_division) + assert len(record) == 0 + + +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_recall_warnings(zero_division): assert_no_warnings(recall_score, np.array([[1, 1], [1, 1]]), np.array([[0, 0], [0, 0]]), - average='micro') + average='micro', zero_division=zero_division) with warnings.catch_warnings(record=True) as record: warnings.simplefilter('always') recall_score(np.array([[0, 0], [0, 0]]), np.array([[1, 1], [1, 1]]), - average='micro') - assert (str(record.pop().message) == - 'Recall is ill-defined and ' - 'being set to 0.0 due to no true samples.') + average='micro', zero_division=zero_division) + if zero_division == "warn": + assert (str(record.pop().message) == + 'Recall is ill-defined and ' + 'being set to 0.0 due to no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + else: + assert len(record) == 0 + recall_score([0, 0], [0, 0]) - assert (str(record.pop().message) == - 'Recall is ill-defined and ' - 'being set to 0.0 due to no true samples.') + if zero_division == "warn": + assert (str(record.pop().message) == + 'Recall is ill-defined and ' + 'being set to 0.0 due to no true samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') -def test_precision_warnings(): +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_precision_warnings(zero_division): with warnings.catch_warnings(record=True) as record: warnings.simplefilter('always') precision_score(np.array([[1, 1], [1, 1]]), np.array([[0, 0], [0, 0]]), - average='micro') - assert (str(record.pop().message) == - 'Precision is ill-defined and ' - 'being set to 0.0 due to no predicted samples.') + average='micro', zero_division=zero_division) + if zero_division == "warn": + assert (str(record.pop().message) == + 'Precision is ill-defined and ' + 'being set to 0.0 due to no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') + else: + assert len(record) == 0 + precision_score([0, 0], [0, 0]) - assert (str(record.pop().message) == - 'Precision is ill-defined and ' - 'being set to 0.0 due to no predicted samples.') + if zero_division == "warn": + assert (str(record.pop().message) == + 'Precision is ill-defined and ' + 'being set to 0.0 due to no predicted samples.' + ' Use `zero_division` parameter to control' + ' this behavior.') assert_no_warnings(precision_score, np.array([[0, 0], [0, 0]]), np.array([[1, 1], [1, 1]]), - average='micro') + average='micro', zero_division=zero_division) -def test_fscore_warnings(): +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_fscore_warnings(zero_division): with warnings.catch_warnings(record=True) as record: warnings.simplefilter('always') for score in [f1_score, partial(fbeta_score, beta=2)]: score(np.array([[1, 1], [1, 1]]), np.array([[0, 0], [0, 0]]), - average='micro') - assert (str(record.pop().message) == - 'F-score is ill-defined and ' - 'being set to 0.0 due to no predicted samples.') + average='micro', zero_division=zero_division) + assert len(record) == 0 + score(np.array([[0, 0], [0, 0]]), np.array([[1, 1], [1, 1]]), - average='micro') - assert (str(record.pop().message) == - 'F-score is ill-defined and ' - 'being set to 0.0 due to no true samples.') - score([0, 0], [0, 0]) - assert (str(record.pop().message) == - 'F-score is ill-defined and ' - 'being set to 0.0 due to no true samples.') - assert (str(record.pop().message) == - 'F-score is ill-defined and ' - 'being set to 0.0 due to no predicted samples.') + average='micro', zero_division=zero_division) + assert len(record) == 0 + + score(np.array([[0, 0], [0, 0]]), + np.array([[0, 0], [0, 0]]), + average='micro', zero_division=zero_division) + if zero_division == "warn": + assert (str(record.pop().message) == + 'F-score is ill-defined and ' + 'being set to 0.0 due to no true nor predicted ' + 'samples. Use `zero_division` parameter to ' + 'control this behavior.') + else: + assert len(record) == 0 def test_prf_average_binary_data_non_binary(): @@ -1902,7 +2058,7 @@ def test_hinge_loss_multiclass_invariance_lists(): np.clip(dummy_losses, 0, None, out=dummy_losses) dummy_hinge_loss = np.mean(dummy_losses) assert (hinge_loss(y_true, pred_decision) == - dummy_hinge_loss) + dummy_hinge_loss) def test_log_loss():