diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 591dbbd439644..479103163339f 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -182,7 +182,8 @@ The ``cross_validate`` function differs from ``cross_val_score`` in two ways - - It allows specifying multiple metrics for evaluation. -- It returns a dict containing training scores, fit-times and score-times in +- It returns a dict containing fit-times, score-times + (and optionally training scores as well as fitted estimators) in addition to the test score. For single metric evaluation, where the scoring parameter is a string, @@ -196,6 +197,9 @@ following keys - for all the scorers. If train scores are not needed, this should be set to ``False`` explicitly. +You may also retain the estimator fitted on each training set by setting +``return_estimator=True``. + The multiple metrics can be specified either as a list, tuple or set of predefined scorer names:: @@ -226,9 +230,10 @@ Or as a dict mapping scorer name to a predefined or custom scoring function:: Here is an example of ``cross_validate`` using a single metric:: >>> scores = cross_validate(clf, iris.data, iris.target, - ... scoring='precision_macro') + ... scoring='precision_macro', + ... return_estimator=True) >>> sorted(scores.keys()) - ['fit_time', 'score_time', 'test_score', 'train_score'] + ['estimator', 'fit_time', 'score_time', 'test_score', 'train_score'] Obtaining predictions by cross-validation diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 193204002664a..621fe0d99ea89 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -163,6 +163,10 @@ Model evaluation and meta-estimators group-based CV strategies. :issue:`9085` by :user:`Laurent Direr ` and `Andreas Müller`_. +- Add `return_estimator` parameter in :func:`model_selection.cross_validate` to + return estimators fitted on each split. :issue:`9686` by :user:`Aurélien Bellet + `. + Metrics - :func:`metrics.roc_auc_score` now supports binary ``y_true`` other than diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index d6d4c0924b350..6597720ac6a76 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -39,7 +39,8 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, - pre_dispatch='2*n_jobs', return_train_score="warn"): + pre_dispatch='2*n_jobs', return_train_score="warn", + return_estimator=False): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. @@ -129,6 +130,9 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, expensive and is not strictly required to select the parameters that yield the best generalization performance. + return_estimator : boolean, default False + Whether to return the estimators fitted on each split. + Returns ------- scores : dict of float arrays of shape=(n_splits,) @@ -150,6 +154,10 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, The time for scoring the estimator on the test set for each cv split. (Note time for scoring on the train set is not included even if ``return_train_score`` is set to ``True`` + ``estimator`` + The estimator objects for each cv split. + This is available only if ``return_estimator`` parameter + is set to ``True``. Examples -------- @@ -203,14 +211,16 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, delayed(_fit_and_score)( clone(estimator), X, y, scorers, train, test, verbose, None, fit_params, return_train_score=return_train_score, - return_times=True) + return_times=True, return_estimator=return_estimator) for train, test in cv.split(X, y, groups)) + zipped_scores = list(zip(*scores)) if return_train_score: - train_scores, test_scores, fit_times, score_times = zip(*scores) + train_scores = zipped_scores.pop(0) train_scores = _aggregate_score_dicts(train_scores) - else: - test_scores, fit_times, score_times = zip(*scores) + if return_estimator: + fitted_estimators = zipped_scores.pop() + test_scores, fit_times, score_times = zipped_scores test_scores = _aggregate_score_dicts(test_scores) # TODO: replace by a dict in 0.21 @@ -218,6 +228,9 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, ret['fit_time'] = np.array(fit_times) ret['score_time'] = np.array(score_times) + if return_estimator: + ret['estimator'] = fitted_estimators + for name in scorers: ret['test_%s' % name] = np.array(test_scores[name]) if return_train_score: @@ -347,7 +360,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, - return_times=False, error_score='raise'): + return_times=False, return_estimator=False, + error_score='raise'): """Fit estimator and compute scores for a given dataset split. Parameters @@ -405,6 +419,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_times : boolean, optional, default: False Whether to return the fit/score times. + return_estimator : boolean, optional, default: False + Whether to return the fitted estimator. + Returns ------- train_scores : dict of scorer name -> float, optional @@ -425,6 +442,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters : dict or None, optional The parameters that have been evaluated. + + estimator : estimator object + The fitted estimator """ if verbose > 1: if parameters is None: @@ -513,6 +533,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, ret.extend([fit_time, score_time]) if return_parameters: ret.append(parameters) + if return_estimator: + ret.append(estimator) return ret diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 19b8b6510ca20..4aab9dadef76a 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -368,20 +368,23 @@ def test_cross_validate(): test_mse_scores = [] train_r2_scores = [] test_r2_scores = [] + fitted_estimators = [] for train, test in cv.split(X, y): est = clone(reg).fit(X[train], y[train]) train_mse_scores.append(mse_scorer(est, X[train], y[train])) train_r2_scores.append(r2_scorer(est, X[train], y[train])) test_mse_scores.append(mse_scorer(est, X[test], y[test])) test_r2_scores.append(r2_scorer(est, X[test], y[test])) + fitted_estimators.append(est) train_mse_scores = np.array(train_mse_scores) test_mse_scores = np.array(test_mse_scores) train_r2_scores = np.array(train_r2_scores) test_r2_scores = np.array(test_r2_scores) + fitted_estimators = np.array(fitted_estimators) scores = (train_mse_scores, test_mse_scores, train_r2_scores, - test_r2_scores) + test_r2_scores, fitted_estimators) yield check_cross_validate_single_metric, est, X, y, scores yield check_cross_validate_multi_metric, est, X, y, scores @@ -411,7 +414,7 @@ def test_cross_validate_return_train_score_warn(): def check_cross_validate_single_metric(clf, X, y, scores): (train_mse_scores, test_mse_scores, train_r2_scores, - test_r2_scores) = scores + test_r2_scores, fitted_estimators) = scores # Test single metric evaluation when scoring is string or singleton list for (return_train_score, dict_len) in ((True, 4), (False, 3)): # Single metric passed as a string @@ -443,11 +446,19 @@ def check_cross_validate_single_metric(clf, X, y, scores): assert_equal(len(r2_scores_dict), dict_len) assert_array_almost_equal(r2_scores_dict['test_r2'], test_r2_scores) + # Test return_estimator option + mse_scores_dict = cross_validate(clf, X, y, cv=5, + scoring='neg_mean_squared_error', + return_estimator=True) + for k, est in enumerate(mse_scores_dict['estimator']): + assert_almost_equal(est.coef_, fitted_estimators[k].coef_) + assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_) + def check_cross_validate_multi_metric(clf, X, y, scores): # Test multimetric evaluation when scoring is a list / dict (train_mse_scores, test_mse_scores, train_r2_scores, - test_r2_scores) = scores + test_r2_scores, fitted_estimators) = scores all_scoring = (('r2', 'neg_mean_squared_error'), {'r2': make_scorer(r2_score), 'neg_mean_squared_error': 'neg_mean_squared_error'})