-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+2] ENH multiclass balanced accuracy #10587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+2] ENH multiclass balanced accuracy #10587
Conversation
Includes computationally simpler implementation and logically simpler description.
Ahh... passing tests. |
@@ -1357,6 +1357,8 @@ functions or non-estimator constructors. | |||
equal weight by giving each sample a weight inversely related | |||
to its class's prevalence in the training data: | |||
``n_samples / (n_classes * np.bincount(y))``. | |||
**Note** however that this rebalancing does not take the weight of | |||
samples in each class into account. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should have a "weight-balanced" option for class_weight
. It would be interesting to see if that improved imbalanced boosting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently my phone wrote "weight-loss card" (!) there. Amended.
|
||
.. math:: | ||
|
||
\texttt{balanced-accuracy}(y, \hat{y}) = \frac{1}{2} \left(\frac{\sum_i 1(\hat{y}_i = 1 \land y_i = 1)}{\sum_i 1(y_i = 1)} + \frac{\sum_i 1(\hat{y}_i = 0 \land y_i = 0)}{\sum_i 1(y_i = 0)}\right) | ||
\hat{w}_i = \frac{w_i}{\sum_j{1(y_j = y_i) w_j}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I give the equation assuming w_i=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine if we let the general formula.
sklearn/metrics/classification.py
Outdated
sensitivity (true positive rate) and specificity (true negative rate), | ||
or the average recall obtained on either class. It is also equal to the | ||
ROC AUC score given binary inputs. | ||
The balanced accuracy in binary and muitclass classification problems to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo multiclass
While I'm interested in your critique of the docs and implementation, @maskani-moh, I'd mostly like you to verify that this interpretation of balanced accuracy, as accuracy with sample weights assigned to give equal total weight to each class, makes the choice of a multiclass generalisation clear. |
sklearn/metrics/classification.py
Outdated
minlength=n_classes) | ||
if sample_weight is None: | ||
sample_weight = 1 | ||
sample_weight = class_weight.take(encoded_y_true) * sample_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to apply sample_weight
a second time. I thought it was already taken into account when computing the class_weight
. Which paper should I check for references?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think weighted balanced accuracy is reported anywhere, but:
- the PR's implementation matches the incumbed
- the implementation matches the invariance tests for weighting in
metrics.tests.test_common
that are really pretty good, if I must say so myself as the architect...
Generally when we have a value for class_weight
as well as sample_weight
, we weight the samples per the class_weight
(i.e. class_weight.take(y)
), and then we multiply by each sample's weight. Exactly what's happening here. However, currently our handling of class_weight='balanced'
counts the number not the total weight of samples in each class, then assigns each the reciprocal as each sample's weight. I initially used that in this implementation, and was not surprised to find that it failed the tests: repetition of samples was no longer equivalent to integer weights. So here we use the total weight (not the cardinality) in determining the class weight, reciprocate that, but still assign each sample its weight so that we can correctly calculate the weighted confusion matrix.
Which makes me think: we should be able to implement this even more simply from the confusion matrix... I'll play with that another time soon.
The implementation with the confusion matrix seems really straight forward. It looks like an average of the TPR per classes. The generalization from binary to multi-class look good to me. I don't see a case where it would not be correct. |
doc/modules/model_evaluation.rst
Outdated
|
||
In contrast, if the conventional accuracy is above chance only because the | ||
classifier takes advantage of an imbalanced test set, then the balanced | ||
accuracy, as appropriate, will drop to 1/(number of classes). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use a math environment?
:math:`\frac{1}{# classes}`
doc/modules/model_evaluation.rst
Outdated
accuracy, as appropriate, will drop to 1/(number of classes). | ||
|
||
The score ranges from 0 to 1, or when ``adjusted=True`` is used, it rescaled | ||
to the range [1 / (1 - number of classes), 1] with performance at random being |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also find the range difficult to read in the doc. I would go for an math environment.
sklearn/metrics/classification.py
Outdated
adjusted : bool, default=False | ||
When true, the result is adjusted for chance, so that random | ||
performance would score 0, and perfect performance scores 1. | ||
|
||
Returns | ||
------- | ||
balanced_accuracy : float. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might change the sensitivity/specificity explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
doc/modules/model_evaluation.rst
Outdated
With ``adjusted=True``, balanced accuracy reports the relative increase from | ||
:math:`\texttt{balanced-accuracy}(y, \mathbf{0}, w) = | ||
\frac{1}{\text{n classes}}`. In the binary case, this is also known as | ||
*Youden's J statistic*, or *informedness*. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe link to https://en.wikipedia.org/wiki/Youden%27s_J_statistic
LGTM. @maskani-moh Could you have a look and tell us WYT? |
This should be quick to review if someone (other than @glemaitre who has given his +1) is keen to throw it into 0.20. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM at a glance. I need (and promise) to double check the code and refs tomorrow.
Some small comments, feel free to ignore if you think current version is fine.
My LGTM on the PR is based on the fact that the function is there. Honestly, I don't like the idea of including such a function, which can simply be implemented using recall.
Tagging 0.20.
doc/modules/model_evaluation.rst
Outdated
|
||
In contrast, if the conventional accuracy is above chance only because the | ||
classifier takes advantage of an imbalanced test set, then the balanced | ||
accuracy, as appropriate, will drop to :math:`\frac{1}{\text{n\_classes}}`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\text{n\_classes}
-> \text{n_classes}
? Or maybe some other way to get rid of the extra \
here.
Same comment for similar places below.
doc/modules/model_evaluation.rst
Outdated
|
||
The score ranges from 0 to 1, or when ``adjusted=True`` is used, it rescaled to | ||
the range :math:`\frac{1}{1 - \text{n\_classes}}` to 1, inclusive, with | ||
performance at random scoring 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems strange. "Rescaled to the range A to B, with performance at random scoring 0". But 0 is actually not in [A, B]?
I'd prefer a clearer explanation for the scaling strategy we use when adjusted=True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry. I realized that I'm wrong here.
doc/modules/model_evaluation.rst
Outdated
have a score of :math:`0` while perfect predictions have a score of :math:`1`. | ||
One can compute the macro-average recall using ``recall_score(average="macro")`` in :func:`recall_score`. | ||
* Our definition: [Mosley2013]_, [Kelleher2015]_ and [Guyon2015]_, where | ||
[Guyon2015]_ adopt the adjusted version to score chance as 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is score change as 0
?
sklearn/metrics/classification.py
Outdated
ROC AUC score given binary inputs. | ||
The balanced accuracy in binary and multiclass classification problems to | ||
deal with imbalanced datasets. It is defined as the average of recall | ||
obtained on each class. | ||
|
||
The best value is 1 and the worst value is 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer the case when adjusted=True
?
assert balanced == pytest.approx(macro_recall) | ||
adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True) | ||
chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0])) | ||
assert adjusted == (balanced - chance) / (1 - chance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we can't use ==
when adjusted=False
?
doc/modules/model_evaluation.rst
Outdated
0.625 | ||
>>> roc_auc_score(y_true, y_pred) | ||
0.625 | ||
\texttt{balanced-accuracy}(y, \hat{y}, w) = \frac{1}{\sum{\hat{w}_i}} \sum_i 1(\hat{y}_i == y_i) \hat{w}_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it common to use ==
in the indicator function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM apart from the comments above.
doc/whats_new/v0.20.rst
Outdated
:issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia <dalmia>`, and | ||
:issue:`10587` by `Joel Nothman`_. | ||
|
||
- Added :class:`multioutput.RegressorChain` for multi-target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entry should be removed.
Honestly, I don't like the idea of including such a function, which can
be simply implemented by recall.
The adjusted metric can't just be implemented by recall. But really, we've
had years of people asking for balanced accuracy, and not realising that
they could implement it with recall....
|
I don't have time to fix these up right away...
|
@jnothman Do you mind if I push some cosmetic changes and merge this one? |
I don't mind if you're confident about them
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @jnothman
Removing those backslashes broke CircleCI on master. |
Includes computationally simpler implementation and logically simpler description.
See also #10040. Ping @maskani-moh, @amueller.