-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
[MRG] Allow for refit=callable in *SearchCV to add flexibility in identifying the best estimator #11269 #11354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/model_selection/_search.py
Outdated
| "refit should be set to False " | ||
| "explicitly. %r was passed" | ||
| % self.refit) | ||
| refit_metric = scorer_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get why you need this. you don't use refit_metric if refit is callable below. I also think making inferences from the name of the function is inappropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think refit_metric is needed to compute self.best_score_ as shown here in the original code base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I would just disable best_score_ when refit is callable. Please test and document that behaviour.
sklearn/model_selection/_search.py
Outdated
| Where there are considerations other than maximum model performance in | ||
| choosing a best estimator, ``refit`` can be set to a function which returns | ||
| thre selected ``best_index_`` given ``cv_results_``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thre -> the
sklearn/model_selection/_search.py
Outdated
| scorer is used to find the best parameters for refitting the estimator | ||
| at the end. | ||
| Where there are considerations other than maximum model performance in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model performance -> score
sklearn/model_selection/_search.py
Outdated
| scorer that would be used to find the best parameters for refitting | ||
| the estimator at the end. | ||
| Where there are considerations other than maximum model performance in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same text and formatting in both places.
sklearn/model_selection/_search.py
Outdated
| ``best_score_`` and ``best_parameters_`` will only be available if | ||
| ``refit`` is set and all of them will be determined w.r.t this specific | ||
| scorer. | ||
| scorer. If a callable is passed to parameter refit, the function's name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an unnecessary and unhelpful condition.
| For multi-metric evaluation, the name of refit callable function must | ||
| end with a scorer key(`_<scorer_name>`). | ||
| """ | ||
| def refit_prec(cv_results): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a realistic example in examples/model_selection/ rather than here.
As a simple example, I would consider using maximising score while minimising the number of selected features or PCA components.
Here we should merely be testing interface, and a dummy function (for instance, one that always chooses the lowest-score model) is sufficient / most appropriate, as it is then easy for us to be sure what correct behaviour is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman Would you say a dummy function like below is good enough to test our interface?
def refit_callable(cv_results):
return cv_results['mean_test_score'].argmin()It seems that you're suggesting two things here :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that looks good. I might add to that an assertion that all the keys we expect to be in results are in there.
Yes, I am indeed suggesting a second thing here. An example in examples/model_selection will hugely increase the visibility and practical usability of this feature. The example gallery is how we advise users how to use the features described in technical detail in the docstrings (and before StackOverflow has all the answers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'm adding a example from examples/model_selection for this feature in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman is it appropriate to add one more example for refit=callable in the docstring under GridSearchCV class after this one?
scikit-learn/sklearn/model_selection/_search.py
Lines 931 to 958 in 3b5abf7
| Examples | |
| -------- | |
| >>> from sklearn import svm, datasets | |
| >>> from sklearn.model_selection import GridSearchCV | |
| >>> iris = datasets.load_iris() | |
| >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]} | |
| >>> svc = svm.SVC(gamma="scale") | |
| >>> clf = GridSearchCV(svc, parameters) | |
| >>> clf.fit(iris.data, iris.target) | |
| ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS | |
| GridSearchCV(cv=None, error_score=..., | |
| estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., | |
| decision_function_shape='ovr', degree=..., gamma=..., | |
| kernel='rbf', max_iter=-1, probability=False, | |
| random_state=None, shrinking=True, tol=..., | |
| verbose=False), | |
| fit_params=None, iid=..., n_jobs=1, | |
| param_grid=..., pre_dispatch=..., refit=..., return_train_score=..., | |
| scoring=..., verbose=...) | |
| >>> sorted(clf.cv_results_.keys()) | |
| ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS | |
| ['mean_fit_time', 'mean_score_time', 'mean_test_score',... | |
| 'mean_train_score', 'param_C', 'param_kernel', 'params',... | |
| 'rank_test_score', 'split0_test_score',... | |
| 'split0_train_score', 'split1_test_score', 'split1_train_score',... | |
| 'split2_test_score', 'split2_train_score',... | |
| 'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a meaningful example is too large, and too much of a power-user feature, to be in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman It seems that we dont need to write test cases for our example under examples directory, right? ;)
|
Feel free to use GitHub's todo list feature in the PR description. |
|
@jnothman Thanks for your input! I'll improve my implementation based on your feedback. |
| enumerate(cv_results['mean_test_prec'])} | ||
| # Select models which have test precisions within 1 standard deviation | ||
| # of the best 'mean_test_prec' | ||
| candidates = dict(filter(lambda i: (i[1] >= test_prec_lower |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, a dict comprehension is much easier to read than this
So is test_prec_upper > i[1] >= test_prec_lower
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
| enumerate(cv_results['mean_fit_time'])} | ||
| fit_time_rank = sorted(fit_time) | ||
| for i in fit_time_rank: | ||
| if fit_time[i] in candidates: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't working in AppVeyor. The function is returning None there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm replacing these two test cases with simpler ones.
jnothman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Circle CI should fail if the example does
jnothman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please reference the example from doc/modules/grid_search.rst. you should probably put the motivation / use case there more than in the example
|
Documentation is rendered at https://26300-843222-gh.circle-artifacts.com/0/doc/_changed.html |
| } | ||
| ] | ||
|
|
||
| grid = GridSearchCV(pipe, cv=3, n_jobs=1, param_grid=param_grid, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should be encouraging users to calculate a standard deviation over 3 samples. Make cv=10.
| interface can also be used in multiple metrics evaluation. | ||
| This example balances model complexity and cross-validated score by | ||
| finding a decent accuracy within 1 standard deviation of the best accuracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to say that this is a rule of thumb for insignificant difference.
We could determine insignificant difference in a more proper way, such as with a wilcoxon rank-sum test
| @@ -0,0 +1,125 @@ | |||
| """ | |||
| ======================================================================= | |||
| Balance model complexity and cross-validated score using refit=callable | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop "using *"
| upper/lower bounds within 1 standard deviation of the | ||
| best `mean_test_scores`. | ||
| """ | ||
| std_test_score = np.std(scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be using std_test_score: you want standard deviation across cv splits, not across parameter candidates
|
@jnothman @adrinjalali Probably need your help to fix travis-ci issue... :-/ |
adrinjalali
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jiaowoshabi , LGTM!
|
Awesome, @jiaowoshabi! Please add an entry to the change log at |
doc/whats_new/v0.21.rst
Outdated
| :func:`~model_selection.validation_curve` only the latter is required. | ||
| :issue:`12613` and :issue:`12669` by :user:`Marc Torrellas <marctorrellas>`. | ||
|
|
||
| - |Enhancement| :class:`~model_selection.BaseSearchCV` now allows for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseSearchCV is not listed in doc/modules/classes.rst so this link won't work. Ordinarily we'd reference GridSearchCV and RandomizedSearchCV. you could also consider referencing the user guide rather than the example?
sklearn/model_selection/_search.py
Outdated
| See ``scoring`` parameter to know more about multiple metric | ||
| evaluation. | ||
| .. versionadded:: 0.20 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think versionchanged may be more appropriate, since the parameter was not added.
sklearn/model_selection/_search.py
Outdated
| evaluation. | ||
| .. versionadded:: 0.20 | ||
| GridSearchCV supports ``refit`` = callable to add flexibility in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't mention GridSearchCV here. Simply say "Support for callable added." the rest is documented above.
|
Thanks @jiaowoshabi!! |
| self.best_index_ = self.refit(results) | ||
| if not isinstance(self.best_index_, (int, np.integer)): | ||
| raise TypeError('best_index_ returned is not an integer') | ||
| if self.best_index_ < 0 or self.best_index_ >= len(results): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure this is a bug: results is a dictionary of things, and each value is an array the size of the grid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened #13413
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
Reference Issues/PRs
Fixes #11269. Fixes #12865. See also #9499
What does this implement/fix? Explain your changes.
Allow a callable to be passed to refit in *SearchCV to balance score and model complexity. This interface adds flexibility in identifying the "best" estimator. The function passed to parameter
refitincorporate of which metric to optimise. Hence users can use multi-metric evaluation with this interface.
Any other comments?
mean_test_score_search.py under model_selection directory)plot_grid_search_refit_callable.py) of demonstrating the usage of this interface underexamples/model_selection/makeChecklist:
refit=callableusing simple dummy refit function.refit=callableusing similar example in multi-metric eval settings_search.pyto pass the above tests