diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 8f0751f1cfa50..fd7a83a873af8 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -216,6 +216,11 @@ Preprocessing Model evaluation and meta-estimators +- The default of ``n_splits`` parameter of :class:`model_selection.KFold` is + deprecated in version 0.20 and will be changed from ``n_splits=3`` to + ``n_splits=5`` in version 0.22. + :issue:`11129` by :user:`Mohammad Shahebaz ` + - A scorer based on :func:`metrics.brier_score_loss` is also available. :issue:`9521` by :user:`Hanmin Qin `. diff --git a/sklearn/linear_model/tests/test_least_angle.py b/sklearn/linear_model/tests/test_least_angle.py index e41df9cce1178..c20ca68a873f9 100644 --- a/sklearn/linear_model/tests/test_least_angle.py +++ b/sklearn/linear_model/tests/test_least_angle.py @@ -424,6 +424,7 @@ def test_lars_cv(): def test_lars_cv_max_iter(): with warnings.catch_warnings(record=True) as w: + warnings.simplefilter(action='ignore', category=FutureWarning) X = diabetes.data y = diabetes.target rng = np.random.RandomState(42) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 866cb4cc53aa8..bca41c0eeb8c8 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -367,6 +367,10 @@ class KFold(_BaseKFold): If None, the random number generator is the RandomState instance used by `np.random`. Used when ``shuffle`` == True. + .. versionchanged:: 0.20 + The default value ``n_splits=3`` is deprecated in version 0.20 and will + be changed to ``n_splits=5``in version 0.22. + Examples -------- >>> from sklearn.model_selection import KFold @@ -1859,7 +1863,7 @@ def split(self, X=None, y=None, groups=None): yield train, test -def check_cv(cv=3, y=None, classifier=False): +def check_cv(cv=None, y=None, classifier=False): """Input checker utility for building a cross-validator Parameters @@ -1894,6 +1898,9 @@ def check_cv(cv=3, y=None, classifier=False): splits via the ``split`` method. """ if cv is None: + warnings.warn("The default value of n_splits=3 is deprecated" + " in version 0.20 and will be changed to " + "n_splits=5 in version 0.22", FutureWarning) cv = 3 if isinstance(cv, numbers.Integral): @@ -2058,6 +2065,7 @@ def train_test_split(*arrays, **options): # Tell nose that train_test_split is not a test train_test_split.__test__ = False + def _build_repr(self): # XXX This is copied from BaseEstimator's get_params cls = self.__class__ diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index f436c7b55cf36..f54f9b98d18f7 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -138,6 +138,7 @@ def test_validate_parameter_grid_input(input, error_type, error_message): with pytest.raises(error_type, message=error_message): ParameterGrid(input) + def test_parameter_grid(): # Test basic properties of ParameterGrid. @@ -359,14 +360,17 @@ def test_return_train_score_warn(): estimators = [GridSearchCV(LinearSVC(random_state=0), grid, iid=False), RandomizedSearchCV(LinearSVC(random_state=0), grid, n_iter=2, iid=False)] - + msg_nsplit = ("The default value of n_splits=3 is deprecated in " + "version 0.20 and will be changed to n_splits=5 " + "in version 0.22") result = {} for estimator in estimators: for val in [True, False, 'warn']: estimator.set_params(return_train_score=val) fit_func = ignore_warnings(estimator.fit, category=ConvergenceWarning) - result[val] = assert_no_warnings(fit_func, X, y).cv_results_ + result[val] = assert_warns_message(FutureWarning, msg_nsplit, + fit_func, X, y).cv_results_ train_keys = ['split0_train_score', 'split1_train_score', 'split2_train_score', 'mean_train_score', 'std_train_score'] @@ -1567,7 +1571,6 @@ def test_deprecated_grid_search_iid(): grid = GridSearchCV(SVC(gamma='scale'), param_grid={'C': [1]}, cv=3) # no warning with equally sized test sets assert_no_warnings(grid.fit, X, y) - grid = GridSearchCV(SVC(gamma='scale'), param_grid={'C': [1]}, cv=5) # warning because 54 % 5 != 0 assert_warns_message(DeprecationWarning, depr_message, grid.fit, X, y) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 2929916619769..6c7718379886b 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -396,12 +396,14 @@ def test_cross_validate_return_train_score_warn(): X, y = make_classification(random_state=0) estimator = MockClassifier() - + msg_nsplit = ("The default value of n_splits=3 is deprecated in " + "version 0.20 and will be changed to n_splits=5 " + "in version 0.22") result = {} for val in [False, True, 'warn']: - result[val] = assert_no_warnings(cross_validate, estimator, X, y, - return_train_score=val) - + result[val] = assert_warns_message(FutureWarning, msg_nsplit, + cross_validate, estimator, X, y, + return_train_score=val) msg = ( 'You are accessing a training score ({!r}), ' 'which will not be available by default ' @@ -1194,7 +1196,8 @@ def test_validation_curve(): MockEstimatorWithParameter(), X, y, param_name="param", param_range=param_range, cv=2 ) - if len(w) > 0: + # Expected single FutureWarning for deprecation of n_splits=3 + if len(w) != 0: raise RuntimeError("Unexpected warning: %r" % w[0].message) assert_array_almost_equal(train_scores.mean(axis=1), param_range)