diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index be0259879a2dc..c057580877f11 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -252,15 +252,16 @@ Some also work in the multilabel case: precision_recall_fscore_support precision_score recall_score + roc_auc_score zero_one_loss -And some work with binary and multilabel (but not multiclass) problems: + +Some work with binary and multilabel (but not multiclass) problems: .. autosummary:: :template: function.rst average_precision_score - roc_auc_score In the following sub-sections, we will describe each of those functions, @@ -976,10 +977,41 @@ In multi-label classification, the :func:`roc_auc_score` function is extended by averaging over the labels as :ref:`above `. Compared to metrics such as the subset accuracy, the Hamming loss, or the -F1 score, ROC doesn't require optimizing a threshold for each label. The -:func:`roc_auc_score` function can also be used in multi-class classification, -if the predicted outputs have been binarized. +F1 score, ROC doesn't require optimizing a threshold for each label. + +The :func:`roc_auc_score` function can also be used in multi-class +classification. [F2009]_ Two averaging strategies are currently supported: the +one-vs-one algorithm computes the average of the pairwise +ROC AUC scores, and the one-vs-rest algorithm +computes the average of the ROC AUC scores for each class against +all other classes. In both cases, the predicted class labels are provided in +an array with values from 0 to ``n_classes``, and the scores correspond to the +probability estimates that a sample belongs to a particular class. + +**One-vs-one Algorithm** +The AUC of each class against each other, computing +the AUC of all possible pairwise combinations :math:`c(c-1)` for a +:math:`c`-dimensional classifier. + +[HT2001]_ Using the uniform class distribution: + +.. math:: \frac{1}{c(c-1)}\sum_{j=1}^c\sum_{k \neq j}^c \textnormal{AUC}(j, k) + +[F2009]_ Weighted by the prevalence of classes `j` and `k`: + +.. math:: \frac{1}{c-1}\sum_{j=1}^c\sum_{k \neq j}^c p(j \cup k)\textnormal{AUC}(j, k) +**One-vs-rest Algorithm** +AUC of each class against the rest. This treats +a :math:`c`-dimensional classifier as :math:`c` two-dimensional classifiers. + +[F2006]_ Using the uniform class distribution: + +.. math:: \frac{\sum_{j=1}^c \textnormal{AUC}(j, \textnormal{rest}_j)}{c} + +[F2001]_ Weighted by the a priori class distribution: + +.. math:: \frac{\sum_{j=1}^c p(j)\textnormal{AUC}(j, \textnormal{rest}_j)}{c} .. image:: ../auto_examples/model_selection/images/sphx_glr_plot_roc_002.png :target: ../auto_examples/model_selection/plot_roc.html @@ -1000,6 +1032,24 @@ if the predicted outputs have been binarized. for an example of using ROC to model species distribution. +.. topic:: References: + + .. [F2001] Fawcett, T., 2001. `Using rule sets to maximize + ROC performance `_ + In Data Mining, 2001. + Proceedings IEEE International Conference, pp. 131-138. + .. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis. + `_ + Pattern Recognition Letters, 27(8), pp. 861-874. + .. [F2009] Ferri, C., Hernandez-Orallo, J., and Modroiu, R., 2009. + `An experimental comparison of performance measures for classification. + `_ + Pattern Recognition Letters, 30(1), pp. 27-38. + .. [HT2001] Hand, D.J. and Till, R.J., 2001. `A simple generalisation + of the area under the ROC curve for multiple class classification problems. + `_ + Machine learning, 45(2), pp.171-186. + .. _zero_one_loss: Zero one loss diff --git a/examples/model_selection/plot_roc.py b/examples/model_selection/plot_roc.py index 475d7b4aba7a6..3a233eb5b79ae 100644 --- a/examples/model_selection/plot_roc.py +++ b/examples/model_selection/plot_roc.py @@ -19,16 +19,39 @@ ------------------- ROC curves are typically used in binary classification to study the output of -a classifier. In order to extend ROC curve and ROC area to multi-class -or multi-label classification, it is necessary to binarize the output. One ROC -curve can be drawn per label, but one can also draw a ROC curve by considering +a classifier. Extensions of ROC curve and ROC area to multi-class +or multi-label classification can use the One-vs-Rest or One-vs-One scheme. + +One-vs-Rest +----------- + +The output is binarized and one ROC curve is drawn per label, +where label is set to be the positive class and all other labels (the "rest") +are considered the negative class. + +The ROC area can be approximated by taking the average--unweighted or weighted +by the a priori class distribution--of the one-vs-rest ROC areas. + +One can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging). -Another evaluation measure for multi-class classification is +Another evaluation measure for one-vs-rest multi-class classification is macro-averaging, which gives equal weight to the classification of each label. +One-vs-One +---------- + +Two ROC curves can be drawn per pair of labels because either of the two +labels can be considered the positive class (and the other the negative +class). The ROC area of a label pair is approximated taking the average of +these two ROC AUC scores. + +The One-vs-One approximation of a multi-class ROC AUC score is the average-- +unweighted or weighted by class prevalence--across all of the pairwise +approximate ROC AUC scores. + .. note:: See also :func:`sklearn.metrics.roc_auc_score`, @@ -39,10 +62,10 @@ import numpy as np import matplotlib.pyplot as plt -from itertools import cycle +from itertools import combinations, cycle from sklearn import svm, datasets -from sklearn.metrics import roc_curve, auc +from sklearn.metrics import roc_curve, auc, roc_auc_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import label_binarize from sklearn.multiclass import OneVsRestClassifier @@ -53,9 +76,8 @@ X = iris.data y = iris.target -# Binarize the output -y = label_binarize(y, classes=[0, 1, 2]) -n_classes = y.shape[1] +classes = np.unique(y) +n_classes = len(classes) # Add noisy features to make the problem harder random_state = np.random.RandomState(0) @@ -72,17 +94,17 @@ y_score = classifier.fit(X_train, y_train).decision_function(X_test) # Compute ROC curve and ROC area for each class + +# Binarize y_test to compute the ROC curve +y_test_binarized = label_binarize(y_test, classes=classes) + fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): - fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) + fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_score[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) -# Compute micro-average ROC curve and ROC area -fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) -roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) - ############################################################################## # Plot of a ROC curve for a specific class @@ -101,7 +123,12 @@ ############################################################################## -# Plot ROC curves for the multiclass problem +# Plot ROC curves for the multiclass problem using One vs. Rest classification. + +# Compute micro-average ROC curve and ROC area +fpr["micro"], tpr["micro"], _ = roc_curve( + y_test_binarized.ravel(), y_score.ravel()) +roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # Compute macro-average ROC curve and ROC area @@ -143,6 +170,63 @@ plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') -plt.title('Some extension of Receiver operating characteristic to multi-class') +plt.title('An extension of ROC to multi-class ' + 'using One-vs-Rest') plt.legend(loc="lower right") plt.show() + +# Compute the One-vs-Rest ROC AUC score, weighted and unweighted +unweighted_roc_auc_ovr = roc_auc_score(y_test, y_score, multiclass="ovr") +weighted_roc_auc_ovr = roc_auc_score( + y_test, y_score, multiclass="ovr", average="weighted") +print("One-vs-Rest ROC AUC scores: {0} (unweighted), {1} (weighted)".format( + unweighted_roc_auc_ovr, weighted_roc_auc_ovr)) + +############################################################################## +# Plot ROC curves for the multiclass problem using One vs. One classification. + +for a, b in combinations(range(n_classes), 2): + # Filter `y_test` and `y_score` to only consider the current + # `a` and `b` class pair. + ab_mask = np.logical_or(y_test == a, y_test == b) + y_true_filtered = y_test[ab_mask] + y_score_filtered = y_score[ab_mask] + + # Compute ROC curve and ROC area with `a` as the positive class + class_a = y_true_filtered == a + fpr[(a, b)], tpr[(a, b)], _ = roc_curve( + class_a, y_score_filtered[:, a]) + roc_auc[(a, b)] = auc(fpr[(a, b)], tpr[(a, b)]) + + # Compute ROC curve and ROC area with `b` as the positive class + class_b = y_true_filtered == b + fpr[(b, a)], tpr[(b, a)], _ = roc_curve( + class_b, y_score_filtered[:, b]) + roc_auc[(b, a)] = auc(fpr[(b, a)], tpr[(b, a)]) + +plt.figure() +for a, b in combinations(range(n_classes), 2): + plt.plot(fpr[(a, b)], tpr[(a, b)], lw=lw, + label='ROC curve: class {0} vs. {1} ' + '(area = {2:0.2f})'.format( + a, b, roc_auc[(a, b)])) + plt.plot(fpr[(b, a)], tpr[(b, a)], lw=lw, + label='ROC curve: class {0} vs. {1} ' + '(area = {2:0.2f})'.format( + b, a, roc_auc[(b, a)])) +plt.plot([0, 1], [0, 1], 'k--', lw=lw) +plt.xlim([0.0, 1.0]) +plt.ylim([0.0, 1.05]) +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('An extension of ROC to multi-class ' + 'using One-vs-One') +plt.legend(bbox_to_anchor=(1.1, 0.30)) +plt.show() + +# Compute the One-vs-One ROC AUC score, weighted and unweighted +unweighted_roc_auc_ovo = roc_auc_score(y_test, y_score, multiclass="ovo") +weighted_roc_auc_ovo = roc_auc_score( + y_test, y_score, multiclass="ovo", average="weighted") +print("One-vs-One ROC AUC scores: {0} (unweighted), {1} (weighted)".format( + unweighted_roc_auc_ovo, weighted_roc_auc_ovo)) diff --git a/sklearn/metrics/base.py b/sklearn/metrics/base.py index 0ad96c1afd059..d2edee1902126 100644 --- a/sklearn/metrics/base.py +++ b/sklearn/metrics/base.py @@ -13,6 +13,7 @@ # License: BSD 3 clause from __future__ import division +import itertools import numpy as np @@ -131,3 +132,69 @@ def _average_binary_score(binary_metric, y_true, y_score, average, return np.average(score, weights=average_weight) else: return score + + +def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average): + """Uses the binary metric for one-vs-one multiclass classification, + where the score is computed according to the Hand & Till (2001) algorithm. + + Parameters + ---------- + y_true : array, shape = [n_samples] + True multiclass labels. + Assumes labels have been recoded to 0 to n_classes. + + y_score : array, shape = [n_samples, n_classes] + Target scores corresponding to probability estimates of a sample + belonging to a particular class + + average : 'macro' or 'weighted', default='macro' + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. Classes + are assumed to be uniformly distributed. + ``'weighted'``: + Calculate metrics for each label, taking into account the a priori + distribution of the classes. + + binary_metric : callable, the binary metric function to use. + Accepts the following as input + y_true_target : array, shape = [n_samples_target] + Some sub-array of y_true for a pair of classes designated + positive and negative in the one-vs-one scheme. + y_score_target : array, shape = [n_samples_target] + Scores corresponding to the probability estimates + of a sample belonging to the designated positive class label + + Returns + ------- + score : float + Average the sum of pairwise binary metric scores + """ + n_classes = len(np.unique(y_true)) + n_pairs = n_classes * (n_classes - 1) // 2 + prevalence = np.empty(n_pairs) + pair_scores = np.empty(n_pairs) + + ix = 0 + for a, b in itertools.combinations(range(n_classes), 2): + a_mask = y_true == a + ab_mask = np.logical_or(a_mask, y_true == b) + + prevalence[ix] = np.sum(ab_mask) / len(y_true) + + y_score_filtered = y_score[ab_mask] + + a_true = a_mask[ab_mask] + b_true = np.logical_not(a_true) + + a_true_score = binary_metric( + a_true, y_score_filtered[:, a]) + b_true_score = binary_metric( + b_true, y_score_filtered[:, b]) + binary_avg_score = (a_true_score + b_true_score) / 2 + pair_scores[ix] = binary_avg_score + + ix += 1 + return (np.average(pair_scores, weights=prevalence) + if average == "weighted" else np.average(pair_scores)) diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index d1f58772de595..6bae5c6759cb6 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -23,6 +23,7 @@ import numpy as np from scipy.sparse import csr_matrix +from ..preprocessing import LabelBinarizer from ..utils import assert_all_finite from ..utils import check_consistent_length from ..utils import column_or_1d, check_array @@ -34,7 +35,7 @@ from ..utils.sparsefuncs import count_nonzero from ..exceptions import UndefinedMetricWarning -from .base import _average_binary_score +from .base import _average_binary_score, _average_multiclass_ovo_score def auc(x, y, reorder=False): @@ -184,23 +185,36 @@ def _binary_average_precision(y_true, y_score, sample_weight=None): average, sample_weight=sample_weight) -def roc_auc_score(y_true, y_score, average="macro", sample_weight=None): +def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro", + sample_weight=None): """Compute Area Under the Curve (AUC) from prediction scores - Note: this implementation is restricted to the binary classification task - or multilabel classification task in label indicator format. - Read more in the :ref:`User Guide `. Parameters ---------- y_true : array, shape = [n_samples] or [n_samples, n_classes] True binary labels in binary label indicators. + The multiclass case expects shape = [n_samples] and labels + with values from 0 to (n_classes-1), inclusive. y_score : array, shape = [n_samples] or [n_samples, n_classes] Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). + The multiclass case expects shape = [n_samples, n_classes] + where the scores correspond to probability estimates. + + multiclass : string, ['ovr' (default), 'ovo'] + Note: multiclass ROC AUC currently only handles the 'macro' and + 'weighted' averages. + + ``'ovr'``: + Calculate metrics for the multiclass case using the one-vs-rest + approach. + ``'ovo'``: + Calculate metrics for the multiclass case using the one-vs-one + approach. average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted'] If ``None``, the scores for each class are returned. Otherwise, @@ -255,9 +269,42 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None): sample_weight=sample_weight) return auc(fpr, tpr, reorder=True) - return _average_binary_score( - _binary_roc_auc_score, y_true, y_score, average, - sample_weight=sample_weight) + y_type = type_of_target(y_true) + y_true = check_array(y_true, ensure_2d=False) + y_score = check_array(y_score, ensure_2d=False) + + if y_type == "multiclass" or (y_type == "binary" and + y_score.ndim == 2 and + y_score.shape[1] > 2): + # validation for multiclass parameter specifications + average_options = ("macro", "weighted") + if average not in average_options: + raise ValueError("Parameter 'average' must be one of {0} for" + " multiclass problems.".format(average_options)) + multiclass_options = ("ovo", "ovr") + if multiclass not in multiclass_options: + raise ValueError("Parameter multiclass='{0}' is not supported" + " for multiclass ROC AUC. 'multiclass' must be" + " one of {1}.".format( + multiclass, multiclass_options)) + if sample_weight is not None: + raise ValueError("Parameter 'sample_weight' is not supported" + " for multiclass one-vs-one ROC AUC." + " 'sample_weight' must be None in this case.") + + if multiclass == "ovo": + return _average_multiclass_ovo_score( + _binary_roc_auc_score, y_true, y_score, average) + else: + y_true = y_true.reshape((-1, 1)) + y_true_multilabel = LabelBinarizer().fit_transform(y_true) + return _average_binary_score( + _binary_roc_auc_score, y_true_multilabel, y_score, average, + sample_weight=sample_weight) + else: + return _average_binary_score( + _binary_roc_auc_score, y_true, y_score, average, + sample_weight=sample_weight) def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 0ba1d858ab7de..12eea9a97f2dc 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -391,6 +391,125 @@ def test_auc_errors(): assert_raises(ValueError, auc, [1.0, 0.0, 0.5], [0.0, 0.0, 0.0]) +def test_multi_ovo_auc_toydata(): + # Tests the one-vs-one multiclass ROC AUC algorithm + # on a small example, representative of an expected use case. + y_true = np.array([0, 1, 0, 2]) + n_labels = len(np.unique(y_true)) + y_scores = np.array( + [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]]) + + # Used to compute the expected output. + # Consider labels 0 and 1: + # positive label is 0, negative label is 1 + score_01 = roc_auc_score([1, 0, 1], [0.1, 0.3, 0.35]) + # positive label is 1, negative label is 0 + score_10 = roc_auc_score([0, 1, 0], [0.8, 0.4, 0.5]) + average_score_01 = (score_01 + score_10) / 2. + + # Consider labels 0 and 2: + score_02 = roc_auc_score([1, 1, 0], [0.1, 0.35, 0]) + score_20 = roc_auc_score([0, 0, 1], [0.1, 0.15, 0.8]) + average_score_02 = (score_02 + score_20) / 2. + + # Consider labels 1 and 2: + score_12 = roc_auc_score([1, 0], [0.4, 0.2]) + score_21 = roc_auc_score([0, 1], [0.3, 0.8]) + average_score_12 = (score_12 + score_21) / 2. + + # Unweighted, one-vs-one multiclass ROC AUC algorithm + sum_avg_scores = average_score_01 + average_score_02 + average_score_12 + ovo_unweighted_coefficient = 2. / (n_labels * (n_labels - 1)) + ovo_unweighted_score = ovo_unweighted_coefficient * sum_avg_scores + assert_almost_equal( + roc_auc_score(y_true, y_scores, multiclass="ovo"), + ovo_unweighted_score) + + # Weighted, one-vs-one multiclass ROC AUC algorithm + # Each term is weighted by the posterior for the positive label. + pair_scores = [average_score_01, average_score_02, average_score_12] + prevalence = [0.75, 0.75, 0.50] + ovo_weighted_score = np.average(pair_scores, weights=prevalence) + assert_almost_equal( + roc_auc_score(y_true, y_scores, multiclass="ovo", average="weighted"), + ovo_weighted_score) + + +def test_multi_ovr_auc_toydata(): + # Tests the unweighted, one-vs-rest multiclass ROC AUC algorithm + # on a small example, representative of an expected use case. + y_true = np.array([0, 1, 2, 2]) + y_scores = np.array( + [[1.0, 0.0, 0.0], [0.1, 0.5, 0.4], [0.1, 0.1, 0.8], [0.3, 0.3, 0.4]]) + # Compute the expected result by individually computing the 'one-vs-rest' + # ROC AUC scores for classes 0, 1, and 2. + out_0 = roc_auc_score([1, 0, 0, 0], y_scores[:, 0]) + out_1 = roc_auc_score([0, 1, 0, 0], y_scores[:, 1]) + out_2 = roc_auc_score([0, 0, 1, 1], y_scores[:, 2]) + result_unweighted = (out_0 + out_1 + out_2) / 3. + + assert_almost_equal( + roc_auc_score(y_true, y_scores, multiclass="ovr"), + result_unweighted) + + # Tests the weighted, one-vs-rest multiclass ROC AUC algorithm + # on the same input + result_weighted = out_0 * 0.25 + out_1 * 0.25 + out_2 * 0.5 + assert_almost_equal( + roc_auc_score(y_true, y_scores, multiclass="ovr", average="weighted"), + result_weighted) + + +def test_multi_auc_score_under_permutation(): + y_score = np.random.rand(100, 3) + y_score[:, 2] += .1 + y_score[:, 1] -= .1 + y_true = np.argmax(y_score, axis=1) + y_true[np.random.randint(len(y_score), size=20)] = np.random.randint( + 2, size=20) + for multiclass in ['ovr', 'ovo']: + for average in ['macro', 'weighted']: + same_score_under_permutation = None + for perm in [[0, 1, 2], [0, 2, 1], [1, 0, 2], + [1, 2, 0], [2, 0, 1], [2, 1, 0]]: + inv_perm = np.zeros(3, dtype=int) + inv_perm[perm] = np.arange(3) + y_score_perm = y_score[:, inv_perm] + y_true_perm = np.take(perm, y_true) + score = roc_auc_score(y_true_perm, y_score_perm, + multiclass=multiclass, average=average) + if same_score_under_permutation is None: + same_score_under_permutation = score + else: + assert_almost_equal(score, same_score_under_permutation) + + +def test_auc_score_multi_error(): + # Test that roc_auc_score function returns an error when trying + # to compute multiclass AUC for parameters where an output + # is not defined. + rng = check_random_state(404) + y_pred = rng.rand(10) + y_true = rng.randint(0, 3, size=10) + average_error_msg = ("Parameter 'average' must be one of " + "('macro', 'weighted') for multiclass problems.") + assert_raise_message(ValueError, average_error_msg, + roc_auc_score, y_true, y_pred, average="sample") + assert_raise_message(ValueError, average_error_msg, + roc_auc_score, y_true, y_pred, average="micro") + multiclass_error_msg = ("Parameter multiclass='invalid' is not " + "supported for multiclass ROC AUC. 'multiclass' " + "must be one of ('ovo', 'ovr').") + assert_raise_message(ValueError, multiclass_error_msg, + roc_auc_score, y_true, y_pred, multiclass="invalid") + sample_weight_error_msg = ("Parameter 'sample_weight' is not supported " + "for multiclass one-vs-one ROC AUC. " + "'sample_weight' must be None in this case.") + assert_raise_message(ValueError, sample_weight_error_msg, + roc_auc_score, y_true, y_pred, + multiclass="ovo", sample_weight=[]) + + def test_auc_score_non_binary_class(): # Test that roc_auc_score function returns an error when trying # to compute AUC for non-binary class values. @@ -406,10 +525,6 @@ def test_auc_score_non_binary_class(): y_true = -np.ones(10, dtype="int") assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred) - # y_true contains three different class values - y_true = rng.randint(0, 3, size=10) - assert_raise_message(ValueError, "multiclass format is not supported", - roc_auc_score, y_true, y_pred) clean_warning_registry() with warnings.catch_warnings(record=True): @@ -426,11 +541,6 @@ def test_auc_score_non_binary_class(): assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred) - # y_true contains three different class values - y_true = rng.randint(0, 3, size=10) - assert_raise_message(ValueError, "multiclass format is not supported", - roc_auc_score, y_true, y_pred) - def test_precision_recall_curve(): y_true, _, probas_pred = make_prediction(binary=True) @@ -601,7 +711,6 @@ def test_score_scale_invariance(): # issue #3864 (and others), where overly aggressive rounding was causing # problems for users with very small y_score values y_true, _, probas_pred = make_prediction(binary=True) - roc_auc = roc_auc_score(y_true, probas_pred) roc_auc_scaled_up = roc_auc_score(y_true, 100 * probas_pred) roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * probas_pred)