-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX fix regression in gridsearchcv when parameter grids have estimators as values #29179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
06cc76a
4905ef3
176ead1
c318da7
8e73ce4
81b6da7
1975302
77462ae
70ed083
8a65fff
b5f944f
02a3937
58689a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from sklearn import config_context | ||
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier | ||
from sklearn.cluster import KMeans | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.datasets import ( | ||
make_blobs, | ||
make_classification, | ||
|
@@ -64,7 +65,7 @@ | |
from sklearn.naive_bayes import ComplementNB | ||
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler | ||
from sklearn.svm import SVC, LinearSVC | ||
from sklearn.tests.metadata_routing_common import ( | ||
ConsumingScorer, | ||
|
@@ -1403,9 +1404,7 @@ def test_search_cv_results_none_param(): | |
est_parameters, | ||
cv=cv, | ||
).fit(X, y) | ||
assert_array_equal( | ||
grid_search.cv_results_["param_random_state"], [0, float("nan")] | ||
) | ||
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None]) | ||
Comment on lines
-1406
to
+1407
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. |
||
|
||
|
||
@ignore_warnings() | ||
|
@@ -2686,3 +2685,36 @@ def score(self, X, y): | |
grid_search.fit(X, y) | ||
for param in param_grid: | ||
assert grid_search.cv_results_[f"param_{param}"].dtype == object | ||
|
||
|
||
def test_search_with_estimators_issue_29157(): | ||
"""Check cv_results_ for estimators with a `dtype` parameter, e.g. OneHotEncoder.""" | ||
pd = pytest.importorskip("pandas") | ||
df = pd.DataFrame( | ||
{ | ||
"numeric_1": [1, 2, 3, 4, 5], | ||
"object_1": ["a", "a", "a", "a", "a"], | ||
"target": [1.0, 4.1, 2.0, 3.0, 1.0], | ||
} | ||
) | ||
X = df.drop("target", axis=1) | ||
y = df["target"] | ||
enc = ColumnTransformer( | ||
[("enc", OneHotEncoder(sparse_output=False), ["object_1"])], | ||
remainder="passthrough", | ||
) | ||
pipe = Pipeline( | ||
[ | ||
("enc", enc), | ||
("regressor", LinearRegression()), | ||
] | ||
) | ||
grid_params = { | ||
"enc__enc": [ | ||
OneHotEncoder(sparse_output=False), | ||
OrdinalEncoder(), | ||
] | ||
} | ||
grid_search = GridSearchCV(pipe, grid_params, cv=2) | ||
grid_search.fit(X, y) | ||
assert grid_search.cv_results_["param_enc__enc"].dtype == object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is NumPy raising this warning? If so, we can add a commend here?