-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Adds multiclass ROC AUC #12789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a666180
118a700
3371b1d
d74ce16
805d804
2bd693e
fc54dde
133a09a
bc40110
0d035e3
d08f084
4c7a656
4723b00
aa6dd49
5af924b
5c094cd
d0393d7
4a0ded6
d599552
2cc343a
74bef0d
c91a9bd
e4d2443
0f5a088
1de4333
e169e0d
95a117c
01ba344
0517eae
67f2376
a4ea7a6
96f2c2d
236504d
8fbcd35
e98ee89
5a89add
6ea7aa3
c7e1aa8
75a0d7e
e5290f5
d68c3f0
a5f1845
031ed06
0b38e38
3fb95d6
b049962
9f34776
0044f41
a93f06b
07a6f8a
f80f584
d006581
3a0961e
6e6d998
9058b24
be75778
1699604
d26d528
6978943
acedbaa
8c038d9
f39a067
765c71b
f2c0c2b
139f3af
99f5498
8b4dd6e
f3c1f0f
bd032f9
36fe583
21203e4
43bd6bb
941d810
fa11e2d
1ec9b3f
5da21f0
e40218e
5a4eaf5
a6f1984
8c00a1f
9864378
2da0866
66b6690
870941e
9b8a843
8c8f1de
5ff4f4c
49d6ea3
c7d686d
d96a7d9
40cc0a1
551c32a
3536851
e3c9e79
994e949
8628672
24f7c98
3304b66
29908e1
566f313
9acd61b
8f7c4ef
71c1d78
3fcf96f
ceed5ee
38466ff
1095d7f
95e25e9
76036a7
146491d
f01b435
3b2b436
89de04f
7999125
11e87bb
df7efe0
0646612
bfc73c9
eaf979b
65fea8e
2085b4d
c5101e2
1399ddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -313,6 +313,7 @@ Others also work in the multiclass case: | |
confusion_matrix | ||
hinge_loss | ||
matthews_corrcoef | ||
roc_auc_score | ||
|
||
|
||
Some also work in the multilabel case: | ||
|
@@ -331,6 +332,7 @@ Some also work in the multilabel case: | |
precision_recall_fscore_support | ||
precision_score | ||
recall_score | ||
roc_auc_score | ||
zero_one_loss | ||
|
||
And some work with binary and multilabel (but not multiclass) problems: | ||
|
@@ -339,7 +341,6 @@ And some work with binary and multilabel (but not multiclass) problems: | |
:template: function.rst | ||
|
||
average_precision_score | ||
roc_auc_score | ||
|
||
|
||
In the following sub-sections, we will describe each of those functions, | ||
|
@@ -1313,9 +1314,52 @@ In multi-label classification, the :func:`roc_auc_score` function is | |
extended by averaging over the labels as :ref:`above <average>`. | ||
|
||
Compared to metrics such as the subset accuracy, the Hamming loss, or the | ||
F1 score, ROC doesn't require optimizing a threshold for each label. The | ||
:func:`roc_auc_score` function can also be used in multi-class classification, | ||
if the predicted outputs have been binarized. | ||
F1 score, ROC doesn't require optimizing a threshold for each label. | ||
|
||
The :func:`roc_auc_score` function can also be used in multi-class | ||
classification. Two averaging strategies are currently supported: the | ||
one-vs-one algorithm computes the average of the pairwise ROC AUC scores, and | ||
the one-vs-rest algorithm computes the average of the ROC AUC scores for each | ||
class against all other classes. In both cases, the predicted labels are | ||
provided in an array with values from 0 to ``n_classes``, and the scores | ||
correspond to the probability estimates that a sample belongs to a particular | ||
class. The OvO and OvR algorithms supports weighting uniformly | ||
(``average='macro'``) and weighting by the prevalence (``average='weighted'``). | ||
|
||
**One-vs-one Algorithm**: Computes the average AUC of all possible pairwise | ||
combinations of classes. [HT2001]_ defines a multiclass AUC metric weighted | ||
uniformly: | ||
|
||
.. math:: | ||
|
||
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c (\text{AUC}(j | k) + | ||
\text{AUC}(k | j)) | ||
|
||
where :math:`c` is the number of classes and :math:`\text{AUC}(j | k)` is the | ||
AUC with class :math:`j` as the positive class and class :math:`k` as the | ||
negative class. In general, | ||
:math:`\text{AUC}(j | k) \neq \text{AUC}(k | j))` in the multiclass | ||
case. This algorithm is used by setting the keyword argument ``multiclass`` | ||
to ``'ovo'`` and ``average`` to ``'macro'``. | ||
|
||
The [HT2001]_ multiclass AUC metric can be extended to be weighted by the | ||
prevalence: | ||
|
||
.. math:: | ||
|
||
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c p(j \cup k)( | ||
\text{AUC}(j | k) + \text{AUC}(k | j)) | ||
|
||
where :math:`c` is the number of classes. This algorithm is used by setting | ||
the keyword argument ``multiclass`` to ``'ovo'`` and ``average`` to | ||
``'weighted'``. The ``'weighted'`` option returns a prevalence-weighted average | ||
as described in [FC2009]_. | ||
|
||
**One-vs-rest Algorithm**: Computes the AUC of each class against the rest. | ||
The algorithm is functionally the same as the multilabel case. To enable this | ||
algorithm set the keyword argument ``multiclass`` to ``'ovr'``. Similar to | ||
OvO, OvR supports two types of averaging: ``'macro'`` [F2006]_ and | ||
``'weighted'`` [F2001]_. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do either of these metrics have any notable properties? Are they 0.5 for random predictions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They do output 0.5 for random predictions. Do you think this information should be added to the docs? |
||
In applications where a high false positive rate is not tolerable the parameter | ||
``max_fpr`` of :func:`roc_auc_score` can be used to summarize the ROC curve up | ||
|
@@ -1341,6 +1385,28 @@ to the given limit. | |
for an example of using ROC to | ||
model species distribution. | ||
|
||
.. topic:: References: | ||
|
||
.. [HT2001] Hand, D.J. and Till, R.J., (2001). `A simple generalisation | ||
of the area under the ROC curve for multiple class classification problems. | ||
<http://link.springer.com/article/10.1023/A:1010920819831>`_ | ||
Machine learning, 45(2), pp.171-186. | ||
|
||
.. [FC2009] Ferri, Cèsar & Hernandez-Orallo, Jose & Modroiu, R. (2009). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not cited. should it be? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the provost citation is missing for the OVR. |
||
`An Experimental Comparison of Performance Measures for Classification. | ||
<https://www.math.ucdavis.edu/~saito/data/roc/ferri-class-perf-metrics.pdf>`_ | ||
Pattern Recognition Letters. 30. 27-38. | ||
|
||
.. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis. | ||
<http://www.sciencedirect.com/science/article/pii/S016786550500303X>`_ | ||
Pattern Recognition Letters, 27(8), pp. 861-874. | ||
|
||
.. [F2001] Fawcett, T., 2001. `Using rule sets to maximize | ||
ROC performance <http://ieeexplore.ieee.org/document/989510/>`_ | ||
In Data Mining, 2001. | ||
Proceedings IEEE International Conference, pp. 131-138. | ||
|
||
|
||
.. _zero_one_loss: | ||
|
||
Zero one loss | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# Noel Dawe <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
from itertools import combinations | ||
|
||
import numpy as np | ||
|
||
|
@@ -123,3 +124,74 @@ def _average_binary_score(binary_metric, y_true, y_score, average, | |
return np.average(score, weights=average_weight) | ||
else: | ||
return score | ||
|
||
|
||
def _average_multiclass_ovo_score(binary_metric, y_true, y_score, | ||
average='macro'): | ||
"""Average one-versus-one scores for multiclass classification. | ||
|
||
Uses the binary metric for one-vs-one multiclass classification, | ||
where the score is computed according to the Hand & Till (2001) algorithm. | ||
|
||
Parameters | ||
---------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally the list should follow the same order as the parameters passed to |
||
binary_metric : callable | ||
The binary metric function to use that accepts the following as input | ||
y_true_target : array, shape = [n_samples_target] | ||
Some sub-array of y_true for a pair of classes designated | ||
positive and negative in the one-vs-one scheme. | ||
y_score_target : array, shape = [n_samples_target] | ||
Scores corresponding to the probability estimates | ||
of a sample belonging to the designated positive class label | ||
|
||
y_true : array-like, shape = (n_samples, ) | ||
True multiclass labels. | ||
|
||
y_score : array-like, shape = (n_samples, n_classes) | ||
Target scores corresponding to probability estimates of a sample | ||
belonging to a particular class | ||
|
||
average : 'macro' or 'weighted', optional (default='macro') | ||
Determines the type of averaging performed on the pairwise binary | ||
metric scores | ||
``'macro'``: | ||
Calculate metrics for each label, and find their unweighted | ||
mean. This does not take label imbalance into account. Classes | ||
are assumed to be uniformly distributed. | ||
``'weighted'``: | ||
Calculate metrics for each label, taking into account the | ||
prevalence of the classes. | ||
|
||
Returns | ||
------- | ||
score : float | ||
Average of the pairwise binary metric scores | ||
""" | ||
check_consistent_length(y_true, y_score) | ||
|
||
y_true_unique = np.unique(y_true) | ||
n_classes = y_true_unique.shape[0] | ||
n_pairs = n_classes * (n_classes - 1) // 2 | ||
pair_scores = np.empty(n_pairs) | ||
|
||
is_weighted = average == "weighted" | ||
prevalence = np.empty(n_pairs) if is_weighted else None | ||
|
||
# Compute scores treating a as positive class and b as negative class, | ||
# then b as positive class and a as negative class | ||
for ix, (a, b) in enumerate(combinations(y_true_unique, 2)): | ||
a_mask = y_true == a | ||
b_mask = y_true == b | ||
ab_mask = np.logical_or(a_mask, b_mask) | ||
|
||
if is_weighted: | ||
prevalence[ix] = np.average(ab_mask) | ||
|
||
a_true = a_mask[ab_mask] | ||
b_true = b_mask[ab_mask] | ||
|
||
a_true_score = binary_metric(a_true, y_score[ab_mask, a]) | ||
b_true_score = binary_metric(b_true, y_score[ab_mask, b]) | ||
pair_scores[ix] = (a_true_score + b_true_score) / 2 | ||
|
||
return np.average(pair_scores, weights=prevalence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mention weighted vs unweighted in this paragraph maybe?