-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+2] ENH multiclass balanced accuracy #10587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a09e7ac
362c3cc
d5a065c
da8d27b
11ad2d7
3d9919b
62021a2
2239eac
df2ebbc
05a98e4
23e3976
906d066
34d9ba3
301d475
1dcc881
28a034d
9cf3979
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -417,66 +417,67 @@ In the multilabel case with binary label indicators: :: | |
Balanced accuracy score | ||
----------------------- | ||
|
||
The :func:`balanced_accuracy_score` function computes the | ||
`balanced accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_, which | ||
avoids inflated performance estimates on imbalanced datasets. It is defined as the | ||
arithmetic mean of `sensitivity <https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_ | ||
(true positive rate) and `specificity <https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_ | ||
(true negative rate), or the average of `recall scores <https://en.wikipedia.org/wiki/Precision_and_recall>`_ | ||
obtained on either class. | ||
|
||
If the classifier performs equally well on either class, this term reduces to the | ||
conventional accuracy (i.e., the number of correct predictions divided by the total | ||
number of predictions). In contrast, if the conventional accuracy is above chance only | ||
because the classifier takes advantage of an imbalanced test set, then the balanced | ||
accuracy, as appropriate, will drop to 50%. | ||
|
||
If :math:`\hat{y}_i\in\{0,1\}` is the predicted value of | ||
the :math:`i`-th sample and :math:`y_i\in\{0,1\}` is the corresponding true value, | ||
then the balanced accuracy is defined as | ||
The :func:`balanced_accuracy_score` function computes the `balanced accuracy | ||
<https://en.wikipedia.org/wiki/Accuracy_and_precision>`_, which avoids inflated | ||
performance estimates on imbalanced datasets. It is the macro-average of recall | ||
scores per class or, equivalently, raw accuracy where each sample is weighted | ||
according to the inverse prevalence of its true class. | ||
Thus for balanced datasets, the score is equal to accuracy. | ||
|
||
In the binary case, balanced accuracy is equal to the arithmetic mean of | ||
`sensitivity <https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_ | ||
(true positive rate) and `specificity | ||
<https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_ (true negative | ||
rate), or the area under the ROC curve with binary predictions rather than | ||
scores. | ||
|
||
If the classifier performs equally well on either class, this term reduces to | ||
the conventional accuracy (i.e., the number of correct predictions divided by | ||
the total number of predictions). | ||
|
||
In contrast, if the conventional accuracy is above chance only because the | ||
classifier takes advantage of an imbalanced test set, then the balanced | ||
accuracy, as appropriate, will drop to :math:`\frac{1}{\text{n_classes}}`. | ||
|
||
The score ranges from 0 to 1, or when ``adjusted=True`` is used, it rescaled to | ||
the range :math:`\frac{1}{1 - \text{n_classes}}` to 1, inclusive, with | ||
performance at random scoring 0. | ||
|
||
If :math:`y_i` is the true value of the :math:`i`-th sample, and :math:`w_i` | ||
is the corresponding sample weight, then we adjust the sample weight to: | ||
|
||
.. math:: | ||
|
||
\texttt{balanced-accuracy}(y, \hat{y}) = \frac{1}{2} \left(\frac{\sum_i 1(\hat{y}_i = 1 \land y_i = 1)}{\sum_i 1(y_i = 1)} + \frac{\sum_i 1(\hat{y}_i = 0 \land y_i = 0)}{\sum_i 1(y_i = 0)}\right) | ||
\hat{w}_i = \frac{w_i}{\sum_j{1(y_j = y_i) w_j}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I give the equation assuming w_i=1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine if we let the general formula. |
||
|
||
where :math:`1(x)` is the `indicator function <https://en.wikipedia.org/wiki/Indicator_function>`_. | ||
Given predicted :math:`\hat{y}_i` for sample :math:`i`, balanced accuracy is | ||
defined as: | ||
|
||
Under this definition, the balanced accuracy coincides with :func:`roc_auc_score` | ||
given binary ``y_true`` and ``y_pred``: | ||
.. math:: | ||
|
||
>>> import numpy as np | ||
>>> from sklearn.metrics import balanced_accuracy_score, roc_auc_score | ||
>>> y_true = [0, 1, 0, 0, 1, 0] | ||
>>> y_pred = [0, 1, 0, 0, 0, 1] | ||
>>> balanced_accuracy_score(y_true, y_pred) | ||
0.625 | ||
>>> roc_auc_score(y_true, y_pred) | ||
0.625 | ||
\texttt{balanced-accuracy}(y, \hat{y}, w) = \frac{1}{\sum{\hat{w}_i}} \sum_i 1(\hat{y}_i = y_i) \hat{w}_i | ||
|
||
(but in general, :func:`roc_auc_score` takes as its second argument non-binary scores). | ||
With ``adjusted=True``, balanced accuracy reports the relative increase from | ||
:math:`\texttt{balanced-accuracy}(y, \mathbf{0}, w) = | ||
\frac{1}{\text{n_classes}}`. In the binary case, this is also known as | ||
`*Youden's J statistic* <https://en.wikipedia.org/wiki/Youden%27s_J_statistic>`_, or *informedness*. | ||
|
||
.. note:: | ||
|
||
Currently this score function is only defined for binary classification problems, you | ||
may need to wrap it by yourself if you want to use it for multilabel problems. | ||
The multiclass definition here seems the most reasonable extension of the | ||
metric used in binary classification, though there is no certain consensus | ||
in the literature: | ||
|
||
There is no clear consensus on the definition of a balanced accuracy for the | ||
multiclass setting. Here are some definitions that can be found in the literature: | ||
|
||
* Macro-average recall as described in [Mosley2013]_, [Kelleher2015]_ and [Guyon2015]_: | ||
the recall for each class is computed independently and the average is taken over all classes. | ||
In [Guyon2015]_, the macro-average recall is then adjusted to ensure that random predictions | ||
have a score of :math:`0` while perfect predictions have a score of :math:`1`. | ||
One can compute the macro-average recall using ``recall_score(average="macro")`` in :func:`recall_score`. | ||
* Our definition: [Mosley2013]_, [Kelleher2015]_ and [Guyon2015]_, where | ||
[Guyon2015]_ adopt the adjusted version to ensure that random predictions | ||
have a score of :math:`0` and perfect predictions have a score of :math:`1`.. | ||
* Class balanced accuracy as described in [Mosley2013]_: the minimum between the precision | ||
and the recall for each class is computed. Those values are then averaged over the total | ||
number of classes to get the balanced accuracy. | ||
* Balanced Accuracy as described in [Urbanowicz2015]_: the average of sensitivity and selectivity | ||
* Balanced Accuracy as described in [Urbanowicz2015]_: the average of sensitivity and specificity | ||
is computed for each class and then averaged over total number of classes. | ||
|
||
Note that none of these different definitions are currently implemented within | ||
the :func:`balanced_accuracy_score` function. | ||
|
||
.. topic:: References: | ||
|
||
.. [Guyon2015] I. Guyon, K. Bennett, G. Cawley, H.J. Escalante, S. Escalera, T.K. Ho, N. Macià, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from __future__ import division, print_function | ||
|
||
import numpy as np | ||
from scipy import linalg | ||
from functools import partial | ||
from itertools import product | ||
import warnings | ||
|
||
import numpy as np | ||
from scipy import linalg | ||
import pytest | ||
|
||
from sklearn import datasets | ||
|
@@ -31,6 +31,7 @@ | |
|
||
from sklearn.metrics import accuracy_score | ||
from sklearn.metrics import average_precision_score | ||
from sklearn.metrics import balanced_accuracy_score | ||
from sklearn.metrics import classification_report | ||
from sklearn.metrics import cohen_kappa_score | ||
from sklearn.metrics import confusion_matrix | ||
|
@@ -1675,3 +1676,26 @@ def test_brier_score_loss(): | |
# calculate even if only single class in y_true (#6980) | ||
assert_almost_equal(brier_score_loss([0], [0.5]), 0.25) | ||
assert_almost_equal(brier_score_loss([1], [0.5]), 0.25) | ||
|
||
|
||
def test_balanced_accuracy_score_unseen(): | ||
assert_warns_message(UserWarning, 'y_pred contains classes not in y_true', | ||
balanced_accuracy_score, [0, 0, 0], [0, 0, 1]) | ||
|
||
|
||
@pytest.mark.parametrize('y_true,y_pred', | ||
[ | ||
(['a', 'b', 'a', 'b'], ['a', 'a', 'a', 'b']), | ||
(['a', 'b', 'c', 'b'], ['a', 'a', 'a', 'b']), | ||
(['a', 'a', 'a', 'b'], ['a', 'b', 'c', 'b']), | ||
]) | ||
def test_balanced_accuracy_score(y_true, y_pred): | ||
macro_recall = recall_score(y_true, y_pred, average='macro', | ||
labels=np.unique(y_true)) | ||
with ignore_warnings(): | ||
# Warnings are tested in test_balanced_accuracy_score_unseen | ||
balanced = balanced_accuracy_score(y_true, y_pred) | ||
assert balanced == pytest.approx(macro_recall) | ||
adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True) | ||
chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0])) | ||
assert adjusted == (balanced - chance) / (1 - chance) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason we can't use |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should have a "weight-balanced" option for
class_weight
. It would be interesting to see if that improved imbalanced boosting.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently my phone wrote "weight-loss card" (!) there. Amended.