Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX fix regression in gridsearchcv when parameter grids have estimators as values #29179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 5, 2024

Conversation

MarcoGorelli
Copy link
Contributor

Reference Issues/PRs

closes #29157

What does this implement/fix? Explain your changes.

Fixes regression. Constructs array, and gets the dtype from there, as suggested here, but sets 'U' kinds to object in keeping with this comment

Per discussion in #29157, alternatives to creating an array may not be acceptable

Any other comments?

Copy link

github-actions bot commented Jun 4, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 58689a0. Link to the linter CI: here

Comment on lines -1406 to +1408
assert_array_equal(
grid_search.cv_results_["param_random_state"], [0, float("nan")]
)
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random_state is documented to accept an integer or None, but not float - so I think the new output looks more correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. random_state should not be a float.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli MarcoGorelli marked this pull request as ready for review June 5, 2024 06:21
Comment on lines 2691 to 2692
"ignore:in the future the `.dtype` attribute of a given datatype object must "
"be a valid dtype instance:DeprecationWarning"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who's rasing this? As in, are the users gonna see this now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy raises it in the line np.result_type(*param_list)

It's a DeprecationWarning, so it wouldn't ordinarily be visible to end users, which is why running the example in the linked issue doesn't show any warning #29157

Still, doesn't hurt to silence it, I've gone with that 👍

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel ogrisel added this to the 1.5.1 milestone Jun 5, 2024
@ogrisel ogrisel added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Jun 5, 2024
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix @MarcoGorelli !

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="in the future the `.dtype` attribute",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is NumPy raising this warning? If so, we can add a commend here?

Comment on lines -1406 to +1408
assert_array_equal(
grid_search.cv_results_["param_random_state"], [0, float("nan")]
)
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. random_state should not be a float.

Comment on lines 2690 to 2691
def test_search_with_estimators_issue_29157():
pd = pytest.importorskip("pandas")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have a short description in the code itself:

Suggested change
def test_search_with_estimators_issue_29157():
pd = pytest.importorskip("pandas")
def test_search_with_estimators_issue_29157():
"""Check cv_results_ for estimators with a `dtype` parameter such as OneHotEncoder."""
pd = pytest.importorskip("pandas")

@thomasjpfan thomasjpfan merged commit b375b7b into scikit-learn:main Jun 5, 2024
30 checks passed
@lesteve
Copy link
Member

lesteve commented Jun 6, 2024

Thanks for the fix @MarcoGorelli!

It kind of feels like this is geting more and more complicated though 😅 ... see below for some issues I can imagine.

I was wondering why after all the strategy of creating an array and use the automatic dtype was dropped?

I guess one of the reason was in @thomasjpfan #28352 (comment) you said:

(Scikit-learn does not really like using fixed length string dtypes "<U4", so using object here keeps the original behavior.)

Is there anything else? If that's the only reason, maybe we can do .astype(object) if the automatically inferred dtype is a string dtype? I guess that makes a copy, not sure how crucial this is ...

I think the underlying issue is that np.result_type is a bit ambiguous: it can take both dtype-like objects (OrdinalEncoder(), 'float64', np.int32) and values (3.2, np.array([1, 2], dtype=np.float64), ...) see numpy/numpy#26612 (comment) for a proposed way to make this more explicit.

Here are some possible issues I can imagine with the code as it is in this PR:

  • at one point, the warning will turn into an error:

    import numpy as np
    from sklearn.preprocessing import OrdinalEncoder
    
    np.result_type(OrdinalEncoder(), OrdinalEncoder())

    For completeness the warning is:

    <ipython-input-1-5c4a0f627be2>:4: DeprecationWarning: in the future the `.dtype` attribute of a given datatype object must be a valid dtype instance. `data_type.dtype` may need to be coerced using `np.dtype(data_type.dtype)`. (Deprecated NumPy 1.20)
    

    The reason for the warning is that OrdinalEncoder().dtype is np.float64 and not np.dtype(np.float64)

  • you could imagine some other edge cases, e.g. you do a grid-search on OrdinalEncoder dtype so the values could be like things like ['float64', 'float32']. You would expect the dtype to be object, except that:

    np.result_type('float64', 'float32') # dtype('float64') not object

@MarcoGorelli
Copy link
Contributor Author

thanks!

I was wondering why after all the strategy of creating an array and use the automatic dtype was dropped?

another issue is that then, a list of tuples would be detected as a 2D array instead of an object 1D array of tuples

at one point, the warning will turn into an error:

true, but TypeError, ValueError are already caught - hopefully that'll catch whatever error this turns into too? 🤞

@lesteve
Copy link
Member

lesteve commented Jun 6, 2024

Good points indeed, oh well I guess I don't a better solution so let's say it is OK enough for now.

If there is another bug found in this slightly tricky code we can at least think about moving the code to a function that can be more easily tested with edge cases.

another issue is that then, a list of tuples would be detected as a 2D array instead of an object 1D array of tuples

Indeed, I have seen you added a test for this in #28571 so 👍.

About the warnings that will maybe one day turn into an error in numpy, I guess our scipy-dev CI (testing our dependencies development version) will detect it in case this is neither TypeError ValueError and then we can ask Numpy to consider chosing an exception that does not break our (slightly brittle) code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:model_selection Regression To backport PR merged in master that need a backport to a release branch defined based on the milestone.
Projects
None yet
5 participants