Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Adds multiclass ROC AUC #12789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 123 commits into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
a666180
Add Hand & Till (OvO) and Provost & Domingos (OvR) implementations
maskani-moh Jan 16, 2018
118a700
Add multi-class implementation in roc_auc_score method
maskani-moh Jan 16, 2018
3371b1d
Add tests for multi-class settings OvO and OvR
maskani-moh Jan 16, 2018
d74ce16
Fix binary case roc computation
maskani-moh Jan 17, 2018
805d804
Make scores add up to 1.0
maskani-moh Jan 17, 2018
2bd693e
Fix typo
maskani-moh Jan 17, 2018
fc54dde
Differenciate binary case explicitly to avoid error when multilabel-i…
maskani-moh Jan 17, 2018
133a09a
Fix prediciton scores
maskani-moh Jan 19, 2018
bc40110
Merge remote-tracking branch 'upstream/master' into multiclass-roc-au…
maskani-moh Feb 8, 2018
0d035e3
Merge remote-tracking branch 'upstream/master' into multiclass-roc-au…
maskani-moh Mar 26, 2018
d08f084
Fix test error by setting param dtype=None
maskani-moh Mar 26, 2018
4c7a656
Quick fix
maskani-moh Mar 26, 2018
4723b00
Raise error for partial computation in multiclass
maskani-moh Mar 26, 2018
aa6dd49
Fix pep8
maskani-moh Mar 26, 2018
5af924b
Merge branch 'master' into multiclass_roc_auc
amueller Oct 5, 2018
5c094cd
try adding ovo multiclass scores
amueller Oct 5, 2018
d0393d7
allow roc_auc and macro_roc_auc for multiclass in test_common
amueller Oct 5, 2018
4a0ded6
add multiclass roc_auc metrics to scores, more common tests
amueller Oct 5, 2018
d599552
ovr is same as multilabel
amueller Oct 5, 2018
2cc343a
remove non-existant import
amueller Oct 5, 2018
74bef0d
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_2
thomasjpfan Dec 13, 2018
c91a9bd
RFC: Removes unrelated diffs
thomasjpfan Dec 13, 2018
e4d2443
ENH: Optimizes ovo
thomasjpfan Dec 13, 2018
0f5a088
WIP: Adds tests back
thomasjpfan Dec 13, 2018
1de4333
WIP: ovr supports sample_weigth
thomasjpfan Dec 13, 2018
e169e0d
RFC: Rename with weighted prefix
thomasjpfan Dec 13, 2018
95a117c
RFC: Moves permutation test to common
thomasjpfan Dec 13, 2018
01ba344
RFC: Uses pytest parameters
thomasjpfan Dec 13, 2018
0517eae
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 13, 2018
67f2376
RFC: Minimizes diffs
thomasjpfan Dec 13, 2018
a4ea7a6
DOC: Adds narative
thomasjpfan Dec 13, 2018
96f2c2d
RFC: Lowers line count
thomasjpfan Dec 13, 2018
236504d
DOC: Fixes latex errors
thomasjpfan Dec 14, 2018
8fbcd35
DOC: Update plot_roc for multiclass
thomasjpfan Dec 14, 2018
e98ee89
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 14, 2018
5a89add
DOC: Adds whats new
thomasjpfan Dec 14, 2018
6ea7aa3
RFC: Clears up test
thomasjpfan Dec 15, 2018
c7e1aa8
RFC: Small
thomasjpfan Dec 17, 2018
75a0d7e
RFC: Clears up test
thomasjpfan Dec 17, 2018
e5290f5
RFC: Clears up test
thomasjpfan Dec 17, 2018
d68c3f0
RFC: Simplifies test
thomasjpfan Dec 17, 2018
a5f1845
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 17, 2018
031ed06
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 18, 2018
0b38e38
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 20, 2018
3fb95d6
TST: Adds ValueError test
thomasjpfan Dec 20, 2018
b049962
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Dec 21, 2018
9f34776
DOC: Show plots
thomasjpfan Dec 21, 2018
0044f41
DOC: Adds names
thomasjpfan Dec 21, 2018
a93f06b
ENH: Adds support for strings and labels
thomasjpfan Jan 3, 2019
07a6f8a
RFC: Encodes y_true before passing to auc score
thomasjpfan Jan 3, 2019
f80f584
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jan 3, 2019
d006581
RFC: Adds ovr scorer
thomasjpfan Jan 3, 2019
3a0961e
DOC: Adds roc_auc_score to multiclass docs
thomasjpfan Jan 3, 2019
6e6d998
RFC: Rename variable
thomasjpfan Jan 15, 2019
9058b24
RFC: Rewords error msg
thomasjpfan Jan 15, 2019
be75778
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jan 15, 2019
1699604
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jan 19, 2019
d26d528
RFC: Increases efficiency
thomasjpfan Jan 19, 2019
6978943
STY: Flake8
thomasjpfan Jan 19, 2019
acedbaa
RFC
thomasjpfan Jan 19, 2019
8c038d9
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jan 30, 2019
f39a067
ENH: Uses object form of plt
thomasjpfan Jan 30, 2019
765c71b
STY: Flake8
thomasjpfan Jan 30, 2019
f2c0c2b
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Feb 5, 2019
139f3af
DOC Address comments
thomasjpfan Feb 5, 2019
99f5498
RFC: Uses pytest.rases(match=...)
thomasjpfan Feb 5, 2019
8b4dd6e
RFC Address comments
thomasjpfan Feb 5, 2019
f3c1f0f
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Feb 5, 2019
bd032f9
REV Add tests back
thomasjpfan Feb 5, 2019
36fe583
RFC Uses better name
thomasjpfan Feb 5, 2019
21203e4
DOC
thomasjpfan Feb 5, 2019
43bd6bb
DOC
thomasjpfan Feb 5, 2019
941d810
RFC Uses average
thomasjpfan Feb 5, 2019
fa11e2d
DOC Adds multiclass macro and weighted curves
thomasjpfan Feb 5, 2019
1ec9b3f
STY Flake8
thomasjpfan Feb 5, 2019
5da21f0
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Feb 6, 2019
e40218e
ENH Adds support for integer y_true
thomasjpfan Feb 6, 2019
5a4eaf5
Trigger CI
thomasjpfan Feb 7, 2019
a6f1984
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Feb 7, 2019
8c00a1f
RF Address comments
thomasjpfan Feb 18, 2019
9864378
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Feb 18, 2019
2da0866
RF Address comments
thomasjpfan Feb 18, 2019
66b6690
RF Uses _encode_python
thomasjpfan Feb 18, 2019
870941e
RF Adds comments
thomasjpfan Feb 18, 2019
9b8a843
STY flake8
thomasjpfan Feb 18, 2019
8c8f1de
RFC Address comments
thomasjpfan Feb 19, 2019
5ff4f4c
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Mar 8, 2019
49d6ea3
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Apr 4, 2019
c7d686d
CLN Address comments
thomasjpfan Apr 4, 2019
d96a7d9
CLN Style
thomasjpfan Apr 4, 2019
40cc0a1
DOC Adds new example for mutliclass roc
thomasjpfan Apr 4, 2019
551c32a
DOC Updates example
thomasjpfan Apr 4, 2019
3536851
CLN Uses softmax
thomasjpfan Apr 4, 2019
e3c9e79
BUG Fix
thomasjpfan Apr 4, 2019
994e949
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Apr 18, 2019
8628672
ENH Forces order for labels
thomasjpfan Apr 18, 2019
24f7c98
CLN Uses the word processing
thomasjpfan Apr 18, 2019
3304b66
CLN Address comments
thomasjpfan May 31, 2019
29908e1
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan May 31, 2019
566f313
REV Removes test
thomasjpfan May 31, 2019
9acd61b
DOC Adds reference
thomasjpfan May 31, 2019
8f7c4ef
BLD Trigger CI
thomasjpfan May 31, 2019
71c1d78
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan May 31, 2019
3fcf96f
DOC Removes whats_new
thomasjpfan May 31, 2019
ceed5ee
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jun 15, 2019
38466ff
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jun 16, 2019
1095d7f
DOC Adds more references
thomasjpfan Jun 16, 2019
95e25e9
CLN Address comments
thomasjpfan Jun 16, 2019
76036a7
CLN Change order
thomasjpfan Jun 16, 2019
146491d
CLN Removes multiclass example
thomasjpfan Jun 16, 2019
f01b435
TST Pytest-dist ordering
thomasjpfan Jun 17, 2019
3b2b436
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jun 24, 2019
89de04f
DOC Spacing
thomasjpfan Jun 24, 2019
7999125
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jul 2, 2019
11e87bb
ENH Raises when multi_class is not specified
thomasjpfan Jul 2, 2019
df7efe0
REV Defaults to ovr
thomasjpfan Jul 2, 2019
0646612
STY Minor
thomasjpfan Jul 2, 2019
bfc73c9
TST roc_auc_score defaults to not support multiclass
thomasjpfan Jul 2, 2019
eaf979b
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jul 3, 2019
65fea8e
ENH Adds weighted scorers
thomasjpfan Jul 4, 2019
2085b4d
Merge remote-tracking branch 'upstream/master' into multiclass_roc_auc_3
thomasjpfan Jul 17, 2019
c5101e2
CLN Address comments
thomasjpfan Jul 17, 2019
1399ddd
CLN Uses partial
thomasjpfan Jul 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Others also work in the multiclass case:
confusion_matrix
hinge_loss
matthews_corrcoef
roc_auc_score


Some also work in the multilabel case:
Expand All @@ -331,6 +332,7 @@ Some also work in the multilabel case:
precision_recall_fscore_support
precision_score
recall_score
roc_auc_score
zero_one_loss

And some work with binary and multilabel (but not multiclass) problems:
Expand All @@ -339,7 +341,6 @@ And some work with binary and multilabel (but not multiclass) problems:
:template: function.rst

average_precision_score
roc_auc_score


In the following sub-sections, we will describe each of those functions,
Expand Down Expand Up @@ -1313,9 +1314,52 @@ In multi-label classification, the :func:`roc_auc_score` function is
extended by averaging over the labels as :ref:`above <average>`.

Compared to metrics such as the subset accuracy, the Hamming loss, or the
F1 score, ROC doesn't require optimizing a threshold for each label. The
:func:`roc_auc_score` function can also be used in multi-class classification,
if the predicted outputs have been binarized.
F1 score, ROC doesn't require optimizing a threshold for each label.

The :func:`roc_auc_score` function can also be used in multi-class
classification. Two averaging strategies are currently supported: the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention weighted vs unweighted in this paragraph maybe?

one-vs-one algorithm computes the average of the pairwise ROC AUC scores, and
the one-vs-rest algorithm computes the average of the ROC AUC scores for each
class against all other classes. In both cases, the predicted labels are
provided in an array with values from 0 to ``n_classes``, and the scores
correspond to the probability estimates that a sample belongs to a particular
class. The OvO and OvR algorithms supports weighting uniformly
(``average='macro'``) and weighting by the prevalence (``average='weighted'``).

**One-vs-one Algorithm**: Computes the average AUC of all possible pairwise
combinations of classes. [HT2001]_ defines a multiclass AUC metric weighted
uniformly:

.. math::

\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c (\text{AUC}(j | k) +
\text{AUC}(k | j))

where :math:`c` is the number of classes and :math:`\text{AUC}(j | k)` is the
AUC with class :math:`j` as the positive class and class :math:`k` as the
negative class. In general,
:math:`\text{AUC}(j | k) \neq \text{AUC}(k | j))` in the multiclass
case. This algorithm is used by setting the keyword argument ``multiclass``
to ``'ovo'`` and ``average`` to ``'macro'``.

The [HT2001]_ multiclass AUC metric can be extended to be weighted by the
prevalence:

.. math::

\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c p(j \cup k)(
\text{AUC}(j | k) + \text{AUC}(k | j))

where :math:`c` is the number of classes. This algorithm is used by setting
the keyword argument ``multiclass`` to ``'ovo'`` and ``average`` to
``'weighted'``. The ``'weighted'`` option returns a prevalence-weighted average
as described in [FC2009]_.

**One-vs-rest Algorithm**: Computes the AUC of each class against the rest.
The algorithm is functionally the same as the multilabel case. To enable this
algorithm set the keyword argument ``multiclass`` to ``'ovr'``. Similar to
OvO, OvR supports two types of averaging: ``'macro'`` [F2006]_ and
``'weighted'`` [F2001]_.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do either of these metrics have any notable properties? Are they 0.5 for random predictions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do output 0.5 for random predictions. Do you think this information should be added to the docs?

In applications where a high false positive rate is not tolerable the parameter
``max_fpr`` of :func:`roc_auc_score` can be used to summarize the ROC curve up
Expand All @@ -1341,6 +1385,28 @@ to the given limit.
for an example of using ROC to
model species distribution.

.. topic:: References:

.. [HT2001] Hand, D.J. and Till, R.J., (2001). `A simple generalisation
of the area under the ROC curve for multiple class classification problems.
<http://link.springer.com/article/10.1023/A:1010920819831>`_
Machine learning, 45(2), pp.171-186.

.. [FC2009] Ferri, Cèsar & Hernandez-Orallo, Jose & Modroiu, R. (2009).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not cited. should it be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the provost citation is missing for the OVR.

`An Experimental Comparison of Performance Measures for Classification.
<https://www.math.ucdavis.edu/~saito/data/roc/ferri-class-perf-metrics.pdf>`_
Pattern Recognition Letters. 30. 27-38.

.. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis.
<http://www.sciencedirect.com/science/article/pii/S016786550500303X>`_
Pattern Recognition Letters, 27(8), pp. 861-874.

.. [F2001] Fawcett, T., 2001. `Using rule sets to maximize
ROC performance <http://ieeexplore.ieee.org/document/989510/>`_
In Data Mining, 2001.
Proceedings IEEE International Conference, pp. 131-138.


.. _zero_one_loss:

Zero one loss
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ Changelog
requires less memory.
:pr:`14108`, pr:`14170` by :user:`Alex Henrie <alexhenrie>`.

:mod:`sklearn.metrics`
......................

- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.

:mod:`sklearn.model_selection`
..................

Expand Down
42 changes: 33 additions & 9 deletions examples/model_selection/plot_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,21 @@
The "steepness" of ROC curves is also important, since it is ideal to maximize
the true positive rate while minimizing the false positive rate.

Multiclass settings
-------------------

ROC curves are typically used in binary classification to study the output of
a classifier. In order to extend ROC curve and ROC area to multi-class
or multi-label classification, it is necessary to binarize the output. One ROC
a classifier. In order to extend ROC curve and ROC area to multi-label
classification, it is necessary to binarize the output. One ROC
curve can be drawn per label, but one can also draw a ROC curve by considering
each element of the label indicator matrix as a binary prediction
(micro-averaging).

Another evaluation measure for multi-class classification is
Another evaluation measure for multi-label classification is
macro-averaging, which gives equal weight to the classification of each
label.

.. note::

See also :func:`sklearn.metrics.roc_auc_score`,
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`.
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`

"""
print(__doc__)
Expand All @@ -47,6 +44,7 @@
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.metrics import roc_auc_score

# Import some data to play with
iris = datasets.load_iris()
Expand Down Expand Up @@ -101,8 +99,8 @@


##############################################################################
# Plot ROC curves for the multiclass problem

# Plot ROC curves for the multilabel problem
# ..........................................
# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
Expand Down Expand Up @@ -146,3 +144,29 @@
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()


##############################################################################
# Area under ROC for the multiclass problem
# .........................................
# The :func:`sklearn.metrics.roc_auc_score` function can be used for
# multi-class classification. The mutliclass One-vs-One scheme compares every
# unique pairwise combination of classes. In this section, we calcuate the AUC
# using the OvR and OvO schemes. We report a macro average, and a
# prevalence-weighted average.
y_prob = classifier.predict_proba(X_test)

macro_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
average="macro")
weighted_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
average="weighted")
macro_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
average="macro")
weighted_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
average="weighted")
print("One-vs-One ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
"(weighted by prevalence)"
.format(macro_roc_auc_ovo, weighted_roc_auc_ovo))
print("One-vs-Rest ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
"(weighted by prevalence)"
.format(macro_roc_auc_ovr, weighted_roc_auc_ovr))
72 changes: 72 additions & 0 deletions sklearn/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Noel Dawe <[email protected]>
# License: BSD 3 clause

from itertools import combinations

import numpy as np

Expand Down Expand Up @@ -123,3 +124,74 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
return np.average(score, weights=average_weight)
else:
return score


def _average_multiclass_ovo_score(binary_metric, y_true, y_score,
average='macro'):
"""Average one-versus-one scores for multiclass classification.

Uses the binary metric for one-vs-one multiclass classification,
where the score is computed according to the Hand & Till (2001) algorithm.

Parameters
----------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the list should follow the same order as the parameters passed to __init__

binary_metric : callable
The binary metric function to use that accepts the following as input
y_true_target : array, shape = [n_samples_target]
Some sub-array of y_true for a pair of classes designated
positive and negative in the one-vs-one scheme.
y_score_target : array, shape = [n_samples_target]
Scores corresponding to the probability estimates
of a sample belonging to the designated positive class label

y_true : array-like, shape = (n_samples, )
True multiclass labels.

y_score : array-like, shape = (n_samples, n_classes)
Target scores corresponding to probability estimates of a sample
belonging to a particular class

average : 'macro' or 'weighted', optional (default='macro')
Determines the type of averaging performed on the pairwise binary
metric scores
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account. Classes
are assumed to be uniformly distributed.
``'weighted'``:
Calculate metrics for each label, taking into account the
prevalence of the classes.

Returns
-------
score : float
Average of the pairwise binary metric scores
"""
check_consistent_length(y_true, y_score)

y_true_unique = np.unique(y_true)
n_classes = y_true_unique.shape[0]
n_pairs = n_classes * (n_classes - 1) // 2
pair_scores = np.empty(n_pairs)

is_weighted = average == "weighted"
prevalence = np.empty(n_pairs) if is_weighted else None

# Compute scores treating a as positive class and b as negative class,
# then b as positive class and a as negative class
for ix, (a, b) in enumerate(combinations(y_true_unique, 2)):
a_mask = y_true == a
b_mask = y_true == b
ab_mask = np.logical_or(a_mask, b_mask)

if is_weighted:
prevalence[ix] = np.average(ab_mask)

a_true = a_mask[ab_mask]
b_true = b_mask[ab_mask]

a_true_score = binary_metric(a_true, y_score[ab_mask, a])
b_true_score = binary_metric(b_true, y_score[ab_mask, b])
pair_scores[ix] = (a_true_score + b_true_score) / 2

return np.average(pair_scores, weights=prevalence)
Loading