-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
StratifiedGroupShuffleSplit and StratifiedGroupKFold #15239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
hermidalc
wants to merge
9
commits into
scikit-learn:main
from
hermidalc:stratified-groupshufflesplit
Closed
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4df86c6
Initial implementation
hermidalc 6be3594
Forgot to add to second __add__ list
hermidalc 2f28673
Update split method parameter doc
hermidalc 2365735
Added example; changed default test_size to 0.1; added to author list
hermidalc b3d2b5a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
hermidalc aa8f288
StratifiedGroupKFold impl and other improvements
hermidalc 647a97e
Add class to __all__ spec
hermidalc 36babe5
Remove random_state when no shuffle
hermidalc 32e502a
Tighter formatting
hermidalc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,11 @@ | |
functions to split the data based on a preset strategy. | ||
""" | ||
|
||
# Author: Alexandre Gramfort <[email protected]>, | ||
# Gael Varoquaux <[email protected]>, | ||
# Author: Alexandre Gramfort <[email protected]> | ||
# Gael Varoquaux <[email protected]> | ||
# Olivier Grisel <[email protected]> | ||
# Raghav RV <[email protected]> | ||
# Leandro Hermida <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
from collections.abc import Iterable | ||
|
@@ -39,7 +40,9 @@ | |
'ShuffleSplit', | ||
'GroupShuffleSplit', | ||
'StratifiedKFold', | ||
'StratifiedGroupKFold', | ||
'StratifiedShuffleSplit', | ||
'StratifiedGroupShuffleSplit', | ||
'PredefinedSplit', | ||
'train_test_split', | ||
'check_cv'] | ||
|
@@ -417,10 +420,9 @@ class KFold(_BaseKFold): | |
|
||
See also | ||
-------- | ||
StratifiedKFold | ||
Takes group information into account to avoid building folds with | ||
imbalanced class distributions (for binary or multiclass | ||
classification tasks). | ||
StratifiedKFold: Takes class information into account to build folds which | ||
retain class distributions (for binary or multiclass classification | ||
tasks). | ||
|
||
GroupKFold: K-fold iterator variant with non-overlapping groups. | ||
|
||
|
@@ -733,6 +735,133 @@ def split(self, X, y, groups=None): | |
return super().split(X, y, groups) | ||
|
||
|
||
class StratifiedGroupKFold(StratifiedKFold): | ||
"""Stratified K-Folds iterator variant with non-overlapping groups. | ||
|
||
This cross-validation object is a variation of StratifiedKFold that returns | ||
folds stratified by group class. The folds are made by preserving the | ||
percentage of groups for each class. | ||
|
||
The same group will not appear in two different folds (the number of | ||
distinct groups has to be at least equal to the number of folds). | ||
|
||
The difference between GroupKFold and StratifiedGroupKFold is that | ||
the former attempts to create balanced folds such that the number of | ||
distinct groups is approximately the same in each fold, whereas | ||
StratifiedGroupKFold attempts to create folds which preserve the | ||
percentage of groups for each class. | ||
|
||
Read more in the :ref:`User Guide <cross_validation>`. | ||
|
||
Parameters | ||
---------- | ||
n_splits : int, default=5 | ||
Number of folds. Must be at least 2. | ||
|
||
shuffle : bool, default=False | ||
Whether to shuffle each class's samples before splitting into batches. | ||
Note that the samples within each split will not be shuffled. | ||
|
||
random_state : int or RandomState instance, default=None | ||
When `shuffle` is True, `random_state` affects the ordering of the | ||
indices, which controls the randomness of each fold for each class. | ||
Otherwise, leave `random_state` as `None`. | ||
Pass an int for reproducible output across multiple function calls. | ||
See :term:`Glossary <random_state>`. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> from sklearn.model_selection import StratifiedGroupKFold | ||
>>> X = np.ones((17, 2)) | ||
>>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
>>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8]) | ||
>>> cv = StratifiedGroupKFold(n_splits=3) | ||
>>> for train_idxs, test_idxs in cv.split(X, y, groups): | ||
... print("TRAIN:", groups[train_idxs]) | ||
... print(" ", y[train_idxs]) | ||
... print(" TEST:", groups[test_idxs]) | ||
... print(" ", y[test_idxs]) | ||
TRAIN: [3 3 3 4 6 6 7 8 8] | ||
[1 1 1 1 0 0 0 0 0] | ||
TEST: [1 1 2 2 5 5 5 5] | ||
[0 0 1 1 0 0 0 0] | ||
TRAIN: [1 1 2 2 4 5 5 5 5 8 8] | ||
[0 0 1 1 1 0 0 0 0 0 0] | ||
TEST: [3 3 3 6 6 7] | ||
[1 1 1 0 0 0] | ||
TRAIN: [1 1 2 2 3 3 3 5 5 5 5 6 6 7] | ||
[0 0 1 1 1 1 1 0 0 0 0 0 0 0] | ||
TEST: [4 8 8] | ||
[1 0 0] | ||
>>> cv = GroupKFold(n_splits=3) | ||
>>> for train_idxs, test_idxs in cv.split(X, y, groups): | ||
... print("TRAIN:", groups[train_idxs]) | ||
... print(" ", y[train_idxs]) | ||
... print(" TEST:", groups[test_idxs]) | ||
... print(" ", y[test_idxs]) | ||
TRAIN: [2 2 3 3 3 4 6 6 7 8 8] | ||
[1 1 1 1 1 1 0 0 0 0 0] | ||
TEST: [1 1 5 5 5 5] | ||
[0 0 0 0 0 0] | ||
TRAIN: [1 1 5 5 5 5 6 6 7 8 8] | ||
[0 0 0 0 0 0 0 0 0 0 0] | ||
TEST: [2 2 3 3 3 4] | ||
[1 1 1 1 1 1] | ||
TRAIN: [1 1 2 2 3 3 3 4 5 5 5 5] | ||
[0 0 1 1 1 1 1 1 0 0 0 0] | ||
TEST: [6 6 7 8 8] | ||
[0 0 0 0 0] | ||
|
||
Notes | ||
----- | ||
The implementation is designed to: | ||
|
||
* Generate test sets such that all contain the same distribution of | ||
group classes, or as close as possible. | ||
* Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to | ||
``y = [1, 0]`` should not change the indices generated. | ||
* Preserve order dependencies in the dataset ordering, when | ||
``shuffle=False``: all samples from class k in some test set were | ||
contiguous in y, or separated in y by samples from classes other than k. | ||
* Generate test sets where the smallest and largest differ by at most one | ||
group. | ||
|
||
See also | ||
-------- | ||
StratifiedKFold: Takes class information into account to build folds which | ||
retain class distributions (for binary or multiclass classification | ||
tasks). | ||
|
||
GroupKFold: K-fold iterator variant with non-overlapping groups. | ||
""" | ||
|
||
def __init__(self, n_splits=5, shuffle=False, random_state=None): | ||
super().__init__(n_splits=n_splits, shuffle=shuffle, | ||
random_state=random_state) | ||
|
||
def _iter_test_masks(self, X, y, groups): | ||
y = check_array(y, ensure_2d=False, dtype=None) | ||
if groups is None: | ||
raise ValueError("The 'groups' parameter should not be None.") | ||
groups = check_array(groups, ensure_2d=False, dtype=None) | ||
(unique_groups, unique_groups_y), group_indices = np.unique( | ||
np.stack((groups, y)), axis=1, return_inverse=True) | ||
n_groups = len(unique_groups) | ||
if self.n_splits > n_groups: | ||
raise ValueError("Cannot have number of splits n_splits=%d greater" | ||
" than the number of groups: %d." | ||
% (self.n_splits, n_groups)) | ||
if unique_groups.shape[0] != np.unique(groups).shape[0]: | ||
raise ValueError("Members of each group must all be of the same " | ||
"class.") | ||
for group_test in super()._iter_test_masks(X=unique_groups, | ||
y=unique_groups_y): | ||
# this is the mask of unique_groups in the partition invert it into | ||
# a data mask | ||
yield np.in1d(group_indices, np.where(group_test)) | ||
|
||
|
||
class TimeSeriesSplit(_BaseKFold): | ||
"""Time Series cross-validator | ||
|
||
|
@@ -1735,6 +1864,148 @@ def split(self, X, y, groups=None): | |
return super().split(X, y, groups) | ||
|
||
|
||
class StratifiedGroupShuffleSplit(StratifiedShuffleSplit): | ||
"""Stratified GroupShuffleSplit cross-validator | ||
|
||
Provides randomized train/test indices to split data according to a | ||
third-party provided group. This group information can be used to encode | ||
arbitrary domain specific stratifications of the samples as integers. | ||
|
||
This cross-validation object is a merge of GroupShuffleSplit and | ||
StratifiedShuffleSplit, which returns randomized folds stratified by group | ||
class. The folds are made by preserving the percentage of groups for each | ||
class. | ||
|
||
Note: like the StratifiedShuffleSplit strategy, stratified random group | ||
splits do not guarantee that all folds will be different, although this is | ||
still very likely for sizeable datasets. | ||
|
||
Read more in the :ref:`User Guide <cross_validation>`. | ||
|
||
Parameters | ||
---------- | ||
n_splits : int, default=5 | ||
Number of re-shuffling & splitting iterations. | ||
|
||
test_size : float, int, None, default=None | ||
If float, should be between 0.0 and 1.0 and represent the proportion | ||
of groups to include in the test split (rounded up). If int, | ||
represents the absolute number of test groups. If None, the value is | ||
set to the complement of the train size. By default, the value is set | ||
to 0.1. | ||
|
||
train_size : float, int, or None, default=None | ||
If float, should be between 0.0 and 1.0 and represent the | ||
proportion of the groups to include in the train split. If | ||
int, represents the absolute number of train groups. If None, | ||
the value is automatically set to the complement of the test size. | ||
|
||
random_state : int, RandomState instance or None, default=None | ||
If int, random_state is the seed used by the random number generator; | ||
If RandomState instance, random_state is the random number generator; | ||
If None, the random number generator is the RandomState instance used | ||
by `np.random`. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> from sklearn.model_selection import StratifiedGroupShuffleSplit | ||
>>> X = np.ones(shape=(15, 2)) | ||
>>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0]) | ||
>>> groups = np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6]) | ||
>>> print(groups.shape) | ||
(15,) | ||
>>> sgss = StratifiedGroupShuffleSplit(n_splits=3, train_size=.7, | ||
... random_state=43) | ||
>>> sgss.get_n_splits() | ||
3 | ||
>>> for train_idx, test_idx in sgss.split(X, y, groups): | ||
... print("TRAIN:", groups[train_idx]) | ||
... print(" ", y[train_idx]) | ||
... print(" TEST:", groups[test_idx]) | ||
... print(" ", y[test_idx]) | ||
TRAIN: [2 2 2 4 5 5 5 5 6 6] | ||
[1 1 1 0 1 1 1 1 0 0] | ||
TEST: [1 1 3 3 3] | ||
[0 0 1 1 1] | ||
TRAIN: [1 1 2 2 2 3 3 3 4] | ||
[0 0 1 1 1 1 1 1 0] | ||
TEST: [5 5 5 5 6 6] | ||
[1 1 1 1 0 0] | ||
TRAIN: [1 1 2 2 2 3 3 3 6 6] | ||
[0 0 1 1 1 1 1 1 0 0] | ||
TEST: [4 5 5 5 5] | ||
[0 1 1 1 1] | ||
|
||
See also | ||
-------- | ||
GroupShuffleSplit: Shuffle-Group(s)-Out iterator. | ||
|
||
StratifiedShuffleSplit: Stratified ShuffleSplit iterator. | ||
""" | ||
|
||
def __init__(self, n_splits=5, test_size=None, train_size=None, | ||
random_state=None): | ||
super().__init__(n_splits=n_splits, test_size=test_size, | ||
train_size=train_size, random_state=random_state) | ||
self._default_test_size = 0.1 | ||
|
||
def _iter_indices(self, X, y, groups): | ||
y = check_array(y, ensure_2d=False, dtype=None) | ||
if groups is None: | ||
raise ValueError("The 'groups' parameter should not be None.") | ||
groups = check_array(groups, ensure_2d=False, dtype=None) | ||
(unique_groups, unique_groups_y), group_indices = np.unique( | ||
np.stack((groups, y)), axis=1, return_inverse=True) | ||
if unique_groups.shape[0] != np.unique(groups).shape[0]: | ||
raise ValueError("Members of each group must all be of the same " | ||
"class.") | ||
for group_train, group_test in super()._iter_indices( | ||
X=unique_groups, y=unique_groups_y): | ||
# these are the indices of unique_groups in the partition invert | ||
# them into data indices | ||
train = np.flatnonzero(np.in1d(group_indices, group_train)) | ||
test = np.flatnonzero(np.in1d(group_indices, group_test)) | ||
yield train, test | ||
|
||
def split(self, X, y, groups=None): | ||
"""Generate indices to split data into training and test set. | ||
|
||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
Training data, where n_samples is the number of samples | ||
and n_features is the number of features. | ||
|
||
Note that providing ``y`` is sufficient to generate the splits and | ||
hence ``np.zeros(n_samples)`` may be used as a placeholder for | ||
``X`` instead of actual training data. | ||
|
||
y : array-like, shape (n_samples,) | ||
The target variable for supervised learning problems. | ||
Stratification is done based on the y labels. | ||
|
||
groups : array-like, with shape (n_samples,) | ||
Group labels for the samples used while splitting the dataset into | ||
train/test set. | ||
|
||
Yields | ||
------ | ||
train : ndarray | ||
The training set indices for that split. | ||
|
||
test : ndarray | ||
The testing set indices for that split. | ||
|
||
Notes | ||
----- | ||
Randomized CV splitters may return different results for each call of | ||
split. You can make the results identical by setting ``random_state`` | ||
to an integer. | ||
""" | ||
return super().split(X, y, groups) | ||
|
||
|
||
def _validate_shuffle_split(n_samples, test_size, train_size, | ||
default_test_size=None): | ||
""" | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this sentence. What do you mean by "folds not being different"?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That text is copied from
StratifiedShuffleSplit
and the meaning behind it is that shuffle splitting does not guarantee that each split will be different than another.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah. I think in sklearn we call folds the partitions in the split, not the repetitions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think partitions/splits/folds that's what is meant here and I believe in
StratifiedShuffleSplit
. With randomized splits there is no guarantee that a partition will be different than other ones.