Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FEA Add RepeatedStratifiedGroupKFold as a new splitter #24227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 6 additions & 2 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,9 @@ Example of 2-fold K-Fold repeated 2 times::
[1 3] [0 2]


Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold n times
with different randomization in each repetition.
Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold and
:class:`RepeatedStratifiedGroupKFold` N times with different randomization in
each repetition.

.. _leave_one_out:

Expand Down Expand Up @@ -724,6 +725,9 @@ Here is a visualization of cross-validation behavior for uneven groups:
:align: center
:scale: 75%

:class:`RepeatedStratifiedGroupKFold` can be used to repeat Stratified group K-Fold n times
with different randomization in each repetition.

.. _leave_one_group_out:

Leave One Group Out
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Version 1.6

.. rubric:: Code and documentation contributors

:mod:`sklearn.model_selection`
..............................

- |Feature| added :class:`model_selection.RepeatedStratifiedGroupKFold`, that repeats
the :class:`model_selection.StratifiedGroupKFold` N times.
:pr:`24227` by :user:`Kevin Arvai <arvkevi>`.

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.5, including:

Expand Down
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LeavePOut,
PredefinedSplit,
RepeatedKFold,
RepeatedStratifiedGroupKFold,
RepeatedStratifiedKFold,
ShuffleSplit,
StratifiedGroupKFold,
Expand Down Expand Up @@ -64,6 +65,7 @@
"LeavePOut",
"RepeatedKFold",
"RepeatedStratifiedKFold",
"RepeatedStratifiedGroupKFold",
"ParameterGrid",
"ParameterSampler",
"PredefinedSplit",
Expand Down
72 changes: 72 additions & 0 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"LeavePOut",
"RepeatedStratifiedKFold",
"RepeatedKFold",
"RepeatedStratifiedGroupKFold",
"ShuffleSplit",
"GroupShuffleSplit",
"StratifiedKFold",
Expand Down Expand Up @@ -1833,6 +1834,77 @@ def split(self, X, y, groups=None):
return super().split(X, y, groups=groups)


class RepeatedStratifiedGroupKFold(_RepeatedSplits):
"""Repeated Stratified Group K-Fold cross validator.

Repeats Stratified Group K-Fold n times with different randomization in each
repetition.

This cross-validation object is a variation of RepeatedStratifiedKFold attempts to
return stratified folds with non-overlapping groups. The folds are made by
preserving the percentage of samples for each class.
Each group will appear exactly once in the test set across all folds (the
number of distinct groups has to be at least equal to the number of folds).

Read more in the :ref:`User Guide <repeated_k_fold>`.

.. versionadded:: 1.6

Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.

n_repeats : int, default=10
Number of times cross-validator needs to be repeated.

random_state : int, RandomState instance or None, default=None
Controls the generation of the random states for each repetition.
Pass an int for reproducible output across multiple function calls.
See :term:`Glossary <random_state>`.

Examples
--------
>>> import numpy as np
>>> from sklearn.model_selection import RepeatedStratifiedGroupKFold
>>> X = np.random.randn(10, 1)
>>> y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
>>> groups = np.array([1, 1, 2, 2, 2, 3, 4, 4, 5, 5])
>>> rsgkf = RepeatedStratifiedGroupKFold(n_splits=3, n_repeats=2, random_state=42)
>>> for train_idxs, test_idxs in rsgkf.split(X, y, groups):
... # print the group assignment for the train/test indices
... print("TRAIN:", groups[train_idxs], "TEST:", groups[test_idxs])
... X_train, X_test = X[train_idxs], X[test_idxs]
... y_train, y_test = y[train_idxs], y[test_idxs]
TRAIN: [2 2 2 4 4 5 5] TEST: [1 1 3]
TRAIN: [1 1 3 4 4 5 5] TEST: [2 2 2]
TRAIN: [1 1 2 2 2 3] TEST: [4 4 5 5]
TRAIN: [1 1 4 4 5 5] TEST: [2 2 2 3]
TRAIN: [2 2 2 3 4 4 5 5] TEST: [1 1]
TRAIN: [1 1 2 2 2 3] TEST: [4 4 5 5]

Notes
-----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting `random_state`
to an integer.

See Also
--------
RepeatedStratifiedKFold : Repeats stratified K-Fold n times.

GroupKFold : K-Fold iterator variant with non-overlapping groups.
"""

def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
super().__init__(
StratifiedGroupKFold,
n_repeats=n_repeats,
n_splits=n_splits,
random_state=random_state,
)


class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta):
"""Base class for *ShuffleSplit.

Expand Down
61 changes: 58 additions & 3 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LeavePOut,
PredefinedSplit,
RepeatedKFold,
RepeatedStratifiedGroupKFold,
RepeatedStratifiedKFold,
ShuffleSplit,
StratifiedGroupKFold,
Expand Down Expand Up @@ -1177,14 +1178,16 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():

def test_repeated_cv_value_errors():
# n_repeats is not integer or <= 0
for cv in (RepeatedKFold, RepeatedStratifiedKFold):
for cv in (RepeatedKFold, RepeatedStratifiedKFold, RepeatedStratifiedGroupKFold):
with pytest.raises(ValueError):
cv(n_repeats=0)
with pytest.raises(ValueError):
cv(n_repeats=1.5)


@pytest.mark.parametrize("RepeatedCV", [RepeatedKFold, RepeatedStratifiedKFold])
@pytest.mark.parametrize(
"RepeatedCV", [RepeatedKFold, RepeatedStratifiedKFold, RepeatedStratifiedGroupKFold]
)
def test_repeated_cv_repr(RepeatedCV):
n_splits, n_repeats = 2, 6
repeated_cv = RepeatedCV(n_splits=n_splits, n_repeats=n_repeats)
Expand Down Expand Up @@ -1239,7 +1242,15 @@ def test_get_n_splits_for_repeated_stratified_kfold():
assert expected_n_splits == rskf.get_n_splits()


def test_repeated_stratified_kfold_determinstic_split():
def test_get_n_splits_for_repeated_stratified_group_kfold():
n_splits = 3
n_repeats = 4
rsgkf = RepeatedStratifiedGroupKFold(n_splits=n_splits, n_repeats=n_repeats)
expected_n_splits = n_splits * n_repeats
assert expected_n_splits == rsgkf.get_n_splits()


def test_repeated_stratified_kfold_deterministic_split():
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
y = [1, 1, 1, 0, 0]
random_state = 1944695409
Expand Down Expand Up @@ -1269,6 +1280,48 @@ def test_repeated_stratified_kfold_determinstic_split():
next(splits)


@pytest.mark.parametrize("random_state", [0, 1])
def test_repeated_stratified_group_kfold_deterministic_split(random_state):
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
y = [1, 1, 1, 0, 0]
groups = [0, 0, 1, 1, 1]
random_state = random_state
rsgkf = RepeatedStratifiedGroupKFold(
n_splits=2, n_repeats=2, random_state=random_state
)

# Make sure we don't get the same split for a different seed
expected_train_splits = [[0, 1], [2, 3, 4], [2, 3, 4], [0, 1]]
expected_test_splits = [[2, 3, 4], [0, 1], [0, 1], [2, 3, 4]]

if random_state == 1:
expected_train_splits = [[2, 3, 4], [0, 1], [2, 3, 4], [0, 1]]
expected_test_splits = [[0, 1], [2, 3, 4], [0, 1], [2, 3, 4]]

# split should produce same and deterministic splits on
# each call
for _ in range(3):
splits = rsgkf.split(X, y, groups)
train, test = next(splits)
assert_array_equal(train, expected_train_splits[0])
assert_array_equal(test, expected_test_splits[0])

train, test = next(splits)
assert_array_equal(train, expected_train_splits[1])
assert_array_equal(test, expected_test_splits[1])

train, test = next(splits)
assert_array_equal(train, expected_train_splits[2])
assert_array_equal(test, expected_test_splits[2])

train, test = next(splits)
assert_array_equal(train, expected_train_splits[3])
assert_array_equal(test, expected_test_splits[3])

with pytest.raises(StopIteration):
next(splits)


def test_train_test_split_errors():
pytest.raises(ValueError, train_test_split)

Expand Down Expand Up @@ -2042,6 +2095,8 @@ def test_random_state_shuffle_false(Klass):
(RepeatedKFold(random_state=np.random.RandomState(0)), False),
(RepeatedStratifiedKFold(random_state=None), False),
(RepeatedStratifiedKFold(random_state=np.random.RandomState(0)), False),
(RepeatedStratifiedGroupKFold(random_state=None), False),
(RepeatedStratifiedGroupKFold(random_state=np.random.RandomState(0)), False),
(ShuffleSplit(random_state=None), False),
(ShuffleSplit(random_state=np.random.RandomState(0)), False),
(GroupShuffleSplit(random_state=None), False),
Expand Down