Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Enforce n_folds >= 2 for k-fold cross-validation #2054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2013

Conversation

ogrisel
Copy link
Member

@ogrisel ogrisel commented Jun 11, 2013

Users might be confused if they set n_folds=1 in KFold or StratifiedKFold as they would get empty training sets that can often results in model fit weird error messages.

This can happen if they naively use cv=1 in GridSearchCV for instance. See: #2048 .

@jnothman
Copy link
Member

LGTM

@ogrisel
Copy link
Member Author

ogrisel commented Jun 11, 2013

Thanks @jnothman.

Any second reviewer? @GaelVaroquaux @larsmans @mblondel ? If nobody has an objection I will merge tonight.

@@ -366,6 +372,13 @@ def __init__(self, y, n_folds=3, indices=True, k=None):
_validate_kfold(n_folds, n)
_, y_sorted = unique(y, return_inverse=True)
min_labels = np.min(np.bincount(y_sorted))
n_folds = int(n_folds)
if n_folds < 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message can be factored out to a private method or function for maintainability. Maybe the check too.

Otherwise, 👍 for merge.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! We didn't even notice there is already a helper function _validate_kfold that tests for k <= 0! And then KFold goes on to do more validation that equally applies to StratifiedKFold.

This is the problem of looking at diffs. No longer LGTM for merge :s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And I would vote for making an ABC to cover both cases, but KFold extracts cintiguous slices while stratiied uses i::n_folds. The latter would work for both, but I assume we need to ensure backwards-compatibility.)

@jnothman
Copy link
Member

I'm sure this PR isn't the right place for it, but, is the following not bad behaviour of KFold?

>>> for n in range(20, 30):
...   print('%d samples in 10 folds' % n, [len(test) for train, test in KFold(n,10)])
... 
20 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
21 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 3]
22 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 4]
23 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 5]
24 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 6]
25 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 7]
26 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 8]
27 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 9]
28 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 10]
29 samples in 10 folds [2, 2, 2, 2, 2, 2, 2, 2, 2, 11]

The folds here are very unbalanced! Of course one should ideally have n >> k...

@erg
Copy link
Contributor

erg commented Jun 11, 2013

I'm sure this PR isn't the right place for it, but, is the following not bad behaviour of KFold?

Ooh! I had something for this...

def _fair_array_counts(n_samples, n_classes, random_state=None):
    """Tries to fairly partition n_samples between n_classes.

    If this cannot be done fairly, +1 is added `remainder` times
    to the counts for random arrays until a total of `n_samples` is
    reached.

    >>> _fair_array_counts(5, 3, random_state=43)
    array([2, 1, 2])
    """
    if n_classes > n_samples:
        raise ValueError("The number of classes is greater"
                         " than the number of samples requested")
    sample_size = n_samples // n_classes
    sample_size_rem = n_samples % n_classes
    counts = np.repeat(sample_size, n_classes)
    if sample_size_rem > 0:
        counts[:sample_size_rem] += 1
        # Shuffle so the class inbalance varies between runs
        random_state = check_random_state(random_state)
        random_state.shuffle(counts)
    return counts


for n in range(20, 31):
    print ml._fair_array_counts(n,10,random_state=3)
[2 2 2 2 2 2 2 2 2 2]
[2 2 2 2 2 2 2 3 2 2]
[2 2 3 2 2 2 2 3 2 2]
[2 2 3 3 2 2 2 3 2 2]
[2 2 3 3 2 2 2 3 3 2]
[2 3 3 3 2 2 2 3 3 2]
[3 3 3 3 2 2 2 3 3 2]
[3 3 3 3 2 3 2 3 3 2]
[3 3 3 3 2 3 3 3 3 2]
[3 3 3 3 2 3 3 3 3 3]
[3 3 3 3 3 3 3 3 3 3]

@jnothman
Copy link
Member

What is the benefit of distributing the remainder randomly?

@erg
Copy link
Contributor

erg commented Jun 12, 2013

It supports randomness because I wrote it for a resample function PR.

#1454

@ogrisel
Copy link
Member Author

ogrisel commented Jun 12, 2013

I addressed all the comments about the initial scope of this PR (I think). The unfair distribution of the folds should probably be addressed in another PR I guess.

@larsmans
Copy link
Member

Test failure...

@jnothman
Copy link
Member

Test failure...

Yes, I don't think renaming n to n_samples is within scope of this PR. It would have to be a full deprecation and applied consistently across the module where n is repeatedly used with this meaning.

@ogrisel
Copy link
Member Author

ogrisel commented Jun 13, 2013

I fixed the broken doctest. Unfortunately it's not possible to provide backward compat for positional arguments that are called as kwargs.

If you really want I can revert the n to n_samples renaming.

@jnothman
Copy link
Member

If you really want I can revert the n to n_samples renaming.

If not, change it for the rest of the module!

@ogrisel
Copy link
Member Author

ogrisel commented Jun 13, 2013

Ok I reverted the n_samples renaming. I think this PR is in a fine state now.

"k-folds cross validation requires at least one"
" train / test split by setting n_folds=2 or more,"
" got n_folds=%d."
% n_folds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use .format() instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it seems there are a lot of % formats in the codebase. And last time @ogrisel tried otherwise, I had to change {} to {0}, {1} to make Jenkins (i.e. Py2.6) happy...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, many (I'm probably culpable of it myself as well in past PRs).

Annoying about the python 2.6 issue. Maybe this is an issue to bring up when we move to python 2.7 as the minimum version?

@ogrisel
Copy link
Member Author

ogrisel commented Jun 14, 2013

Ok I switched to new style format and rebased everything. We want to keep support for 2.6 till 2015 (end of the Ubuntu 10.04 LTS support) if it's not too cumbersome.

@larsmans
Copy link
Member

Are we switching to .format now? I still use % even in new code...

@ogrisel
Copy link
Member Author

ogrisel commented Jun 14, 2013

Are we switching to .format now? I still use % even in new code...

I would not say it's mandatory. It tend to prefer it a bit over the old syntax as I find it a bit more explicit for python newcomers but I don't think the python community is likely to deprecate the % notation any time soon.

@larsmans
Copy link
Member

Apparently they don't dare actually deprecate it. But fair enough, I'll start learning str.format :)

@jnothman
Copy link
Member

It tend to prefer it a bit over the old syntax as I find it a bit more explicit for python newcomers

It's also problematic to use % where you might unwittingly have a tuple on the rhs...

@ogrisel
Copy link
Member Author

ogrisel commented Jun 16, 2013

Shall we merge this PR?

@ogrisel
Copy link
Member Author

ogrisel commented Jun 16, 2013

It's also problematic to use % where you might unwittingly have a tuple on the rhs...

+1

jnothman added a commit that referenced this pull request Jun 21, 2013
ENH Enforce n_folds >= 2 for k-fold cross-validation
@jnothman jnothman merged commit aa66b62 into scikit-learn:master Jun 21, 2013
@ogrisel
Copy link
Member Author

ogrisel commented Jun 21, 2013

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants