-
-
Notifications
You must be signed in to change notification settings - Fork 26k
Fixes #11129: Change cv default to 5 #11139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
009267a
e683777
646cc3d
3fe6f32
6143eed
ee65184
66378db
929c900
3bf45e0
b41e503
64a837b
2d99fe8
fdb34d8
bc26d71
809920d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -367,6 +367,10 @@ class KFold(_BaseKFold): | |
If None, the random number generator is the RandomState instance used | ||
by `np.random`. Used when ``shuffle`` == True. | ||
|
||
.. versionchanged:: 0.20 | ||
The default value ``n_splits=3`` is deprecated in version 0.20 and will | ||
be changed to ``n_splits=5``in version 0.22. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.model_selection import KFold | ||
|
@@ -1859,7 +1863,7 @@ def split(self, X=None, y=None, groups=None): | |
yield train, test | ||
|
||
|
||
def check_cv(cv=3, y=None, classifier=False): | ||
def check_cv(cv=None, y=None, classifier=False): | ||
"""Input checker utility for building a cross-validator | ||
|
||
Parameters | ||
|
@@ -1894,6 +1898,9 @@ def check_cv(cv=3, y=None, classifier=False): | |
splits via the ``split`` method. | ||
""" | ||
if cv is None: | ||
warnings.warn("The default value of n_splits=3 is deprecated" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value of cv on line 1866 should be cv=None, otherwise we will never enter this block. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe pass an explicit value for cv? |
||
" in version 0.20 and will be changed to " | ||
"n_splits=5 in version 0.22", FutureWarning) | ||
cv = 3 | ||
|
||
if isinstance(cv, numbers.Integral): | ||
|
@@ -2058,6 +2065,7 @@ def train_test_split(*arrays, **options): | |
# Tell nose that train_test_split is not a test | ||
train_test_split.__test__ = False | ||
|
||
|
||
def _build_repr(self): | ||
# XXX This is copied from BaseEstimator's get_params | ||
cls = self.__class__ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,7 @@ def test_validate_parameter_grid_input(input, error_type, error_message): | |
with pytest.raises(error_type, message=error_message): | ||
ParameterGrid(input) | ||
|
||
|
||
def test_parameter_grid(): | ||
|
||
# Test basic properties of ParameterGrid. | ||
|
@@ -359,14 +360,17 @@ def test_return_train_score_warn(): | |
estimators = [GridSearchCV(LinearSVC(random_state=0), grid, iid=False), | ||
RandomizedSearchCV(LinearSVC(random_state=0), grid, | ||
n_iter=2, iid=False)] | ||
|
||
msg_nsplit = ("The default value of n_splits=3 is deprecated in " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the message here. It would be better if you can refer to merged PRs so that users can get similar messages for similar issues. |
||
"version 0.20 and will be changed to n_splits=5 " | ||
"in version 0.22") | ||
result = {} | ||
for estimator in estimators: | ||
for val in [True, False, 'warn']: | ||
estimator.set_params(return_train_score=val) | ||
fit_func = ignore_warnings(estimator.fit, | ||
category=ConvergenceWarning) | ||
result[val] = assert_no_warnings(fit_func, X, y).cv_results_ | ||
result[val] = assert_warns_message(FutureWarning, msg_nsplit, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe pass an explicit value for cv? |
||
fit_func, X, y).cv_results_ | ||
|
||
train_keys = ['split0_train_score', 'split1_train_score', | ||
'split2_train_score', 'mean_train_score', 'std_train_score'] | ||
|
@@ -1567,7 +1571,6 @@ def test_deprecated_grid_search_iid(): | |
grid = GridSearchCV(SVC(gamma='scale'), param_grid={'C': [1]}, cv=3) | ||
# no warning with equally sized test sets | ||
assert_no_warnings(grid.fit, X, y) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unrelated change |
||
grid = GridSearchCV(SVC(gamma='scale'), param_grid={'C': [1]}, cv=5) | ||
# warning because 54 % 5 != 0 | ||
assert_warns_message(DeprecationWarning, depr_message, grid.fit, X, y) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,12 +396,14 @@ def test_cross_validate_return_train_score_warn(): | |
|
||
X, y = make_classification(random_state=0) | ||
estimator = MockClassifier() | ||
|
||
msg_nsplit = ("The default value of n_splits=3 is deprecated in " | ||
"version 0.20 and will be changed to n_splits=5 " | ||
"in version 0.22") | ||
result = {} | ||
for val in [False, True, 'warn']: | ||
result[val] = assert_no_warnings(cross_validate, estimator, X, y, | ||
return_train_score=val) | ||
|
||
result[val] = assert_warns_message(FutureWarning, msg_nsplit, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Firstly this needs to assert that no other warnings are issued. Secondly it needs to be clear to the developer tasked with completing this deprecation in a couple of years' time what it used to be asserting, i.e. that it was asserting an absence of warnings.t Thirdly, the right solution is actually to pass an explicit value for cv here. Sorry I didn't notice that before. |
||
cross_validate, estimator, X, y, | ||
return_train_score=val) | ||
msg = ( | ||
'You are accessing a training score ({!r}), ' | ||
'which will not be available by default ' | ||
|
@@ -1194,7 +1196,8 @@ def test_validation_curve(): | |
MockEstimatorWithParameter(), X, y, param_name="param", | ||
param_range=param_range, cv=2 | ||
) | ||
if len(w) > 0: | ||
# Expected single FutureWarning for deprecation of n_splits=3 | ||
if len(w) != 0: | ||
raise RuntimeError("Unexpected warning: %r" % w[0].message) | ||
|
||
assert_array_almost_equal(train_scores.mean(axis=1), param_range) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a API changes summary section for it. Also, I think we need a reason for the API change. See e.g.,