-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
BUG: use appropriate dtype in cv_results as opposed to always using object #28352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: use appropriate dtype in cv_results as opposed to always using object #28352
Conversation
a9af32e
to
41658f3
Compare
38a2498
to
06e35b6
Compare
@MarcoGorelli havne't forgotten about this. But this is touching VERY OLD code, so I need to spend some time to get into it 😉 |
no hurry at all, I understand this is low priority! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
sklearn/model_selection/_search.py
Outdated
for index, value in param_results[key].items(): | ||
# Setting the value at an index unmasks that index | ||
ma[index] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love the nested for loop here, but I can't think of a better way.
@thomasjpfan could you maybe have a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
sklearn/model_selection/_search.py
Outdated
for index, value in param_results[key].items(): | ||
# Setting the value at an index unmasks that index | ||
ma[index] = value | ||
param_results[key] = ma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of overwriting the param_results
in the loop, can we directly add the new results into the results
dict?
param_results[key] = ma | |
results[key] = ma |
A few lines down, we can remove the results.update(param_results)
.
sklearn/model_selection/_search.py
Outdated
# Use one MaskedArray and mask all the places where the param is not | ||
# applicable for that candidate (which may not contain all the params). | ||
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr.dtype) | ||
for index, value in param_results[key].items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for index, value in param_results[key].items(): | |
for index, value in param_result.items(): |
sklearn/model_selection/_search.py
Outdated
for key in param_results: | ||
arr = np.array(list(param_results[key].values())) | ||
if len(arr) == n_candidates: | ||
param_results[key] = MaskedArray(arr, mask=False) | ||
else: | ||
# Use one MaskedArray and mask all the places where the param is not | ||
# applicable for that candidate (which may not contain all the params). | ||
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can avoid creating a new NumPy array:
for key in param_results: | |
arr = np.array(list(param_results[key].values())) | |
if len(arr) == n_candidates: | |
param_results[key] = MaskedArray(arr, mask=False) | |
else: | |
# Use one MaskedArray and mask all the places where the param is not | |
# applicable for that candidate (which may not contain all the params). | |
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr.dtype) | |
for key, param_result in param_results.items(): | |
param_list = list(param_results[key].values()) | |
try: | |
arr_dtype = np.result_type(*param_list) | |
except TypeError: | |
arr_dtype = object | |
if len(arr) == n_candidates: | |
results[key] = MaskedArray(arr, mask=False, dtype=arr_dtype) | |
else: | |
# Use one MaskedArray and mask all the places where the param is not | |
# applicable for that candidate (which may not contain all the params). | |
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype) |
(Scikit-learn does not really like using fixed length string dtypes "<U4"
, so using object
here keeps the original behavior.)
Thanks both for your reviews! That looks better, thanks Thomas, have updated (unrelated, but |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
It seems like this broke the doc build on The error is:
|
Arviz uses scipy.signal.gaussian, which was removed in 1.13. Most recent arviz uses scipy.signal.windows.gaussian scikit-learn 1.5.0 contained a regression (scikit-learn/scikit-learn#28352) that has been fixed in scikit-learn/scikit-learn#29078
* BLD: Fix broken versions Arviz uses scipy.signal.gaussian, which was removed in 1.13. Most recent arviz uses scipy.signal.windows.gaussian scikit-learn 1.5.0 contained a regression (scikit-learn/scikit-learn#28352) that has been fixed in scikit-learn/scikit-learn#29078 * BLD: Chasing errors, limit scipy/arviz in SBR To test the notebooks, need to install SBR extras, which included a version of arviz that isn't available on 3.9. Earlier version works, but restricts scipy version
* BLD: Fix broken versions Arviz uses scipy.signal.gaussian, which was removed in 1.13. Most recent arviz uses scipy.signal.windows.gaussian scikit-learn 1.5.0 contained a regression (scikit-learn/scikit-learn#28352) that has been fixed in scikit-learn/scikit-learn#29078 * BLD: Chasing errors, limit scipy/arviz in SBR To test the notebooks, need to install SBR extras, which included a version of arviz that isn't available on 3.9. Earlier version works, but restricts scipy version
Reference Issues/PRs
closes #28350
What does this implement/fix? Explain your changes.
Instead of always using dtype
object
, use a more appropriate dtype (the one detected by numpy)Any other comments?
I noticed this when trying to use Polars, which is pickier about object dtype than pandas, for #28345
The existing tests already cover this functionality, so I've just updated them rather than increasing the test suite's running time. I can add a new test if desired though