Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+2] Nonrepeating ROC thresholds #3268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

jblackburne
Copy link
Contributor

Added a unit test to ensure that there are no spurious repeating values in the thresholds returned by roc_curve because of machine precision, and a quick stab at a fix.

Without the fix, the thresholds array produced in the unit test starts with the values [1, 0.99, 0.98, 0.98, ...]. There should be only one 0.98. The undesirable behavior stems from machine epsilon differences between predicted probabilities that are mathematically equivalent.

I am not sure that my suggested fix is the best approach, so I welcome comments.

…es in the thresholds returned by roc_curve because of machine precision, and a quick stab at a fix.
@coveralls
Copy link

Coverage Status

Coverage increased (+0.0%) when pulling 84cac35 on jblackburne:nonrepeating_roc_thresholds into 662fe00 on scikit-learn:master.


# How well can the classifier predict whether a digit is less than 5?
# This task contributes floating point roundoff errors to the probabilities
probas_pred = clf.fit(X[::2], y[::2]).predict_proba(X[1::2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please use explicit train test splits instead of slices with steps?

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=0)

I find the code more readable with explicit variable names instead of slices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using train_test_split would necessitate adding a new import to the test file, which I'd rather not do. What if I were to switch to named slice objects? That would provide the same readability.

…lds test. Added named slice variables to same, for readability.
… ROC thresholds. Added a comment explaining why the treatment is needed.
distinct_value_indices = np.where(np.diff(y_score))[0]
# We need to use isclose to avoid spurious repeated thresholds
# stemming from floating point roundoff errors.
distinct_value_indices = np.where(np.logical_not(np.isclose(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, np.isclose was added in numpy 1.7, so this will not work with earlier versions. I will look into fixing this tomorrow

@arjoly
Copy link
Member

arjoly commented Jun 12, 2014

I have reproduce your testing case

from sklearn import datasets
dataset = datasets.load_digits()
X = dataset['data']
y = dataset['target']
from sklearn import ensemble
clf = ensemble.RandomForestClassifier(n_estimators=100, random_state=0)
train, test = slice(None, None, 2), slice(1, None, 2)
probas_pred = clf.fit(X[train], y[train]).predict_proba(X[test])
probas_pred = probas_pred[:, :5].sum(axis=1)
y_true = [yy < 5 for yy in y[1::2]]
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_true, probas_pred)

From my understanding, the issue comes from floating point rounding in RandomForestClassifier and not from the roc_curve function.

In [28]: numpy.unique(probas_pred).size
Out[28]: 168

In [32]: numpy.unique(thresholds).size
Out[32]: 168

Furthermore, the auc doesn't change much even with 2 digits rounding:

In [35]: roc_auc_score(y_true, probas_pred)
Out[35]: 0.99765378147925865

In [37]: probas_pred_rounded = numpy.round(probas_pred, 2)
In [38]: roc_auc_score(y_true, probas_pred_rounded)
Out[38]: 0.99765378147925854

@jblackburne
Copy link
Contributor Author

Arnaud,

There is no issue in RandomForestClassifier. It returns 101 or fewer unique values, as expected for a 100-tree forest. The roundoff errors arise from the sum I perform over 5 columns of probas_pred in the unit test. This (perhaps somewhat unusual) use case interacts with roc_curve to produce effectively duplicate thresholds. I agree that the auc will not change. But it does seem undesirable to surprise a user of roc_curve with duplicate thresholds.

Tonight I'll push a changeset that fixes the double use of probas_pred in the unit test, to avoid confusion.

@arjoly
Copy link
Member

arjoly commented Jun 12, 2014

Still the issue is not in roc_curve which worked as expected (168 unique value in, 168 thresholds out), but from the summation probas_pred = probas_pred[:, :5].sum(axis=1).

On a side note if this really mater, you could do that sort of summation using a kahan-summation algorithm or even better a Kahan-Babuska-Summation-Algorithm.

@jblackburne
Copy link
Contributor Author

I see your point. 168 unique values in, 168 thresholds out indeed. But in my opinion, a better algorithm for selecting ROC curve thresholds would be to use only the y_score values that differ by "significant" amounts, for some definition of "significant". It adds no value to have separate thresholds equal to 0.98 and 0.9800000000000001, and it seems inelegant to allow floating point roundoff to influence the threshold selection.

However, if the prevailing opinion is not to modify the threshold selection algorithm, I will withdraw this pull request.

@ogrisel
Copy link
Member

ogrisel commented Jun 13, 2014

I also find machine precision levels duplicated thresholds curious and it might cause artifacts in a matplotlib display of the curve. On the other hand don't think it worth slowing down the computation of predict_proba by implementing a more numerically precise sum there. So rounding the scores to 7 significant digits in the roc_curve function sounds ok to me. Maybe we add the kwarg rounding=7 (by default) to control the number of significant digits and let the user set rounding=None to disable the rounding altogether.

@coveralls
Copy link

Coverage Status

Coverage decreased (-0.06%) when pulling 58c0a75 on jblackburne:nonrepeating_roc_thresholds into 662fe00 on scikit-learn:master.

@jblackburne
Copy link
Contributor Author

Are they any further concerns or obstacles to merging this one?

@ogrisel
Copy link
Member

ogrisel commented Jun 17, 2014

This looks good to me. +1 on my side. @arjoly have you changed your mind or not? Other's opinion? @GaelVaroquaux @larsmans @jnothman @mblondel @agramfort?

@arjoly
Copy link
Member

arjoly commented Jun 17, 2014

I haven't changed my mind, but it's ok if you find other people supporting this pull request.

@jnothman
Copy link
Member

I am fine with this. +1

@jnothman jnothman changed the title Nonrepeating ROC thresholds [MRG+2] Nonrepeating ROC thresholds Jun 29, 2014
@jnothman
Copy link
Member

(although I do wonder whether the test is robust to changes in the randomforestclassifier implementation)

@ogrisel
Copy link
Member

ogrisel commented Jul 1, 2014

Merging then. If the random forest implementation changes in a way that makes this test fail, we can always rewrite it to use fixed values. However I think it's best to show a real case that stems from the library.

ogrisel added a commit that referenced this pull request Jul 1, 2014
@ogrisel ogrisel merged commit 777123d into scikit-learn:master Jul 1, 2014
@jblackburne jblackburne deleted the nonrepeating_roc_thresholds branch November 27, 2014 07:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants