|
34 | 34 |
|
35 | 35 | from sklearn.model_selection import KFold |
36 | 36 | from sklearn.model_selection import StratifiedKFold |
37 | | -from sklearn.model_selection import StratifiedShuffleSplit |
38 | | -from sklearn.model_selection import LeaveOneLabelOut |
39 | | -from sklearn.model_selection import LeavePLabelOut |
40 | | -from sklearn.model_selection import LabelKFold |
41 | | -from sklearn.model_selection import LabelShuffleSplit |
42 | 37 | from sklearn.model_selection import GridSearchCV |
43 | 38 | from sklearn.model_selection import RandomizedSearchCV |
44 | 39 | from sklearn.model_selection import ParameterGrid |
|
47 | 42 | # TODO Import from sklearn.exceptions once merged. |
48 | 43 | from sklearn.base import ChangedBehaviorWarning |
49 | 44 | from sklearn.model_selection._validation import FitFailedWarning |
| 45 | +from sklearn.model_selection._split import ALL_CVS, LABEL_CVS |
50 | 46 |
|
51 | 47 | from sklearn.svm import LinearSVC, SVC |
52 | 48 | from sklearn.tree import DecisionTreeRegressor |
|
60 | 56 | from sklearn.pipeline import Pipeline |
61 | 57 |
|
62 | 58 |
|
| 59 | +def initialize_cross_validators(CVClass): |
| 60 | + # set parameters to initialize the cross-validators |
| 61 | + if CVClass is ALL_CVS['LeavePLabelOut']: |
| 62 | + return CVClass(n_labels=2) |
| 63 | + if CVClass is ALL_CVS['LeavePOut']: |
| 64 | + return CVClass(p=2) |
| 65 | + |
| 66 | + |
63 | 67 | # Neither of the following two estimators inherit from BaseEstimator, |
64 | 68 | # to test hyperparameter search on user-defined classifiers. |
65 | 69 | class MockClassifier(object): |
@@ -235,17 +239,16 @@ def test_grid_search_labels(): |
235 | 239 | clf = LinearSVC(random_state=0) |
236 | 240 | grid = {'C': [1]} |
237 | 241 |
|
238 | | - label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(), |
239 | | - LabelShuffleSplit()] |
240 | | - for cv in label_cvs: |
| 242 | + for _, CVClass in LABEL_CVS.iteritems(): |
| 243 | + cv = initialize_cross_validators(CVClass) |
241 | 244 | gs = GridSearchCV(clf, grid, cv=cv) |
242 | 245 | assert_raise_message(ValueError, |
243 | 246 | "The labels parameter should not be None", |
244 | 247 | gs.fit, X, y) |
245 | 248 | gs.fit(X, y, labels) |
246 | 249 |
|
247 | | - non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()] |
248 | | - for cv in non_label_cvs: |
| 250 | + for _, CVClass in (set(ALL_CVS.iteritems()) - set(LABEL_CVS.iteritems())): |
| 251 | + cv = initialize_cross_validators(CVClass) |
249 | 252 | gs = GridSearchCV(clf, grid, cv=cv) |
250 | 253 | # Should not raise an error |
251 | 254 | gs.fit(X, y) |
|
0 commit comments