diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index a8beea4f8c2f9..0215acbe5b74c 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -34,6 +34,7 @@ from ._classification import zero_one_loss from ._classification import brier_score_loss from ._classification import multilabel_confusion_matrix +from ._classification import tpr_fpr_tnr_fnr_scores from . import cluster from .cluster import adjusted_mutual_info_score @@ -160,6 +161,7 @@ 'SCORERS', 'silhouette_samples', 'silhouette_score', + 'tpr_fpr_tnr_fnr_scores', 'v_measure_score', 'zero_one_loss', 'brier_score_loss', diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index e845e02808872..62eb4b42c189f 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1538,6 +1538,209 @@ def precision_recall_fscore_support(y_true, y_pred, *, beta=1.0, labels=None, return precision, recall, f_score, true_sum +@_deprecate_positional_args +def tpr_fpr_tnr_fnr_scores(y_true, y_pred, *, labels=None, pos_label=1, + average=None, warn_for=('tpr', 'fpr', + 'tnr', 'fnr'), + sample_weight=None, zero_division="warn"): + """Compute True Positive Rate (TPR), False Positive Rate (FPR),\ + True Negative Rate (TNR), False Negative Rate (FNR) for each class + + The TPR is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. + + The FPR is the ratio ``fp / (tn + fp)`` where ``tn`` is the number of + true negatives and ``fp`` the number of false positives. + + The TNR is the ratio ``tn / (tn + fp)`` where ``tn`` is the number of + true negatives and ``fp`` the number of false positives. + + The FNR is the ratio ``fn / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. + + If ``pos_label is None`` and in binary classification, this function + returns the true positive rate, false positive rate, true negative rate + and false negative rate if ``average`` is one of ``'micro'``, ``'macro'``, + ``'weighted'`` or ``'samples'``. + + Parameters + ---------- + y_true : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Ground truth (correct) target values. + + y_pred : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Estimated targets as returned by a classifier. + + labels : list, default=None + The set of labels to include when ``average != 'binary'``, and their + order if ``average is None``. Labels present in the data can be + excluded, for example to calculate a multiclass average ignoring a + majority negative class, while labels not present in the data will + result in 0 components in a macro average. For multilabel targets, + labels are column indices. By default, all labels in ``y_true`` and + ``y_pred`` are used in sorted order. + + pos_label : str or int, default=1 + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. + + average : str, {None, 'binary', 'micro', 'macro', 'samples', 'weighted'}, \ + default=None + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + warn_for : tuple or set, for internal use + This determines which warnings will be made in the case that this + function is being used to return only one of its metrics. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + zero_division : str or int, {'warn', 0, 1}, default="warn" + Sets the value to return when there is a zero division: + - tpr, fnr: when there are no positive labels + - fpr, tnr: when there are no negative labels + + If set to "warn", this acts as 0, but warnings are also raised. + + Returns + ------- + tpr : float (if average is not None), \ + or ndarray of shape (n_unique_labels,) + + fpr : float (if average is not None), \ + or ndarray of shape (n_unique_labels,) + + tnr : float (if average is not None), \ + or ndarray of shape (n_unique_labels,) + + fnr : float (if average is not None), \ + or ndarray of shape (n_unique_labels,) + The number of occurrences of each label in ``y_true``. + + References + ---------- + .. [1] `Wikipedia entry for confusion matrix + `_ + + .. [2] `Discriminative Methods for Multi-labeled Classification Advances + in Knowledge Discovery and Data Mining (2004), pp. 22-30 by Shantanu + Godbole, Sunita Sarawagi + `_ + + Examples + -------- + >>> import numpy as np + >>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig']) + >>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog']) + >>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='macro') + (0.3333333333333333, 0.3333333333333333, 0.6666666666666666, + 0.6666666666666666) + >>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='micro') + (0.3333333333333333, 0.3333333333333333, 0.6666666666666666, + 0.6666666666666666) + >>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='weighted') + (0.3333333333333333, 0.3333333333333333, 0.6666666666666666, + 0.6666666666666666) + + It is possible to compute per-label fpr, fnr, tnr, tpr and + supports instead of averaging: + + >>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average=None, + ... labels=['pig', 'dog', 'cat']) + (array([0., 0., 1.]), array([0.25, 0.5 , 0.25]), + array([0.75, 0.5 , 0.75]), array([1., 1., 0.])) + + Notes + ----- + When ``true positive + false negative == 0``, TPR, FNR are undefined; + When ``true negative + false positive == 0``, FPR, TNR are undefined. + In such cases, by default the metric will be set to 0, as will f-score, + and ``UndefinedMetricWarning`` will be raised. This behavior can be + modified with ``zero_division``. + """ + _check_zero_division(zero_division) + + labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) + + # Calculate tp_sum, fp_sum, tn_sum, fn_sum, pos_sum, neg_sum + samplewise = average == 'samples' + MCM = multilabel_confusion_matrix(y_true, y_pred, + sample_weight=sample_weight, + labels=labels, samplewise=samplewise) + tn_sum = MCM[:, 0, 0] + fp_sum = MCM[:, 0, 1] + fn_sum = MCM[:, 1, 0] + tp_sum = MCM[:, 1, 1] + neg_sum = tn_sum + fp_sum + pos_sum = fn_sum + tp_sum + + if average == 'micro': + tp_sum = np.array([tp_sum.sum()]) + fp_sum = np.array([fp_sum.sum()]) + tn_sum = np.array([tn_sum.sum()]) + fn_sum = np.array([fn_sum.sum()]) + neg_sum = np.array([neg_sum.sum()]) + pos_sum = np.array([pos_sum.sum()]) + + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + tpr = _prf_divide(tp_sum, pos_sum, 'tpr', 'positives', + average, warn_for, zero_division) + fpr = _prf_divide(fp_sum, neg_sum, 'fpr', 'negatives', + average, warn_for, zero_division) + tnr = _prf_divide(tn_sum, neg_sum, 'tnr', 'negatives', + average, warn_for, zero_division) + fnr = _prf_divide(fn_sum, pos_sum, 'fnr', 'positives', + average, warn_for, zero_division) + # Average the results + if average == 'weighted': + weights = pos_sum + if weights.sum() == 0: + zero_division_value = 0.0 if zero_division in ["warn", 0] else 1.0 + # TPR and FNR is zero_division if there are no positive labels + # FPR and TNR is zero_division if there are no negative labels + return (zero_division_value if pos_sum.sum() == 0 else 0, + zero_division_value if neg_sum.sum() == 0 else 0, + zero_division_value if neg_sum.sum() == 0 else 0, + zero_division_value if pos_sum.sum() == 0 else 0) + + elif average == 'samples': + weights = sample_weight + else: + weights = None + + if average is not None: + assert average != 'binary' or len(fpr) == 1 + fpr = np.average(fpr, weights=weights) + tnr = np.average(tnr, weights=weights) + fnr = np.average(fnr, weights=weights) + tpr = np.average(tpr, weights=weights) + return tpr, fpr, tnr, fnr + + @_deprecate_positional_args def precision_score(y_true, y_pred, *, labels=None, pos_label=1, average='binary', sample_weight=None, @@ -2174,7 +2377,8 @@ def log_loss(y_true, y_pred, *, eps=1e-15, normalize=True, sample_weight=None, y_true : array-like or label indicator matrix Ground truth (correct) labels for n_samples samples. - y_pred : array-like of float, shape = (n_samples, n_classes) or (n_samples,) + y_pred : array-like of float, shape = (n_samples, n_classes) \ + or (n_samples,) Predicted probabilities, as returned by a classifier's predict_proba method. If ``y_pred.shape = (n_samples,)`` the probabilities provided are assumed to be that of the diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 0fa97eea609f2..38619cfde94d6 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -42,6 +42,7 @@ from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import precision_score from sklearn.metrics import recall_score +from sklearn.metrics import tpr_fpr_tnr_fnr_scores from sklearn.metrics import zero_one_loss from sklearn.metrics import brier_score_loss from sklearn.metrics import multilabel_confusion_matrix @@ -332,6 +333,145 @@ def test_precision_recall_f_ignored_labels(): recall_all(average=average)) +def test_tpr_fpr_tnr_fnr_scores_binary_averaged(): + # Test TPR, FPR, TNR, FNR Score for binary classification task + y_true, y_pred, _ = make_prediction(binary=True) + + # compute scores with default labels introspection + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average=None + ) + assert_array_almost_equal(tprs, [0.88, 0.68], 2) + assert_array_almost_equal(fprs, [0.32, 0.12], 2) + assert_array_almost_equal(tnrs, [0.68, 0.88], 2) + assert_array_almost_equal(fnrs, [0.12, 0.32], 2) + + tn, fp, fn, tp = assert_no_warnings( + confusion_matrix, y_true, y_pred + ).ravel() + assert_array_almost_equal(tp / (tp + fn), 0.68, 2) + assert_array_almost_equal(fp / (tn + fp), 0.12, 2) + assert_array_almost_equal(tn / (tn + fp), 0.88, 2) + assert_array_almost_equal(fn / (tp + fn), 0.32, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average='macro' + ) + assert tpr == np.mean(tprs) + assert fpr == np.mean(fprs) + assert tnr == np.mean(tnrs) + assert fnr == np.mean(fnrs) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average='weighted' + ) + support = np.bincount(y_true) + assert tpr == np.average(tprs, weights=support) + assert fpr == np.average(fprs, weights=support) + assert tnr == np.average(tnrs, weights=support) + assert fnr == np.average(fnrs, weights=support) + + +def test_tpr_fpr_tnr_fnr_scores_multiclass(): + # Test TPR, FPR, TNR, FNR Score for multiclass classification task + y_true, y_pred, _ = make_prediction(binary=False) + + # compute scores with default labels introspection + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average=None + ) + assert_array_almost_equal(tprs, [0.79, 0.1, 0.9], 2) + assert_array_almost_equal(fprs, [0.08, 0.14, 0.45], 2) + assert_array_almost_equal(tnrs, [0.92, 0.86, 0.55], 2) + assert_array_almost_equal(fnrs, [0.21, 0.9, 0.1], 2) + + # averaging tests + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average='micro' + ) + assert_array_almost_equal(tpr, 0.53, 2) + assert_array_almost_equal(fpr, 0.23, 2) + assert_array_almost_equal(tnr, 0.77, 2) + assert_array_almost_equal(fnr, 0.47, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average='macro' + ) + assert_array_almost_equal(tpr, 0.6, 2) + assert_array_almost_equal(fpr, 0.22, 2) + assert_array_almost_equal(tnr, 0.78, 2) + assert_array_almost_equal(fnr, 0.4, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, average='weighted' + ) + assert_array_almost_equal(tpr, 0.53, 2) + assert_array_almost_equal(fpr, 0.2, 2) + assert_array_almost_equal(tnr, 0.8, 2) + assert_array_almost_equal(fnr, 0.47, 2) + + with pytest.raises(ValueError): + tpr_fpr_tnr_fnr_scores(y_true, y_pred, average="samples") + + # same prediction but with and explicit label ordering + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores( + y_true, y_pred, labels=[0, 2, 1], average=None + ) + assert_array_almost_equal(tpr, [0.79, 0.9, 0.1], 2) + assert_array_almost_equal(fpr, [0.08, 0.45, 0.14], 2) + assert_array_almost_equal(tnr, [0.92, 0.55, 0.86], 2) + assert_array_almost_equal(fnr, [0.21, 0.1, 0.9], 2) + + +@ignore_warnings +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_tpr_fpr_tnr_fnr_scores_with_an_empty_prediction(zero_division): + y_true = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 1, 0]]) + y_pred = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 0]]) + + zero_division = 1.0 if zero_division == 1.0 else 0.0 + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores(y_true, y_pred, + average=None, + zero_division=zero_division) + assert_array_almost_equal(tpr, [0.0, 0.5, 1.0, zero_division], 2) + assert_array_almost_equal(fpr, [0.0, 0.0, 0.0, 1 / 3.0], 2) + assert_array_almost_equal(tnr, [1.0, 1.0, 1.0, 2 / 3.0], 2) + assert_array_almost_equal(fnr, [1.0, 0.5, 0.0, zero_division], 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores(y_true, y_pred, + average="macro", + zero_division=zero_division) + assert_almost_equal(tpr, 0.625 if zero_division else 0.375) + assert_almost_equal(fpr, 1 / 3.0 / 4.0) + assert_almost_equal(tnr, 0.91666, 5) + assert_almost_equal(fnr, 0.625 if zero_division else 0.375) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores(y_true, y_pred, + average="micro", + zero_division=zero_division) + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0.125) + assert_almost_equal(tnr, 0.875) + assert_almost_equal(fnr, 0.5) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores(y_true, y_pred, + average="weighted", + zero_division=zero_division) + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0) + assert_almost_equal(tnr, 1.0) + assert_almost_equal(fnr, 0.5) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_scores(y_true, y_pred, + average="samples", + sample_weight=[1, 1, 2], + zero_division=zero_division) + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0.08333, 5) + assert_almost_equal(tnr, 0.91666, 5) + assert_almost_equal(fnr, 0.5) + + def test_average_precision_score_score_non_binary_class(): # Test that average_precision_score function returns an error when trying # to compute average_precision_score for multiclass task. @@ -2118,8 +2258,9 @@ def test_hinge_loss_multiclass(): ]) np.clip(dummy_losses, 0, None, out=dummy_losses) dummy_hinge_loss = np.mean(dummy_losses) - assert (hinge_loss(y_true, pred_decision) == - dummy_hinge_loss) + assert ( + hinge_loss(y_true, pred_decision) == dummy_hinge_loss + ) def test_hinge_loss_multiclass_missing_labels_with_labels_none(): diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 297793541007b..cd2073b82f9b7 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -56,6 +56,7 @@ from sklearn.metrics import recall_score from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_curve +from sklearn.metrics import tpr_fpr_tnr_fnr_scores from sklearn.metrics import zero_one_loss from sklearn.metrics import ndcg_score from sklearn.metrics import dcg_score @@ -144,6 +145,9 @@ "f2_score": partial(fbeta_score, beta=2), "f0.5_score": partial(fbeta_score, beta=0.5), "matthews_corrcoef_score": matthews_corrcoef, + "tpr_fpr_tnr_fnr_scores": tpr_fpr_tnr_fnr_scores, + "binary_tpr_fpr_tnr_fnr_scores": + partial(tpr_fpr_tnr_fnr_scores, average="binary"), "weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5), "weighted_f1_score": partial(f1_score, average="weighted"), @@ -151,6 +155,8 @@ "weighted_precision_score": partial(precision_score, average="weighted"), "weighted_recall_score": partial(recall_score, average="weighted"), "weighted_jaccard_score": partial(jaccard_score, average="weighted"), + "weighted_tpr_fpr_tnr_fnr_scores": + partial(tpr_fpr_tnr_fnr_scores, average="weighted"), "micro_f0.5_score": partial(fbeta_score, average="micro", beta=0.5), "micro_f1_score": partial(f1_score, average="micro"), @@ -158,6 +164,8 @@ "micro_precision_score": partial(precision_score, average="micro"), "micro_recall_score": partial(recall_score, average="micro"), "micro_jaccard_score": partial(jaccard_score, average="micro"), + "micro_tpr_fpr_tnr_fnr_scores": + partial(tpr_fpr_tnr_fnr_scores, average="micro"), "macro_f0.5_score": partial(fbeta_score, average="macro", beta=0.5), "macro_f1_score": partial(f1_score, average="macro"), @@ -165,6 +173,8 @@ "macro_precision_score": partial(precision_score, average="macro"), "macro_recall_score": partial(recall_score, average="macro"), "macro_jaccard_score": partial(jaccard_score, average="macro"), + "macro_tpr_fpr_tnr_fnr_scores": + partial(tpr_fpr_tnr_fnr_scores, average="macro"), "samples_f0.5_score": partial(fbeta_score, average="samples", beta=0.5), "samples_f1_score": partial(f1_score, average="samples"), @@ -172,6 +182,8 @@ "samples_precision_score": partial(precision_score, average="samples"), "samples_recall_score": partial(recall_score, average="samples"), "samples_jaccard_score": partial(jaccard_score, average="samples"), + "samples_tpr_fpr_tnr_fnr_scores": + partial(tpr_fpr_tnr_fnr_scores, average="samples"), "cohen_kappa_score": cohen_kappa_score, } @@ -269,6 +281,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_precision_score", "samples_recall_score", "samples_jaccard_score", + "samples_tpr_fpr_tnr_fnr_scores", "coverage_error", "unnormalized_multilabel_confusion_matrix_sample", "label_ranking_loss", @@ -287,6 +300,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "roc_auc_score", "weighted_roc_auc", + "tpr_fpr_tnr_fnr_scores", "average_precision_score", "weighted_average_precision_score", "micro_average_precision_score", @@ -300,6 +314,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "f1_score", "f2_score", "f0.5_score", + "binary_tpr_fpr_tnr_fnr_scores", # curves "roc_curve", @@ -333,6 +348,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", "jaccard_score", + "tpr_fpr_tnr_fnr_scores", "average_precision_score", "weighted_average_precision_score", "micro_average_precision_score", @@ -362,17 +378,21 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", "jaccard_score", + "tpr_fpr_tnr_fnr_scores", "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", "weighted_jaccard_score", + "weighted_tpr_fpr_tnr_fnr_scores", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", "micro_jaccard_score", + "micro_tpr_fpr_tnr_fnr_scores", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "macro_jaccard_score", + "macro_tpr_fpr_tnr_fnr_scores", "unnormalized_multilabel_confusion_matrix", "unnormalized_multilabel_confusion_matrix_sample", @@ -414,20 +434,24 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", "weighted_jaccard_score", + "weighted_tpr_fpr_tnr_fnr_scores", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "macro_jaccard_score", + "macro_tpr_fpr_tnr_fnr_scores", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", "micro_jaccard_score", + "micro_tpr_fpr_tnr_fnr_scores", "unnormalized_multilabel_confusion_matrix", "samples_f0.5_score", "samples_f1_score", "samples_f2_score", "samples_precision_score", "samples_recall_score", "samples_jaccard_score", + "samples_tpr_fpr_tnr_fnr_scores", } # Regression metrics with "multioutput-continuous" format support @@ -452,6 +476,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # P = R = F = accuracy in multiclass case "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", + "micro_tpr_fpr_tnr_fnr_scores", "matthews_corrcoef_score", "mean_absolute_error", "mean_squared_error", "median_absolute_error", "max_error", @@ -474,6 +499,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "precision_score", "recall_score", "f2_score", "f0.5_score", + "tpr_fpr_tnr_fnr_scores", + "weighted_tpr_fpr_tnr_fnr_scores", + "macro_tpr_fpr_tnr_fnr_scores", + "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_jaccard_score", "unnormalized_multilabel_confusion_matrix",