-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX Add NaN handling to selection of best parameters for HalvingGridSearchCV
#24539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Parameter combinations that have a score of NaN should not be ranked higher than solutions with an actual score.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a deterministic non-regression test, subclass:
class FastClassifier(DummyClassifier): |
and a testing_fit_param
that fails when testing_fit_param=1
and passes for all other parameters. Then Grid Search over testing_fit_param
to include 1. One can do the same for predict
.
If all results are nan, then I would raise. We have this behavior when all results fail in GridSearchCV
:
scikit-learn/sklearn/model_selection/_validation.py
Lines 361 to 367 in 21829b5
all_fits_failed_message = ( | |
f"\nAll the {num_fits} fits failed.\n" | |
"It is very likely that your model is misconfigured.\n" | |
"You can try to debug the error by setting error_score='raise'.\n\n" | |
f"Below are more details about the failures:\n{fit_errors_summary}" | |
) | |
raise ValueError(all_fits_failed_message) |
For successive halving, if one of the iterations completely fails and there is no "top k" and I think we should raise.
This checks that candidates that have a NaN score are always ranked lowest.
I added some tests based on a classifier that fails depending on the value of a hyper-parameter. What would be nice is to have a classifier that fails on some of the CV splits, but not all. But I can't work out how to do that and I think it is over the top as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update!
I think this needs a whats new entry in v1.2 and updating the docstring to explain the behavior.
Co-authored-by: Thomas J. Fan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM
HalvingGridSearchCV
HalvingGridSearchCV
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
I just merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only some nitpick.
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Safely merging since some of the CIs are green and we did only cosmetic changes. |
…SearchCV` (scikit-learn#24539) Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
Reference Issues/PRs
Fix #20678
What does this implement/fix? Explain your changes.
Parameter combinations that have a score of NaN should not be ranked higher than solutions with an actual score.
This switches to using
np.nanargmax()
to find the highest scores that are not NaN. In addition, when selecting the top-k parameter combinations we now rank combinations with a score of NaN lower than any solution with a score.Any other comments?
I am not sure exactly how to provoke the failure. I've tested this with a fake estimator like the following:
which allows you to have
fit()
and/orpredict()
fail at random. I think for a non regression test we need something better. Does someone have an idea?Another thing I've noticed is that
np.nanargmax
raises an exception if all elements passed to it are NaNs. Not quite sure what we should do in that case. Pick a random order?? Raise?