-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
[MRG+1] add option to cross_validate to return estimators fitted on each split #9686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1b30589
468556e
e690acd
4e65976
b60cd89
8e4ba61
87b52f0
cc4165f
e6e54d7
f166c53
31215b5
03fc79e
e4374a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,7 +39,8 @@ | |
|
|
||
| def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, | ||
| n_jobs=1, verbose=0, fit_params=None, | ||
| pre_dispatch='2*n_jobs', return_train_score="warn"): | ||
| pre_dispatch='2*n_jobs', return_train_score="warn", | ||
| return_estimator=False): | ||
| """Evaluate metric(s) by cross-validation and also record fit/score times. | ||
|
|
||
| Read more in the :ref:`User Guide <multimetric_cross_validation>`. | ||
|
|
@@ -129,6 +130,9 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
| expensive and is not strictly required to select the parameters that | ||
| yield the best generalization performance. | ||
|
|
||
| return_estimator : boolean, default False | ||
| Whether to return the estimators fitted on each split. | ||
|
|
||
| Returns | ||
| ------- | ||
| scores : dict of float arrays of shape=(n_splits,) | ||
|
|
@@ -150,6 +154,10 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
| The time for scoring the estimator on the test set for each | ||
| cv split. (Note time for scoring on the train set is not | ||
| included even if ``return_train_score`` is set to ``True`` | ||
| ``estimator`` | ||
| The estimator objects for each cv split. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if return_estimator is True.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| This is available only if ``return_estimator`` parameter | ||
| is set to ``True``. | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
@@ -203,21 +211,26 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
| delayed(_fit_and_score)( | ||
| clone(estimator), X, y, scorers, train, test, verbose, None, | ||
| fit_params, return_train_score=return_train_score, | ||
| return_times=True) | ||
| return_times=True, return_estimator=return_estimator) | ||
| for train, test in cv.split(X, y, groups)) | ||
|
|
||
| zipped_scores = list(zip(*scores)) | ||
| if return_train_score: | ||
| train_scores, test_scores, fit_times, score_times = zip(*scores) | ||
| train_scores = zipped_scores.pop(0) | ||
| train_scores = _aggregate_score_dicts(train_scores) | ||
| else: | ||
| test_scores, fit_times, score_times = zip(*scores) | ||
| if return_estimator: | ||
| fitted_estimators = zipped_scores.pop() | ||
| test_scores, fit_times, score_times = zipped_scores | ||
| test_scores = _aggregate_score_dicts(test_scores) | ||
|
|
||
| # TODO: replace by a dict in 0.21 | ||
| ret = DeprecationDict() if return_train_score == 'warn' else {} | ||
| ret['fit_time'] = np.array(fit_times) | ||
| ret['score_time'] = np.array(score_times) | ||
|
|
||
| if return_estimator: | ||
| ret['estimator'] = fitted_estimators | ||
|
|
||
| for name in scorers: | ||
| ret['test_%s' % name] = np.array(test_scores[name]) | ||
| if return_train_score: | ||
|
|
@@ -347,7 +360,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
| def _fit_and_score(estimator, X, y, scorer, train, test, verbose, | ||
| parameters, fit_params, return_train_score=False, | ||
| return_parameters=False, return_n_test_samples=False, | ||
| return_times=False, error_score='raise'): | ||
| return_times=False, return_estimator=False, | ||
| error_score='raise'): | ||
| """Fit estimator and compute scores for a given dataset split. | ||
|
|
||
| Parameters | ||
|
|
@@ -405,6 +419,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, | |
| return_times : boolean, optional, default: False | ||
| Whether to return the fit/score times. | ||
|
|
||
| return_estimator : boolean, optional, default: False | ||
| Whether to return the fitted estimator. | ||
|
|
||
| Returns | ||
| ------- | ||
| train_scores : dict of scorer name -> float, optional | ||
|
|
@@ -425,6 +442,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, | |
|
|
||
| parameters : dict or None, optional | ||
| The parameters that have been evaluated. | ||
|
|
||
| estimator : estimator object | ||
| The fitted estimator | ||
| """ | ||
| if verbose > 1: | ||
| if parameters is None: | ||
|
|
@@ -513,6 +533,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, | |
| ret.extend([fit_time, score_time]) | ||
| if return_parameters: | ||
| ret.append(parameters) | ||
| if return_estimator: | ||
| ret.append(estimator) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apparently coverage is missing from this line!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added an
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps put the assert False here just to prove the coverage tool wrong? If it does not fail, you've got some investigation to do...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, just did this and the test now fails. Should I commit this for the sake of it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How strange... |
||
| return ret | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optionally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done