diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index db4aa1b3250a3..a0334a4b36c4b 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -371,6 +371,11 @@ Support for Python 3.4 and below has been officially dropped. :func:`~model_selection.validation_curve` only the latter is required. :issue:`12613` and :issue:`12669` by :user:`Marc Torrellas `. +- |Enhancement| Some :term:`CV splitter` classes and + `model_selection.train_test_split` now raise ``ValueError`` when the + resulting train set is empty. :issue:`12861` by :user:`Nicolas Hug + `. + - |Fix| Fixed a bug where :class:`model_selection.StratifiedKFold` shuffles each class's samples with the same ``random_state``, making ``shuffle=True`` ineffective. diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index a3df4460cd261..fd6d5eb9c83e8 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -162,7 +162,13 @@ class LeaveOneOut(BaseCrossValidator): """ def _iter_test_indices(self, X, y=None, groups=None): - return range(_num_samples(X)) + n_samples = _num_samples(X) + if n_samples <= 1: + raise ValueError( + 'Cannot perform LeaveOneOut with n_samples={}.'.format( + n_samples) + ) + return range(n_samples) def get_n_splits(self, X, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator @@ -209,7 +215,8 @@ class LeavePOut(BaseCrossValidator): Parameters ---------- p : int - Size of the test sets. + Size of the test sets. Must be strictly greater than the number of + samples. Examples -------- @@ -238,7 +245,13 @@ def __init__(self, p): self.p = p def _iter_test_indices(self, X, y=None, groups=None): - for combination in combinations(range(_num_samples(X)), self.p): + n_samples = _num_samples(X) + if n_samples <= self.p: + raise ValueError( + 'p={} must be strictly less than the number of ' + 'samples={}'.format(self.p, n_samples) + ) + for combination in combinations(range(n_samples), self.p): yield np.array(combination) def get_n_splits(self, X, y=None, groups=None): @@ -1862,7 +1875,17 @@ def _validate_shuffle_split(n_samples, test_size, train_size): 'samples %d. Reduce test_size and/or ' 'train_size.' % (n_train + n_test, n_samples)) - return int(n_train), int(n_test) + n_train, n_test = int(n_train), int(n_test) + + if n_train == 0: + raise ValueError( + 'With n_samples={}, test_size={} and train_size={}, the ' + 'resulting train set will be empty. Adjust any of the ' + 'aforementioned parameters.'.format(n_samples, test_size, + train_size) + ) + + return n_train, n_test class PredefinedSplit(BaseCrossValidator): diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index bdf466b92a7b6..e909bc791b489 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1488,3 +1488,51 @@ def __repr__(self): return _build_repr(self) assert_equal(repr(MockSplitter(5, 6)), "MockSplitter(a=5, b=6, c=None)") + + +@pytest.mark.parametrize('CVSplitter', (ShuffleSplit, GroupShuffleSplit, + StratifiedShuffleSplit)) +def test_shuffle_split_empty_trainset(CVSplitter): + cv = CVSplitter(test_size=.99) + X, y = [[1]], [0] # 1 sample + with pytest.raises( + ValueError, + match='With n_samples=1, test_size=0.99 and train_size=None, ' + 'the resulting train set will be empty'): + next(cv.split(X, y, groups=[1])) + + +def test_train_test_split_empty_trainset(): + X, = [[1]] # 1 sample + with pytest.raises( + ValueError, + match='With n_samples=1, test_size=0.99 and train_size=None, ' + 'the resulting train set will be empty'): + train_test_split(X, test_size=.99) + + X = [[1], [1], [1]] # 3 samples, ask for more than 2 thirds + with pytest.raises( + ValueError, + match='With n_samples=3, test_size=0.67 and train_size=None, ' + 'the resulting train set will be empty'): + train_test_split(X, test_size=.67) + + +def test_leave_one_out_empty_trainset(): + # LeaveOneGroup out expect at least 2 groups so no need to check + cv = LeaveOneOut() + X, y = [[1]], [0] # 1 sample + with pytest.raises( + ValueError, + match='Cannot perform LeaveOneOut with n_samples=1'): + next(cv.split(X, y)) + + +def test_leave_p_out_empty_trainset(): + # No need to check LeavePGroupsOut + cv = LeavePOut(p=2) + X, y = [[1], [2]], [0, 3] # 2 samples + with pytest.raises( + ValueError, + match='p=2 must be strictly less than the number of samples=2'): + next(cv.split(X, y, groups=[1, 2]))