-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MRG+1] Ensure correct LabelKFold folds when shuffle=True #5300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -412,6 +412,16 @@ def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): | |
" than the number of labels: {1}.").format(n_folds, | ||
n_labels)) | ||
|
||
if shuffle: | ||
# In case of ties in label weights, label names are indirectly | ||
# used to assign samples to folds. When shuffle=True, label names | ||
# are randomized to obtain random fold assigments. | ||
rng = check_random_state(self.random_state) | ||
unique_labels = np.arange(n_labels, dtype=np.int) | ||
rng.shuffle(unique_labels) | ||
labels = unique_labels[labels] | ||
unique_labels, labels = np.unique(labels, return_inverse=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is useless: at this point |
||
|
||
# Weight labels by their number of occurences | ||
n_samples_per_label = np.bincount(labels) | ||
|
||
|
@@ -433,13 +443,9 @@ def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): | |
|
||
self.idxs = label_to_fold[labels] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has a float dtype. It should be initialized with |
||
|
||
if shuffle: | ||
rng = check_random_state(self.random_state) | ||
rng.shuffle(self.idxs) | ||
|
||
def _iter_test_indices(self): | ||
for i in range(self.n_folds): | ||
yield (self.idxs == i) | ||
for f in range(self.n_folds): | ||
yield np.where(self.idxs == f)[0] | ||
|
||
def __repr__(self): | ||
return '{0}.{1}(n_labels={2}, n_folds={3})'.format( | ||
|
@@ -1211,7 +1217,7 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1, | |
- An iterable yielding train/test splits. | ||
|
||
For integer/None inputs, if ``y`` is binary or multiclass, | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. | ||
|
||
Refer :ref:`User Guide <cross_validation>` for the various | ||
|
@@ -1385,7 +1391,7 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, | |
- An iterable yielding train/test splits. | ||
|
||
For integer/None inputs, if ``y`` is binary or multiclass, | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. | ||
|
||
Refer :ref:`User Guide <cross_validation>` for the various | ||
|
@@ -1649,7 +1655,7 @@ def check_cv(cv, X=None, y=None, classifier=False): | |
- An iterable yielding train/test splits. | ||
|
||
For integer/None inputs, if ``y`` is binary or multiclass, | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. | ||
|
||
Refer :ref:`User Guide <cross_validation>` for the various | ||
|
@@ -1722,7 +1728,7 @@ def permutation_test_score(estimator, X, y, cv=None, | |
- An iterable yielding train/test splits. | ||
|
||
For integer/None inputs, if ``y`` is binary or multiclass, | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
:class:`StratifiedKFold` used. If the estimator is a classifier | ||
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. | ||
|
||
Refer :ref:`User Guide <cross_validation>` for the various | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,7 +359,7 @@ def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372 | |
assert_greater(mean_score, 0.85) | ||
|
||
|
||
def test_label_kfold(): | ||
def check_label_kfold(shuffle): | ||
rng = np.random.RandomState(0) | ||
|
||
# Parameters of the test | ||
|
@@ -370,7 +370,10 @@ def test_label_kfold(): | |
# Construct the test data | ||
tolerance = 0.05 * n_samples # 5 percent error allowed | ||
labels = rng.randint(0, n_labels, n_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems to easy to me: the labels are balanced, hence the rebalancing logic and it's interaction with shuffling is not properly evaluated. It would be better to have a couple of labels with 2 to 3 times more samples than the others for instance. |
||
folds = cval.LabelKFold(labels, n_folds).idxs | ||
folds = cval.LabelKFold(labels, | ||
n_folds=n_folds, | ||
shuffle=shuffle, | ||
random_state=rng).idxs | ||
ideal_n_labels_per_fold = n_samples // n_folds | ||
|
||
# Check that folds have approximately the same size | ||
|
@@ -385,7 +388,10 @@ def test_label_kfold(): | |
|
||
# Check that no label is on both sides of the split | ||
labels = np.asarray(labels, dtype=object) | ||
for train, test in cval.LabelKFold(labels, n_folds=n_folds): | ||
for train, test in cval.LabelKFold(labels, | ||
n_folds=n_folds, | ||
shuffle=shuffle, | ||
random_state=rng): | ||
assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) | ||
|
||
# Construct the test data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rest of this function seems to be independent of the beginning of the test. IMHO it would make the test more readable by having two check functions:
|
||
|
@@ -402,7 +408,10 @@ def test_label_kfold(): | |
n_samples = len(labels) | ||
n_folds = 5 | ||
tolerance = 0.05 * n_samples # 5 percent error allowed | ||
folds = cval.LabelKFold(labels, n_folds).idxs | ||
folds = cval.LabelKFold(labels, | ||
n_folds=n_folds, | ||
shuffle=shuffle, | ||
random_state=rng).idxs | ||
ideal_n_labels_per_fold = n_samples // n_folds | ||
|
||
# Check that folds have approximately the same size | ||
|
@@ -416,14 +425,22 @@ def test_label_kfold(): | |
assert_equal(len(np.unique(folds[labels == label])), 1) | ||
|
||
# Check that no label is on both sides of the split | ||
for train, test in cval.LabelKFold(labels, n_folds=n_folds): | ||
for train, test in cval.LabelKFold(labels, | ||
n_folds=n_folds, | ||
shuffle=shuffle, | ||
random_state=rng): | ||
assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) | ||
|
||
# Should fail if there are more folds than labels | ||
labels = np.array([1, 1, 1, 2, 2]) | ||
assert_raises(ValueError, cval.LabelKFold, labels, n_folds=3) | ||
|
||
|
||
def test_label_kfold(): | ||
for shuffle in [False, True]: | ||
yield check_label_kfold, shuffle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need an additional check that tests that the indices returned when |
||
|
||
|
||
def test_shuffle_split(): | ||
ss1 = cval.ShuffleSplit(10, test_size=0.2, random_state=0) | ||
ss2 = cval.ShuffleSplit(10, test_size=2, random_state=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better never to use Python dtype (that have an implicitly defined size / precision level) but instead use numpy dtypes only, e.g.
np.intp
in this case:np.intp
is the smallest integer dtype that is big enough to index any array.