-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Enforce n_folds >= 2 for k-fold cross-validation #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LGTM |
Thanks @jnothman. Any second reviewer? @GaelVaroquaux @larsmans @mblondel ? If nobody has an objection I will merge tonight. |
@@ -366,6 +372,13 @@ def __init__(self, y, n_folds=3, indices=True, k=None): | |||
_validate_kfold(n_folds, n) | |||
_, y_sorted = unique(y, return_inverse=True) | |||
min_labels = np.min(np.bincount(y_sorted)) | |||
n_folds = int(n_folds) | |||
if n_folds < 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message can be factored out to a private method or function for maintainability. Maybe the check too.
Otherwise, 👍 for merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha! We didn't even notice there is already a helper function _validate_kfold
that tests for k <= 0! And then KFold
goes on to do more validation that equally applies to StratifiedKFold
.
This is the problem of looking at diffs. No longer LGTM for merge :s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(And I would vote for making an ABC to cover both cases, but KFold
extracts cintiguous slices while stratiied uses i::n_folds
. The latter would work for both, but I assume we need to ensure backwards-compatibility.)
I'm sure this PR isn't the right place for it, but, is the following not bad behaviour of KFold? >>> for n in range(20, 30):
... print('%d samples in 10 folds' % n, [len(test) for train, test in KFold(n,10)])
...
20 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
21 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 3]
22 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 4]
23 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 5]
24 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 6]
25 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 7]
26 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 8]
27 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 9]
28 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 10]
29 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 11] The folds here are very unbalanced! Of course one should ideally have |
Ooh! I had something for this... def _fair_array_counts(n_samples, n_classes, random_state=None):
"""Tries to fairly partition n_samples between n_classes.
If this cannot be done fairly, +1 is added `remainder` times
to the counts for random arrays until a total of `n_samples` is
reached.
>>> _fair_array_counts(5, 3, random_state=43)
array([2, 1, 2])
"""
if n_classes > n_samples:
raise ValueError("The number of classes is greater"
" than the number of samples requested")
sample_size = n_samples // n_classes
sample_size_rem = n_samples % n_classes
counts = np.repeat(sample_size, n_classes)
if sample_size_rem > 0:
counts[:sample_size_rem] += 1
# Shuffle so the class inbalance varies between runs
random_state = check_random_state(random_state)
random_state.shuffle(counts)
return counts
for n in range(20, 31):
print ml._fair_array_counts(n,10,random_state=3)
[2 2 2 2 2 2 2 2 2 2]
[2 2 2 2 2 2 2 3 2 2]
[2 2 3 2 2 2 2 3 2 2]
[2 2 3 3 2 2 2 3 2 2]
[2 2 3 3 2 2 2 3 3 2]
[2 3 3 3 2 2 2 3 3 2]
[3 3 3 3 2 2 2 3 3 2]
[3 3 3 3 2 3 2 3 3 2]
[3 3 3 3 2 3 3 3 3 2]
[3 3 3 3 2 3 3 3 3 3]
[3 3 3 3 3 3 3 3 3 3] |
What is the benefit of distributing the remainder randomly? |
It supports randomness because I wrote it for a resample function PR. |
I addressed all the comments about the initial scope of this PR (I think). The unfair distribution of the folds should probably be addressed in another PR I guess. |
Test failure... |
Yes, I don't think renaming |
I fixed the broken doctest. Unfortunately it's not possible to provide backward compat for positional arguments that are called as kwargs. If you really want I can revert the n to n_samples renaming. |
If not, change it for the rest of the module! |
Ok I reverted the n_samples renaming. I think this PR is in a fine state now. |
"k-folds cross validation requires at least one" | ||
" train / test split by setting n_folds=2 or more," | ||
" got n_folds=%d." | ||
% n_folds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably use .format() instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it seems there are a lot of %
formats in the codebase. And last time @ogrisel tried otherwise, I had to change {}
to {0}
, {1}
to make Jenkins (i.e. Py2.6) happy...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, many (I'm probably culpable of it myself as well in past PRs).
Annoying about the python 2.6 issue. Maybe this is an issue to bring up when we move to python 2.7 as the minimum version?
Ok I switched to new style format and rebased everything. We want to keep support for 2.6 till 2015 (end of the Ubuntu 10.04 LTS support) if it's not too cumbersome. |
Are we switching to |
I would not say it's mandatory. It tend to prefer it a bit over the old syntax as I find it a bit more explicit for python newcomers but I don't think the python community is likely to deprecate the % notation any time soon. |
Apparently they don't dare actually deprecate it. But fair enough, I'll start learning |
It's also problematic to use |
Shall we merge this PR? |
+1 |
ENH Enforce n_folds >= 2 for k-fold cross-validation
Thanks! |
Users might be confused if they set
n_folds=1
inKFold
orStratifiedKFold
as they would get empty training sets that can often results in model fit weird error messages.This can happen if they naively use
cv=1
inGridSearchCV
for instance. See: #2048 .