Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1d30b9

Browse files
committed
FIX use labels is not needed in stratified k fold
1 parent 10b6b98 commit d1d30b9

1 file changed

Lines changed: 25 additions & 27 deletions

File tree

sklearn/model_selection/_split.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def split(self, X, y=None, labels=None):
6363
y : array-like, shape (n_samples,)
6464
The target variable for supervised learning problems.
6565
66-
labels : array-like of int with shape (n_samples,), optional
66+
labels : array-like, with shape (n_samples,), optional
6767
Arbitrary domain-specific stratification of the data to be used
6868
to draw the splits.
6969
"""
@@ -261,21 +261,19 @@ def split(self, X, y=None, labels=None):
261261
Training data, where n_samples is the number of samples
262262
and n_features is the number of features.
263263
264-
y : array-like, shape (n_samples,)
264+
y : array-like, shape (n_samples,), optional
265265
The target variable for supervised learning problems.
266266
267-
labels : array-like of int with shape (n_samples,), optional
268-
Arbitrary domain-specific stratification of the data to be used
269-
to draw the splits.
267+
labels : (Ignored, exists for compatibility.)
270268
"""
271-
X, y, labels = indexable(X, y, labels)
269+
X, y = indexable(X, y)
272270
n = _num_samples(X)
273271
if self.n_folds > n:
274272
raise ValueError(
275273
("Cannot have number of folds n_folds={0} greater"
276274
" than the number of samples: {1}.").format(self.n_folds, n))
277275

278-
for train, test in super(_BaseKFold, self).split(X, y, labels):
276+
for train, test in super(_BaseKFold, self).split(X, y):
279277
yield train, test
280278

281279
def n_splits(self, X=None, y=None, labels=None):
@@ -424,9 +422,9 @@ def _make_test_folds(self, X, y=None, labels=None):
424422
rng = self.random_state
425423
y = np.asarray(y)
426424
n_samples = y.shape[0]
427-
unique_labels, y_inversed = np.unique(y, return_inverse=True)
428-
label_counts = bincount(y_inversed)
429-
min_labels = np.min(label_counts)
425+
unique_y, y_inversed = np.unique(y, return_inverse=True)
426+
y_counts = bincount(y_inversed)
427+
min_labels = np.min(y_counts)
430428
if self.n_folds > min_labels:
431429
warnings.warn(("The least populated class in y has only %d"
432430
" members, which is too few. The minimum"
@@ -435,33 +433,33 @@ def _make_test_folds(self, X, y=None, labels=None):
435433
% (min_labels, self.n_folds)), Warning)
436434

437435
# pre-assign each sample to a test fold index using individual KFold
438-
# splitting strategies for each label so as to respect the balance of
439-
# labels
440-
# NOTE: Passing the data corresponding to ith label say X[y==label_i]
441-
# will break when the data is not 100% stratifiable for all labels.
436+
# splitting strategies for each class so as to respect the balance of
437+
# classes
438+
# NOTE: Passing the data corresponding to ith class say X[y==class_i]
439+
# will break when the data is not 100% stratifiable for all classes.
442440
# So we pass np.zeroes(max(c, n_folds)) as data to the KFold
443-
per_label_cvs = [
441+
per_cls_cvs = [
444442
KFold(self.n_folds, shuffle=self.shuffle,
445443
random_state=rng).split(np.zeros(max(c, self.n_folds)))
446-
for c in label_counts]
444+
for c in y_counts]
447445

448446
test_folds = np.zeros(n_samples, dtype=np.int)
449-
for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
450-
for label, (_, test_split) in zip(unique_labels, per_label_splits):
451-
label_test_folds = test_folds[y == label]
447+
for test_fold_idx, per_cls_splits in enumerate(zip(*per_cls_cvs)):
448+
for cls, (_, test_split) in zip(unique_y, per_cls_splits):
449+
cls_test_folds = test_folds[y == cls]
452450
# the test split can be too big because we used
453451
# KFold(...).split(X[:max(c, n_folds)]) when data is not 100%
454-
# stratifiable for all the labels
452+
# stratifiable for all the classes
455453
# (we use a warning instead of raising an exception)
456454
# If this is the case, let's trim it:
457-
test_split = test_split[test_split < len(label_test_folds)]
458-
label_test_folds[test_split] = test_fold_idx
459-
test_folds[y == label] = label_test_folds
455+
test_split = test_split[test_split < len(cls_test_folds)]
456+
cls_test_folds[test_split] = test_fold_idx
457+
test_folds[y == cls] = cls_test_folds
460458

461459
return test_folds
462460

463461
def _iter_test_masks(self, X, y=None, labels=None):
464-
test_folds = self._make_test_folds(X, y, labels)
462+
test_folds = self._make_test_folds(X, y)
465463
for i in range(self.n_folds):
466464
yield test_folds == i
467465

@@ -520,7 +518,7 @@ def n_splits(self, X, y, labels):
520518
X : (Ignored, exists for compatibility.)
521519
y : (Ignored, exists for compatibility.)
522520
523-
labels : array-like of int with shape (n_samples,)
521+
labels : array-like, with shape (n_samples,)
524522
Arbitrary domain-specific stratification of the data to be used
525523
to draw the splits.
526524
"""
@@ -598,7 +596,7 @@ def n_splits(self, X, y, labels):
598596
X : (Ignored, exists for compatibility.)
599597
y : (Ignored, exists for compatibility.)
600598
601-
labels : array-like of int with shape (n_samples,)
599+
labels : array-like, with shape (n_samples,)
602600
Arbitrary domain-specific stratification of the data to be used
603601
to draw the splits.
604602
"""
@@ -628,7 +626,7 @@ def split(self, X, y=None, labels=None):
628626
y : array-like, shape (n_samples,)
629627
The target variable for supervised learning problems.
630628
631-
labels : array-like of int with shape (n_samples,), optional
629+
labels : array-like, with shape (n_samples,), optional
632630
Arbitrary domain-specific stratification of the data to be used
633631
to draw the splits.
634632
"""

0 commit comments

Comments
 (0)