diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 14203b1dad2d5..18b55ac69ce85 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -412,6 +412,16 @@ def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): " than the number of labels: {1}.").format(n_folds, n_labels)) + if shuffle: + # In case of ties in label weights, label names are indirectly + # used to assign samples to folds. When shuffle=True, label names + # are randomized to obtain random fold assigments. + rng = check_random_state(self.random_state) + unique_labels = np.arange(n_labels, dtype=np.int) + rng.shuffle(unique_labels) + labels = unique_labels[labels] + unique_labels, labels = np.unique(labels, return_inverse=True) + # Weight labels by their number of occurences n_samples_per_label = np.bincount(labels) @@ -433,13 +443,9 @@ def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): self.idxs = label_to_fold[labels] - if shuffle: - rng = check_random_state(self.random_state) - rng.shuffle(self.idxs) - def _iter_test_indices(self): - for i in range(self.n_folds): - yield (self.idxs == i) + for f in range(self.n_folds): + yield np.where(self.idxs == f)[0] def __repr__(self): return '{0}.{1}(n_labels={2}, n_folds={3})'.format( @@ -1211,7 +1217,7 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1, - An iterable yielding train/test splits. For integer/None inputs, if ``y`` is binary or multiclass, - :class:`StratifiedKFold` used. If the estimator is a classifier + :class:`StratifiedKFold` used. If the estimator is a classifier or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. Refer :ref:`User Guide ` for the various @@ -1385,7 +1391,7 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - An iterable yielding train/test splits. For integer/None inputs, if ``y`` is binary or multiclass, - :class:`StratifiedKFold` used. If the estimator is a classifier + :class:`StratifiedKFold` used. If the estimator is a classifier or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. Refer :ref:`User Guide ` for the various @@ -1649,7 +1655,7 @@ def check_cv(cv, X=None, y=None, classifier=False): - An iterable yielding train/test splits. For integer/None inputs, if ``y`` is binary or multiclass, - :class:`StratifiedKFold` used. If the estimator is a classifier + :class:`StratifiedKFold` used. If the estimator is a classifier or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. Refer :ref:`User Guide ` for the various @@ -1722,7 +1728,7 @@ def permutation_test_score(estimator, X, y, cv=None, - An iterable yielding train/test splits. For integer/None inputs, if ``y`` is binary or multiclass, - :class:`StratifiedKFold` used. If the estimator is a classifier + :class:`StratifiedKFold` used. If the estimator is a classifier or if ``y`` is neither binary nor multiclass, :class:`KFold` is used. Refer :ref:`User Guide ` for the various diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index f78b9d5d05a7e..2f6cf3142be30 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -359,7 +359,7 @@ def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372 assert_greater(mean_score, 0.85) -def test_label_kfold(): +def check_label_kfold(shuffle): rng = np.random.RandomState(0) # Parameters of the test @@ -370,7 +370,10 @@ def test_label_kfold(): # Construct the test data tolerance = 0.05 * n_samples # 5 percent error allowed labels = rng.randint(0, n_labels, n_samples) - folds = cval.LabelKFold(labels, n_folds).idxs + folds = cval.LabelKFold(labels, + n_folds=n_folds, + shuffle=shuffle, + random_state=rng).idxs ideal_n_labels_per_fold = n_samples // n_folds # Check that folds have approximately the same size @@ -385,7 +388,10 @@ def test_label_kfold(): # Check that no label is on both sides of the split labels = np.asarray(labels, dtype=object) - for train, test in cval.LabelKFold(labels, n_folds=n_folds): + for train, test in cval.LabelKFold(labels, + n_folds=n_folds, + shuffle=shuffle, + random_state=rng): assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) # Construct the test data @@ -402,7 +408,10 @@ def test_label_kfold(): n_samples = len(labels) n_folds = 5 tolerance = 0.05 * n_samples # 5 percent error allowed - folds = cval.LabelKFold(labels, n_folds).idxs + folds = cval.LabelKFold(labels, + n_folds=n_folds, + shuffle=shuffle, + random_state=rng).idxs ideal_n_labels_per_fold = n_samples // n_folds # Check that folds have approximately the same size @@ -416,7 +425,10 @@ def test_label_kfold(): assert_equal(len(np.unique(folds[labels == label])), 1) # Check that no label is on both sides of the split - for train, test in cval.LabelKFold(labels, n_folds=n_folds): + for train, test in cval.LabelKFold(labels, + n_folds=n_folds, + shuffle=shuffle, + random_state=rng): assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) # Should fail if there are more folds than labels @@ -424,6 +436,11 @@ def test_label_kfold(): assert_raises(ValueError, cval.LabelKFold, labels, n_folds=3) +def test_label_kfold(): + for shuffle in [False, True]: + yield check_label_kfold, shuffle + + def test_shuffle_split(): ss1 = cval.ShuffleSplit(10, test_size=0.2, random_state=0) ss2 = cval.ShuffleSplit(10, test_size=2, random_state=0)