Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1696a7

Browse files
committed
ENH Various enhancements to the model_selection module
1 parent 3f8743f commit d1696a7

3 files changed

Lines changed: 29 additions & 36 deletions

File tree

sklearn/model_selection/_split.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,3 +1528,21 @@ def _build_repr(self):
15281528
params[key] = value
15291529

15301530
return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name)))
1531+
1532+
1533+
ALL_CVS = {'KFold': KFold,
1534+
'LabelKFold': LabelKFold,
1535+
'LeaveOneLabelOut': LeaveOneLabelOut,
1536+
'LeaveOneOut': LeaveOneOut,
1537+
'LeavePLabelOut': LeavePLabelOut,
1538+
'LeavePOut': LeavePOut,
1539+
'ShuffleSplit': ShuffleSplit,
1540+
'LabelShuffleSplit': LabelShuffleSplit,
1541+
'StratifiedKFold': StratifiedKFold,
1542+
'StratifiedShuffleSplit': StratifiedShuffleSplit,
1543+
'PredefinedSplit': PredefinedSplit}
1544+
1545+
LABEL_CVS = {'LabelKFold': LabelKFold,
1546+
'LeaveOneLabelOut': LeaveOneLabelOut,
1547+
'LeavePLabelOut': LeavePLabelOut,
1548+
'LabelShuffleSplit': LabelShuffleSplit}

sklearn/model_selection/_validation.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,9 @@
2727
from ..metrics.scorer import check_scoring
2828
from ..exceptions import FitFailedWarning
2929

30-
from ._split import KFold
31-
from ._split import LabelKFold
32-
from ._split import LeaveOneLabelOut
33-
from ._split import LeaveOneOut
34-
from ._split import LeavePLabelOut
35-
from ._split import LeavePOut
36-
from ._split import ShuffleSplit
37-
from ._split import LabelShuffleSplit
38-
from ._split import StratifiedKFold
39-
from ._split import StratifiedShuffleSplit
40-
from ._split import PredefinedSplit
41-
from ._split import check_cv, _safe_split
42-
4330
__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
4431
'learning_curve', 'validation_curve']
4532

46-
ALL_CVS = {'KFold': KFold,
47-
'LabelKFold': LabelKFold,
48-
'LeaveOneLabelOut': LeaveOneLabelOut,
49-
'LeaveOneOut': LeaveOneOut,
50-
'LeavePLabelOut': LeavePLabelOut,
51-
'LeavePOut': LeavePOut,
52-
'ShuffleSplit': ShuffleSplit,
53-
'LabelShuffleSplit': LabelShuffleSplit,
54-
'StratifiedKFold': StratifiedKFold,
55-
'StratifiedShuffleSplit': StratifiedShuffleSplit,
56-
'PredefinedSplit': PredefinedSplit}
57-
58-
LABEL_CVS = {'LabelKFold': LabelKFold,
59-
'LeaveOneLabelOut': LeaveOneLabelOut,
60-
'LeavePLabelOut': LeavePLabelOut,
61-
'LabelShuffleSplit': LabelShuffleSplit}
62-
6333

6434
def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
6535
n_jobs=1, verbose=0, fit_params=None,

sklearn/model_selection/tests/test_search.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
# TODO Import from sklearn.exceptions once merged.
4747
from sklearn.base import ChangedBehaviorWarning
4848
from sklearn.model_selection._validation import FitFailedWarning
49+
from sklearn.model_selection._split import ALL_CVS, LABEL_CVS
4950

5051
from sklearn.svm import LinearSVC, SVC
5152
from sklearn.tree import DecisionTreeRegressor
@@ -59,6 +60,14 @@
5960
from sklearn.pipeline import Pipeline
6061

6162

63+
def initialize_cross_validators(CVClass):
64+
# set parameters to initialize the cross-validators
65+
if CVClass is ALL_CVS['LeavePLabelOut']:
66+
return CVClass(n_labels=2)
67+
if CVClass is ALL_CVS['LeaveO']:
68+
return CVClass(p=2)
69+
70+
6271
# Neither of the following two estimators inherit from BaseEstimator,
6372
# to test hyperparameter search on user-defined classifiers.
6473
class MockClassifier(object):
@@ -234,18 +243,14 @@ def test_grid_search_labels():
234243
clf = LinearSVC(random_state=0)
235244
grid = {'C': [1]}
236245

237-
label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(),
238-
LabelShuffleSplit()]
239-
for cv in label_cvs:
246+
for _, cv in LABEL_CVS.iteritems():
240247
gs = GridSearchCV(clf, grid, cv=cv)
241248
assert_raise_message(ValueError,
242249
"The labels parameter should not be None",
243250
gs.fit, X, y)
244251
gs.fit(X, y, labels)
245252

246-
non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
247-
for cv in non_label_cvs:
248-
print(cv)
253+
for _, cv in (set(ALL_CVS.iteritems()) - set(LABEL_CVS.iteritems())):
249254
gs = GridSearchCV(clf, grid, cv=cv)
250255
# Should not raise an error
251256
gs.fit(X, y)

0 commit comments

Comments
 (0)