Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX fix regression in gridsearchcv when parameter grids have estimators as values #29179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 5, 2024
4 changes: 4 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Changelog
grids that have heterogeneous parameter values.
:pr:`29078` by :user:`Loïc Estève <lesteve>`.

- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
grids that have estimators as parameter values.
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.


.. _changes_1_5:

Expand Down
19 changes: 17 additions & 2 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,24 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
for key, param_result in param_results.items():
param_list = list(param_result.values())
try:
arr_dtype = np.result_type(*param_list)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="in the future the `.dtype` attribute",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is NumPy raising this warning? If so, we can add a commend here?

category=DeprecationWarning,
)
# Warning raised by NumPy 1.20+
arr_dtype = np.result_type(*param_list)
except (TypeError, ValueError):
arr_dtype = object
arr_dtype = np.dtype(object)
else:
if any(np.min_scalar_type(x) == object for x in param_list):
# `np.result_type` might get thrown off by `.dtype` properties
# (which some estimators have).
# If finding the result dtype this way would give object,
# then we use object.
# https://github.com/scikit-learn/scikit-learn/issues/29157
arr_dtype = np.dtype(object)
if len(param_list) == n_candidates and arr_dtype != object:
# Exclude `object` else the numpy constructor might infer a list of
# tuples to be a 2d array.
Expand Down
40 changes: 36 additions & 4 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.cluster import KMeans
from sklearn.compose import ColumnTransformer
from sklearn.datasets import (
make_blobs,
make_classification,
Expand Down Expand Up @@ -64,7 +65,7 @@
from sklearn.naive_bayes import ComplementNB
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.svm import SVC, LinearSVC
from sklearn.tests.metadata_routing_common import (
ConsumingScorer,
Expand Down Expand Up @@ -1403,9 +1404,7 @@ def test_search_cv_results_none_param():
est_parameters,
cv=cv,
).fit(X, y)
assert_array_equal(
grid_search.cv_results_["param_random_state"], [0, float("nan")]
)
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
Comment on lines -1406 to +1407
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random_state is documented to accept an integer or None, but not float - so I think the new output looks more correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. random_state should not be a float.



@ignore_warnings()
Expand Down Expand Up @@ -2686,3 +2685,36 @@ def score(self, X, y):
grid_search.fit(X, y)
for param in param_grid:
assert grid_search.cv_results_[f"param_{param}"].dtype == object


def test_search_with_estimators_issue_29157():
"""Check cv_results_ for estimators with a `dtype` parameter, e.g. OneHotEncoder."""
pd = pytest.importorskip("pandas")
df = pd.DataFrame(
{
"numeric_1": [1, 2, 3, 4, 5],
"object_1": ["a", "a", "a", "a", "a"],
"target": [1.0, 4.1, 2.0, 3.0, 1.0],
}
)
X = df.drop("target", axis=1)
y = df["target"]
enc = ColumnTransformer(
[("enc", OneHotEncoder(sparse_output=False), ["object_1"])],
remainder="passthrough",
)
pipe = Pipeline(
[
("enc", enc),
("regressor", LinearRegression()),
]
)
grid_params = {
"enc__enc": [
OneHotEncoder(sparse_output=False),
OrdinalEncoder(),
]
}
grid_search = GridSearchCV(pipe, grid_params, cv=2)
grid_search.fit(X, y)
assert grid_search.cv_results_["param_enc__enc"].dtype == object
Loading