diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 856a104469bfc..789ffa038f25d 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -313,6 +313,7 @@ Others also work in the multiclass case: confusion_matrix hinge_loss matthews_corrcoef + roc_auc_score Some also work in the multilabel case: @@ -331,6 +332,7 @@ Some also work in the multilabel case: precision_recall_fscore_support precision_score recall_score + roc_auc_score zero_one_loss And some work with binary and multilabel (but not multiclass) problems: @@ -339,7 +341,6 @@ And some work with binary and multilabel (but not multiclass) problems: :template: function.rst average_precision_score - roc_auc_score In the following sub-sections, we will describe each of those functions, @@ -1313,9 +1314,52 @@ In multi-label classification, the :func:`roc_auc_score` function is extended by averaging over the labels as :ref:`above `. Compared to metrics such as the subset accuracy, the Hamming loss, or the -F1 score, ROC doesn't require optimizing a threshold for each label. The -:func:`roc_auc_score` function can also be used in multi-class classification, -if the predicted outputs have been binarized. +F1 score, ROC doesn't require optimizing a threshold for each label. + +The :func:`roc_auc_score` function can also be used in multi-class +classification. Two averaging strategies are currently supported: the +one-vs-one algorithm computes the average of the pairwise ROC AUC scores, and +the one-vs-rest algorithm computes the average of the ROC AUC scores for each +class against all other classes. In both cases, the predicted labels are +provided in an array with values from 0 to ``n_classes``, and the scores +correspond to the probability estimates that a sample belongs to a particular +class. The OvO and OvR algorithms supports weighting uniformly +(``average='macro'``) and weighting by the prevalence (``average='weighted'``). + +**One-vs-one Algorithm**: Computes the average AUC of all possible pairwise +combinations of classes. [HT2001]_ defines a multiclass AUC metric weighted +uniformly: + +.. math:: + + \frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c (\text{AUC}(j | k) + + \text{AUC}(k | j)) + +where :math:`c` is the number of classes and :math:`\text{AUC}(j | k)` is the +AUC with class :math:`j` as the positive class and class :math:`k` as the +negative class. In general, +:math:`\text{AUC}(j | k) \neq \text{AUC}(k | j))` in the multiclass +case. This algorithm is used by setting the keyword argument ``multiclass`` +to ``'ovo'`` and ``average`` to ``'macro'``. + +The [HT2001]_ multiclass AUC metric can be extended to be weighted by the +prevalence: + +.. math:: + + \frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c p(j \cup k)( + \text{AUC}(j | k) + \text{AUC}(k | j)) + +where :math:`c` is the number of classes. This algorithm is used by setting +the keyword argument ``multiclass`` to ``'ovo'`` and ``average`` to +``'weighted'``. The ``'weighted'`` option returns a prevalence-weighted average +as described in [FC2009]_. + +**One-vs-rest Algorithm**: Computes the AUC of each class against the rest. +The algorithm is functionally the same as the multilabel case. To enable this +algorithm set the keyword argument ``multiclass`` to ``'ovr'``. Similar to +OvO, OvR supports two types of averaging: ``'macro'`` [F2006]_ and +``'weighted'`` [F2001]_. In applications where a high false positive rate is not tolerable the parameter ``max_fpr`` of :func:`roc_auc_score` can be used to summarize the ROC curve up @@ -1341,6 +1385,28 @@ to the given limit. for an example of using ROC to model species distribution. +.. topic:: References: + + .. [HT2001] Hand, D.J. and Till, R.J., (2001). `A simple generalisation + of the area under the ROC curve for multiple class classification problems. + `_ + Machine learning, 45(2), pp.171-186. + + .. [FC2009] Ferri, Cèsar & Hernandez-Orallo, Jose & Modroiu, R. (2009). + `An Experimental Comparison of Performance Measures for Classification. + `_ + Pattern Recognition Letters. 30. 27-38. + + .. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis. + `_ + Pattern Recognition Letters, 27(8), pp. 861-874. + + .. [F2001] Fawcett, T., 2001. `Using rule sets to maximize + ROC performance `_ + In Data Mining, 2001. + Proceedings IEEE International Conference, pp. 131-138. + + .. _zero_one_loss: Zero one loss diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 3c3e07ec249b3..f2046cc6b64f1 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -118,6 +118,13 @@ Changelog requires less memory. :pr:`14108`, pr:`14170` by :user:`Alex Henrie `. +:mod:`sklearn.metrics` +...................... + +- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`. + :issue:`12789` by :user:`Kathy Chen `, + :user:`Mohamed Maskani `, and :user:`Thomas Fan `. + :mod:`sklearn.model_selection` .................. diff --git a/examples/model_selection/plot_roc.py b/examples/model_selection/plot_roc.py index 475d7b4aba7a6..653c448d5cda4 100644 --- a/examples/model_selection/plot_roc.py +++ b/examples/model_selection/plot_roc.py @@ -15,24 +15,21 @@ The "steepness" of ROC curves is also important, since it is ideal to maximize the true positive rate while minimizing the false positive rate. -Multiclass settings -------------------- - ROC curves are typically used in binary classification to study the output of -a classifier. In order to extend ROC curve and ROC area to multi-class -or multi-label classification, it is necessary to binarize the output. One ROC +a classifier. In order to extend ROC curve and ROC area to multi-label +classification, it is necessary to binarize the output. One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging). -Another evaluation measure for multi-class classification is +Another evaluation measure for multi-label classification is macro-averaging, which gives equal weight to the classification of each label. .. note:: See also :func:`sklearn.metrics.roc_auc_score`, - :ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`. + :ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py` """ print(__doc__) @@ -47,6 +44,7 @@ from sklearn.preprocessing import label_binarize from sklearn.multiclass import OneVsRestClassifier from scipy import interp +from sklearn.metrics import roc_auc_score # Import some data to play with iris = datasets.load_iris() @@ -101,8 +99,8 @@ ############################################################################## -# Plot ROC curves for the multiclass problem - +# Plot ROC curves for the multilabel problem +# .......................................... # Compute macro-average ROC curve and ROC area # First aggregate all false positive rates @@ -146,3 +144,29 @@ plt.title('Some extension of Receiver operating characteristic to multi-class') plt.legend(loc="lower right") plt.show() + + +############################################################################## +# Area under ROC for the multiclass problem +# ......................................... +# The :func:`sklearn.metrics.roc_auc_score` function can be used for +# multi-class classification. The mutliclass One-vs-One scheme compares every +# unique pairwise combination of classes. In this section, we calcuate the AUC +# using the OvR and OvO schemes. We report a macro average, and a +# prevalence-weighted average. +y_prob = classifier.predict_proba(X_test) + +macro_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo", + average="macro") +weighted_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo", + average="weighted") +macro_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr", + average="macro") +weighted_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr", + average="weighted") +print("One-vs-One ROC AUC scores:\n{:.6f} (macro),\n{:.6f} " + "(weighted by prevalence)" + .format(macro_roc_auc_ovo, weighted_roc_auc_ovo)) +print("One-vs-Rest ROC AUC scores:\n{:.6f} (macro),\n{:.6f} " + "(weighted by prevalence)" + .format(macro_roc_auc_ovr, weighted_roc_auc_ovr)) diff --git a/sklearn/metrics/base.py b/sklearn/metrics/base.py index d727d150e4728..288730354139c 100644 --- a/sklearn/metrics/base.py +++ b/sklearn/metrics/base.py @@ -12,6 +12,7 @@ # Noel Dawe # License: BSD 3 clause +from itertools import combinations import numpy as np @@ -123,3 +124,74 @@ def _average_binary_score(binary_metric, y_true, y_score, average, return np.average(score, weights=average_weight) else: return score + + +def _average_multiclass_ovo_score(binary_metric, y_true, y_score, + average='macro'): + """Average one-versus-one scores for multiclass classification. + + Uses the binary metric for one-vs-one multiclass classification, + where the score is computed according to the Hand & Till (2001) algorithm. + + Parameters + ---------- + binary_metric : callable + The binary metric function to use that accepts the following as input + y_true_target : array, shape = [n_samples_target] + Some sub-array of y_true for a pair of classes designated + positive and negative in the one-vs-one scheme. + y_score_target : array, shape = [n_samples_target] + Scores corresponding to the probability estimates + of a sample belonging to the designated positive class label + + y_true : array-like, shape = (n_samples, ) + True multiclass labels. + + y_score : array-like, shape = (n_samples, n_classes) + Target scores corresponding to probability estimates of a sample + belonging to a particular class + + average : 'macro' or 'weighted', optional (default='macro') + Determines the type of averaging performed on the pairwise binary + metric scores + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. Classes + are assumed to be uniformly distributed. + ``'weighted'``: + Calculate metrics for each label, taking into account the + prevalence of the classes. + + Returns + ------- + score : float + Average of the pairwise binary metric scores + """ + check_consistent_length(y_true, y_score) + + y_true_unique = np.unique(y_true) + n_classes = y_true_unique.shape[0] + n_pairs = n_classes * (n_classes - 1) // 2 + pair_scores = np.empty(n_pairs) + + is_weighted = average == "weighted" + prevalence = np.empty(n_pairs) if is_weighted else None + + # Compute scores treating a as positive class and b as negative class, + # then b as positive class and a as negative class + for ix, (a, b) in enumerate(combinations(y_true_unique, 2)): + a_mask = y_true == a + b_mask = y_true == b + ab_mask = np.logical_or(a_mask, b_mask) + + if is_weighted: + prevalence[ix] = np.average(ab_mask) + + a_true = a_mask[ab_mask] + b_true = b_mask[ab_mask] + + a_true_score = binary_metric(a_true, y_score[ab_mask, a]) + b_true_score = binary_metric(b_true, y_score[ab_mask, b]) + pair_scores[ix] = (a_true_score + b_true_score) / 2 + + return np.average(pair_scores, weights=prevalence) diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index 0f75d4bb48685..8e1775e80e635 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -33,8 +33,9 @@ from ..utils.sparsefuncs import count_nonzero from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize +from ..preprocessing.label import _encode -from .base import _average_binary_score +from .base import _average_binary_score, _average_multiclass_ovo_score def auc(x, y): @@ -214,8 +215,36 @@ def _binary_uninterpolated_average_precision( average, sample_weight=sample_weight) +def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None): + """Binary roc auc score""" + if len(np.unique(y_true)) != 2: + raise ValueError("Only one class present in y_true. ROC AUC score " + "is not defined in that case.") + + fpr, tpr, _ = roc_curve(y_true, y_score, + sample_weight=sample_weight) + if max_fpr is None or max_fpr == 1: + return auc(fpr, tpr) + if max_fpr <= 0 or max_fpr > 1: + raise ValueError("Expected max_fpr in range (0, 1], got: %r" % max_fpr) + + # Add a single point at max_fpr by linear interpolation + stop = np.searchsorted(fpr, max_fpr, 'right') + x_interp = [fpr[stop - 1], fpr[stop]] + y_interp = [tpr[stop - 1], tpr[stop]] + tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp)) + fpr = np.append(fpr[:stop], max_fpr) + partial_auc = auc(fpr, tpr) + + # McClish correction: standardize result to be 0.5 if non-discriminant + # and 1 if maximal + min_area = 0.5 * max_fpr**2 + max_area = max_fpr + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) + + def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, - max_fpr=None): + max_fpr=None, multi_class="raise", labels=None): """Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. @@ -228,17 +257,22 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, ---------- y_true : array, shape = [n_samples] or [n_samples, n_classes] True binary labels or binary label indicators. + The multiclass case expects shape = [n_samples] and labels + with values in ``range(n_classes)``. y_score : array, shape = [n_samples] or [n_samples, n_classes] Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). For binary y_true, y_score is supposed to be the score of the class with greater - label. + label. The multiclass case expects shape = [n_samples, n_classes] + where the scores correspond to probability estimates. average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted'] If ``None``, the scores for each class are returned. Otherwise, this determines the type of averaging performed on the data: + Note: multiclass ROC AUC currently only handles the 'macro' and + 'weighted' averages. ``'micro'``: Calculate metrics globally by considering each element of the label @@ -259,7 +293,23 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, max_fpr : float > 0 and <= 1, optional If not ``None``, the standardized partial AUC [3]_ over the range - [0, max_fpr] is returned. + [0, max_fpr] is returned. For the multiclass case, ``max_fpr``, + should be either equal to ``None`` or ``1.0`` as AUC ROC partial + computation currently is not supported for multiclass. + + multi_class : string, 'ovr' or 'ovo', optional(default='raise') + Determines the type of multiclass configuration to use. + ``multi_class`` must be provided when ``y_true`` is multiclass. + ``'ovr'``: + Calculate metrics for the multiclass case using the one-vs-rest + approach. + ``'ovo'``: + Calculate metrics for the multiclass case using the one-vs-one + approach. + + labels : array, shape = [n_classes] or None, optional (default=None) + List of labels to index ``y_score`` used for multiclass. If ``None``, + the lexicon order of ``y_true`` is used to index ``y_score``. Returns ------- @@ -292,41 +342,136 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, 0.75 """ - def _binary_roc_auc_score(y_true, y_score, sample_weight=None): - if len(np.unique(y_true)) != 2: - raise ValueError("Only one class present in y_true. ROC AUC score " - "is not defined in that case.") - - fpr, tpr, _ = roc_curve(y_true, y_score, - sample_weight=sample_weight) - if max_fpr is None or max_fpr == 1: - return auc(fpr, tpr) - if max_fpr <= 0 or max_fpr > 1: - raise ValueError("Expected max_fpr in range (0, 1], got: %r" - % max_fpr) - - # Add a single point at max_fpr by linear interpolation - stop = np.searchsorted(fpr, max_fpr, 'right') - x_interp = [fpr[stop - 1], fpr[stop]] - y_interp = [tpr[stop - 1], tpr[stop]] - tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp)) - fpr = np.append(fpr[:stop], max_fpr) - partial_auc = auc(fpr, tpr) - - # McClish correction: standardize result to be 0.5 if non-discriminant - # and 1 if maximal - min_area = 0.5 * max_fpr**2 - max_area = max_fpr - return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) y_type = type_of_target(y_true) - if y_type == "binary": + y_true = check_array(y_true, ensure_2d=False, dtype=None) + y_score = check_array(y_score, ensure_2d=False) + + if y_type == "multiclass" or (y_type == "binary" and + y_score.ndim == 2 and + y_score.shape[1] > 2): + # do not support partial ROC computation for multiclass + if max_fpr is not None and max_fpr != 1.: + raise ValueError("Partial AUC computation not available in " + "multiclass setting, 'max_fpr' must be" + " set to `None`, received `max_fpr={0}` " + "instead".format(max_fpr)) + if multi_class == 'raise': + raise ValueError("multi_class must be in ('ovo', 'ovr')") + return _multiclass_roc_auc_score(y_true, y_score, labels, + multi_class, average, sample_weight) + elif y_type == "binary": labels = np.unique(y_true) y_true = label_binarize(y_true, labels)[:, 0] + return _average_binary_score(partial(_binary_roc_auc_score, + max_fpr=max_fpr), + y_true, y_score, average, + sample_weight=sample_weight) + else: # multilabel-indicator + return _average_binary_score(partial(_binary_roc_auc_score, + max_fpr=max_fpr), + y_true, y_score, average, + sample_weight=sample_weight) + + +def _multiclass_roc_auc_score(y_true, y_score, labels, + multi_class, average, sample_weight): + """Multiclass roc auc score + + Parameters + ---------- + y_true : array-like, shape = (n_samples, ) + True multiclass labels. + + y_score : array-like, shape = (n_samples, n_classes) + Target scores corresponding to probability estimates of a sample + belonging to a particular class + + labels : array, shape = [n_classes] or None, optional (default=None) + List of labels to index ``y_score`` used for multiclass. If ``None``, + the lexical order of ``y_true`` is used to index ``y_score``. + + multi_class : string, 'ovr' or 'ovo' + Determines the type of multiclass configuration to use. + ``'ovr'``: + Calculate metrics for the multiclass case using the one-vs-rest + approach. + ``'ovo'``: + Calculate metrics for the multiclass case using the one-vs-one + approach. + + average : 'macro' or 'weighted', optional (default='macro') + Determines the type of averaging performed on the pairwise binary + metric scores + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. Classes + are assumed to be uniformly distributed. + ``'weighted'``: + Calculate metrics for each label, taking into account the + prevalence of the classes. + + sample_weight : array-like of shape = [n_samples], optional + Sample weights. + + """ + # validation of the input y_score + if not np.allclose(1, y_score.sum(axis=1)): + raise ValueError( + "Target scores need to be probabilities for multiclass " + "roc_auc, i.e. they should sum up to 1.0 over classes") + + # validation for multiclass parameter specifications + average_options = ("macro", "weighted") + if average not in average_options: + raise ValueError("average must be one of {0} for " + "multiclass problems".format(average_options)) + + multiclass_options = ("ovo", "ovr") + if multi_class not in multiclass_options: + raise ValueError("multi_class='{0}' is not supported " + "for multiclass ROC AUC, multi_class must be " + "in {1}".format( + multi_class, multiclass_options)) + + if labels is not None: + labels = column_or_1d(labels) + classes = _encode(labels) + if len(classes) != len(labels): + raise ValueError("Parameter 'labels' must be unique") + if not np.array_equal(classes, labels): + raise ValueError("Parameter 'labels' must be ordered") + if len(classes) != y_score.shape[1]: + raise ValueError( + "Number of given labels, {0}, not equal to the number " + "of columns in 'y_score', {1}".format( + len(classes), y_score.shape[1])) + if len(np.setdiff1d(y_true, classes)): + raise ValueError( + "'y_true' contains labels not in parameter 'labels'") + else: + classes = _encode(y_true) + if len(classes) != y_score.shape[1]: + raise ValueError( + "Number of classes in y_true not equal to the number of " + "columns in 'y_score'") - return _average_binary_score( - _binary_roc_auc_score, y_true, y_score, average, - sample_weight=sample_weight) + if multi_class == "ovo": + if sample_weight is not None: + raise ValueError("sample_weight is not supported " + "for multiclass one-vs-one ROC AUC, " + "'sample_weight' must be None in this case.") + _, y_true_encoded = _encode(y_true, uniques=classes, encode=True) + # Hand & Till (2001) implementation (ovo) + return _average_multiclass_ovo_score(_binary_roc_auc_score, + y_true_encoded, + y_score, average=average) + else: + # ovr is same as multi-label + y_true_multilabel = label_binarize(y_true, classes) + return _average_binary_score(_binary_roc_auc_score, y_true_multilabel, + y_score, average, + sample_weight=sample_weight) def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 087ade8c3642f..80a0427647a3a 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -487,6 +487,16 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, needs_threshold=True) average_precision_scorer = make_scorer(average_precision_score, needs_threshold=True) +roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_threshold=True, + multi_class='ovo') +roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True, + multi_class='ovo', + average='weighted') +roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_threshold=True, + multi_class='ovr') +roc_auc_ovr_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True, + multi_class='ovr', + average='weighted') # Score function for probabilistic classification neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False, @@ -515,6 +525,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, neg_mean_squared_error=neg_mean_squared_error_scorer, neg_mean_squared_log_error=neg_mean_squared_log_error_scorer, accuracy=accuracy_scorer, roc_auc=roc_auc_scorer, + roc_auc_ovr=roc_auc_ovr_scorer, + roc_auc_ovo=roc_auc_ovo_scorer, balanced_accuracy=balanced_accuracy_scorer, average_precision=average_precision_scorer, neg_log_loss=neg_log_loss_scorer, diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 67e9b66a4b695..6442b11834671 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -211,6 +211,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "weighted_roc_auc": partial(roc_auc_score, average="weighted"), "samples_roc_auc": partial(roc_auc_score, average="samples"), "micro_roc_auc": partial(roc_auc_score, average="micro"), + "ovr_roc_auc": partial(roc_auc_score, average="macro", multi_class='ovr'), + "weighted_ovr_roc_auc": partial(roc_auc_score, average="weighted", + multi_class='ovr'), + "ovo_roc_auc": partial(roc_auc_score, average="macro", multi_class='ovo'), + "weighted_ovo_roc_auc": partial(roc_auc_score, average="weighted", + multi_class='ovo'), "partial_roc_auc": partial(roc_auc_score, max_fpr=0.5), "average_precision_score": @@ -258,11 +264,11 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): METRIC_UNDEFINED_MULTICLASS = { "brier_score_loss", - "roc_auc_score", "micro_roc_auc", - "weighted_roc_auc", "samples_roc_auc", "partial_roc_auc", + "roc_auc_score", + "weighted_roc_auc", "average_precision_score", "weighted_average_precision_score", @@ -457,7 +463,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # No Sample weight support METRICS_WITHOUT_SAMPLE_WEIGHT = { "median_absolute_error", - "max_error" + "max_error", + "ovo_roc_auc", + "weighted_ovo_roc_auc" } @@ -1184,7 +1192,10 @@ def test_multiclass_sample_weight_invariance(name): y_score = random_state.random_sample(size=(n_samples, 5)) metric = ALL_METRICS[name] if name in THRESHOLDED_METRICS: - check_sample_weight_invariance(name, metric, y_true, y_score) + # softmax + temp = np.exp(-y_score) + y_score_norm = temp / temp.sum(axis=-1).reshape(-1, 1) + check_sample_weight_invariance(name, metric, y_true, y_score_norm) else: check_sample_weight_invariance(name, metric, y_true, y_pred) @@ -1280,3 +1291,27 @@ def test_thresholded_multilabel_multioutput_permutations_invariance(name): current_score = metric(y_true_perm, y_score_perm) assert_almost_equal(score, current_score) + + +@pytest.mark.parametrize( + 'name', + sorted(set(THRESHOLDED_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)) +def test_thresholded_metric_permutation_invariance(name): + n_samples, n_classes = 100, 3 + random_state = check_random_state(0) + + y_score = random_state.rand(n_samples, n_classes) + temp = np.exp(-y_score) + y_score = temp / temp.sum(axis=-1).reshape(-1, 1) + y_true = random_state.randint(0, n_classes, size=n_samples) + + metric = ALL_METRICS[name] + score = metric(y_true, y_score) + for perm in permutations(range(n_classes), n_classes): + inverse_perm = np.zeros(n_classes, dtype=int) + inverse_perm[list(perm)] = np.arange(n_classes) + y_score_perm = y_score[:, inverse_perm] + y_true_perm = np.take(perm, y_true) + + current_score = metric(y_true_perm, y_score_perm) + assert_almost_equal(score, current_score) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 140c1c7abad9c..c202aef1added 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -1,4 +1,3 @@ - import pytest import numpy as np import warnings @@ -7,6 +6,7 @@ from sklearn import datasets from sklearn import svm +from sklearn.utils.extmath import softmax from sklearn.datasets import make_multilabel_classification from sklearn.random_projection import sparse_random_matrix from sklearn.utils.validation import check_array, check_consistent_length @@ -440,6 +440,185 @@ def test_auc_errors(): assert_raise_message(ValueError, error_message, auc, x, y) +@pytest.mark.parametrize( + "y_true, labels", + [(np.array([0, 1, 0, 2]), [0, 1, 2]), + (np.array([0, 1, 0, 2]), None), + (["a", "b", "a", "c"], ["a", "b", "c"]), + (["a", "b", "a", "c"], None)] +) +def test_multiclass_ovo_roc_auc_toydata(y_true, labels): + # Tests the one-vs-one multiclass ROC AUC algorithm + # on a small example, representative of an expected use case. + y_scores = np.array( + [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]]) + + # Used to compute the expected output. + # Consider labels 0 and 1: + # positive label is 0, negative label is 1 + score_01 = roc_auc_score([1, 0, 1], [0.1, 0.3, 0.35]) + # positive label is 1, negative label is 0 + score_10 = roc_auc_score([0, 1, 0], [0.8, 0.4, 0.5]) + average_score_01 = (score_01 + score_10) / 2 + + # Consider labels 0 and 2: + score_02 = roc_auc_score([1, 1, 0], [0.1, 0.35, 0]) + score_20 = roc_auc_score([0, 0, 1], [0.1, 0.15, 0.8]) + average_score_02 = (score_02 + score_20) / 2 + + # Consider labels 1 and 2: + score_12 = roc_auc_score([1, 0], [0.4, 0.2]) + score_21 = roc_auc_score([0, 1], [0.3, 0.8]) + average_score_12 = (score_12 + score_21) / 2 + + # Unweighted, one-vs-one multiclass ROC AUC algorithm + ovo_unweighted_score = ( + average_score_01 + average_score_02 + average_score_12) / 3 + assert_almost_equal( + roc_auc_score(y_true, y_scores, labels=labels, multi_class="ovo"), + ovo_unweighted_score) + + # Weighted, one-vs-one multiclass ROC AUC algorithm + # Each term is weighted by the prevalence for the positive label. + pair_scores = [average_score_01, average_score_02, average_score_12] + prevalence = [0.75, 0.75, 0.50] + ovo_weighted_score = np.average(pair_scores, weights=prevalence) + assert_almost_equal( + roc_auc_score( + y_true, + y_scores, + labels=labels, + multi_class="ovo", + average="weighted"), ovo_weighted_score) + + +@pytest.mark.parametrize("y_true, labels", + [(np.array([0, 2, 0, 2]), [0, 1, 2]), + (np.array(['a', 'd', 'a', 'd']), ['a', 'b', 'd'])]) +def test_multiclass_ovo_roc_auc_toydata_binary(y_true, labels): + # Tests the one-vs-one multiclass ROC AUC algorithm for binary y_true + # + # on a small example, representative of an expected use case. + y_scores = np.array( + [[0.2, 0.0, 0.8], [0.6, 0.0, 0.4], [0.55, 0.0, 0.45], [0.4, 0.0, 0.6]]) + + # Used to compute the expected output. + # Consider labels 0 and 1: + # positive label is 0, negative label is 1 + score_01 = roc_auc_score([1, 0, 1, 0], [0.2, 0.6, 0.55, 0.4]) + # positive label is 1, negative label is 0 + score_10 = roc_auc_score([0, 1, 0, 1], [0.8, 0.4, 0.45, 0.6]) + ovo_score = (score_01 + score_10) / 2 + + assert_almost_equal( + roc_auc_score(y_true, y_scores, labels=labels, multi_class='ovo'), + ovo_score) + + # Weighted, one-vs-one multiclass ROC AUC algorithm + assert_almost_equal( + roc_auc_score(y_true, y_scores, labels=labels, multi_class='ovo', + average="weighted"), ovo_score) + + +@pytest.mark.parametrize( + "y_true, labels", + [(np.array([0, 1, 2, 2]), None), + (["a", "b", "c", "c"], None), + ([0, 1, 2, 2], [0, 1, 2]), + (["a", "b", "c", "c"], ["a", "b", "c"])]) +def test_multiclass_ovr_roc_auc_toydata(y_true, labels): + # Tests the unweighted, one-vs-rest multiclass ROC AUC algorithm + # on a small example, representative of an expected use case. + y_scores = np.array( + [[1.0, 0.0, 0.0], [0.1, 0.5, 0.4], [0.1, 0.1, 0.8], [0.3, 0.3, 0.4]]) + # Compute the expected result by individually computing the 'one-vs-rest' + # ROC AUC scores for classes 0, 1, and 2. + out_0 = roc_auc_score([1, 0, 0, 0], y_scores[:, 0]) + out_1 = roc_auc_score([0, 1, 0, 0], y_scores[:, 1]) + out_2 = roc_auc_score([0, 0, 1, 1], y_scores[:, 2]) + result_unweighted = (out_0 + out_1 + out_2) / 3. + + assert_almost_equal( + roc_auc_score(y_true, y_scores, multi_class="ovr", labels=labels), + result_unweighted) + + # Tests the weighted, one-vs-rest multiclass ROC AUC algorithm + # on the same input (Provost & Domingos, 2001) + result_weighted = out_0 * 0.25 + out_1 * 0.25 + out_2 * 0.5 + assert_almost_equal( + roc_auc_score( + y_true, + y_scores, + multi_class="ovr", + labels=labels, + average="weighted"), result_weighted) + + +@pytest.mark.parametrize( + "msg, y_true, labels", + [("Parameter 'labels' must be unique", np.array([0, 1, 2, 2]), [0, 2, 0]), + ("Parameter 'labels' must be unique", np.array(["a", "b", "c", "c"]), + ["a", "a", "b"]), + ("Number of classes in y_true not equal to the number of columns " + "in 'y_score'", np.array([0, 2, 0, 2]), None), + ("Parameter 'labels' must be ordered", np.array(["a", "b", "c", "c"]), + ["a", "c", "b"]), + ("Number of given labels, 2, not equal to the number of columns in " + "'y_score', 3", + np.array([0, 1, 2, 2]), [0, 1]), + ("Number of given labels, 2, not equal to the number of columns in " + "'y_score', 3", + np.array(["a", "b", "c", "c"]), ["a", "b"]), + ("Number of given labels, 4, not equal to the number of columns in " + "'y_score', 3", + np.array([0, 1, 2, 2]), [0, 1, 2, 3]), + ("Number of given labels, 4, not equal to the number of columns in " + "'y_score', 3", + np.array(["a", "b", "c", "c"]), ["a", "b", "c", "d"]), + ("'y_true' contains labels not in parameter 'labels'", + np.array(["a", "b", "c", "e"]), ["a", "b", "c"]), + ("'y_true' contains labels not in parameter 'labels'", + np.array(["a", "b", "c", "d"]), ["a", "b", "c"]), + ("'y_true' contains labels not in parameter 'labels'", + np.array([0, 1, 2, 3]), [0, 1, 2])]) +@pytest.mark.parametrize("multi_class", ["ovo", "ovr"]) +def test_roc_auc_score_multiclass_labels_error( + msg, y_true, labels, multi_class): + y_scores = np.array( + [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]]) + + with pytest.raises(ValueError, match=msg): + roc_auc_score(y_true, y_scores, labels=labels, multi_class=multi_class) + + +@pytest.mark.parametrize("msg, kwargs", [ + ((r"average must be one of \('macro', 'weighted'\) for " + r"multiclass problems"), {"average": "samples", "multi_class": "ovo"}), + ((r"average must be one of \('macro', 'weighted'\) for " + r"multiclass problems"), {"average": "micro", "multi_class": "ovr"}), + ((r"sample_weight is not supported for multiclass one-vs-one " + r"ROC AUC, 'sample_weight' must be None in this case"), + {"multi_class": "ovo", "sample_weight": []}), + ((r"Partial AUC computation not available in multiclass setting, " + r"'max_fpr' must be set to `None`, received `max_fpr=0.5` " + r"instead"), {"multi_class": "ovo", "max_fpr": 0.5}), + ((r"multi_class='ovp' is not supported for multiclass ROC AUC, " + r"multi_class must be in \('ovo', 'ovr'\)"), + {"multi_class": "ovp"}), + (r"multi_class must be in \('ovo', 'ovr'\)", {}) +]) +def test_roc_auc_score_multiclass_error(msg, kwargs): + # Test that roc_auc_score function returns an error when trying + # to compute multiclass AUC for parameters where an output + # is not defined. + rng = check_random_state(404) + y_score = rng.rand(20, 3) + y_prob = softmax(y_score) + y_true = rng.randint(0, 3, size=20) + with pytest.raises(ValueError, match=msg): + roc_auc_score(y_true, y_prob, **kwargs) + + def test_auc_score_non_binary_class(): # Test that roc_auc_score function returns an error when trying # to compute AUC for non-binary class values. @@ -455,10 +634,6 @@ def test_auc_score_non_binary_class(): y_true = np.full(10, -1, dtype="int") assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred) - # y_true contains three different class values - y_true = rng.randint(0, 3, size=10) - assert_raise_message(ValueError, "multiclass format is not supported", - roc_auc_score, y_true, y_pred) with warnings.catch_warnings(record=True): rng = check_random_state(404) @@ -474,11 +649,6 @@ def test_auc_score_non_binary_class(): assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred) - # y_true contains three different class values - y_true = rng.randint(0, 3, size=10) - assert_raise_message(ValueError, "multiclass format is not supported", - roc_auc_score, y_true, y_pred) - def test_binary_clf_curve(): rng = check_random_state(404) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index f1b9120b06442..f7d41eda0075c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -52,7 +52,7 @@ 'recall', 'recall_weighted', 'recall_macro', 'recall_micro', 'neg_log_loss', 'log_loss', 'brier_score_loss', 'jaccard', 'jaccard_weighted', 'jaccard_macro', - 'jaccard_micro'] + 'jaccard_micro', 'roc_auc_ovr', 'roc_auc_ovo'] # All supervised cluster scorers (They behave like classification metric) CLUSTER_SCORERS = ["adjusted_rand_score",