Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6428b98

Browse files
MarcoGorellijeremiedbb
authored andcommitted
FIX fix regression in gridsearchcv when parameter grids have estimators as values (#29179)
1 parent 05db17e commit 6428b98

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

doc/whats_new/v1.5.rst

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ Changelog
3838
grids that have heterogeneous parameter values.
3939
:pr:`29078` by :user:`Loïc Estève <lesteve>`.
4040

41+
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
42+
grids that have estimators as parameter values.
43+
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.
44+
4145

4246
.. _changes_1_5:
4347

sklearn/model_selection/_search.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1089,9 +1089,24 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10891089
for key, param_result in param_results.items():
10901090
param_list = list(param_result.values())
10911091
try:
1092-
arr_dtype = np.result_type(*param_list)
1092+
with warnings.catch_warnings():
1093+
warnings.filterwarnings(
1094+
"ignore",
1095+
message="in the future the `.dtype` attribute",
1096+
category=DeprecationWarning,
1097+
)
1098+
# Warning raised by NumPy 1.20+
1099+
arr_dtype = np.result_type(*param_list)
10931100
except (TypeError, ValueError):
1094-
arr_dtype = object
1101+
arr_dtype = np.dtype(object)
1102+
else:
1103+
if any(np.min_scalar_type(x) == object for x in param_list):
1104+
# `np.result_type` might get thrown off by `.dtype` properties
1105+
# (which some estimators have).
1106+
# If finding the result dtype this way would give object,
1107+
# then we use object.
1108+
# https://github.com/scikit-learn/scikit-learn/issues/29157
1109+
arr_dtype = np.dtype(object)
10951110
if len(param_list) == n_candidates and arr_dtype != object:
10961111
# Exclude `object` else the numpy constructor might infer a list of
10971112
# tuples to be a 2d array.

sklearn/model_selection/tests/test_search.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn import config_context
1818
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
1919
from sklearn.cluster import KMeans
20+
from sklearn.compose import ColumnTransformer
2021
from sklearn.datasets import (
2122
make_blobs,
2223
make_classification,
@@ -64,7 +65,7 @@
6465
from sklearn.naive_bayes import ComplementNB
6566
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
6667
from sklearn.pipeline import Pipeline
67-
from sklearn.preprocessing import StandardScaler
68+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
6869
from sklearn.svm import SVC, LinearSVC
6970
from sklearn.tests.metadata_routing_common import (
7071
ConsumingScorer,
@@ -1403,9 +1404,7 @@ def test_search_cv_results_none_param():
14031404
est_parameters,
14041405
cv=cv,
14051406
).fit(X, y)
1406-
assert_array_equal(
1407-
grid_search.cv_results_["param_random_state"], [0, float("nan")]
1408-
)
1407+
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
14091408

14101409

14111410
@ignore_warnings()
@@ -2686,3 +2685,36 @@ def score(self, X, y):
26862685
grid_search.fit(X, y)
26872686
for param in param_grid:
26882687
assert grid_search.cv_results_[f"param_{param}"].dtype == object
2688+
2689+
2690+
def test_search_with_estimators_issue_29157():
2691+
"""Check cv_results_ for estimators with a `dtype` parameter, e.g. OneHotEncoder."""
2692+
pd = pytest.importorskip("pandas")
2693+
df = pd.DataFrame(
2694+
{
2695+
"numeric_1": [1, 2, 3, 4, 5],
2696+
"object_1": ["a", "a", "a", "a", "a"],
2697+
"target": [1.0, 4.1, 2.0, 3.0, 1.0],
2698+
}
2699+
)
2700+
X = df.drop("target", axis=1)
2701+
y = df["target"]
2702+
enc = ColumnTransformer(
2703+
[("enc", OneHotEncoder(sparse_output=False), ["object_1"])],
2704+
remainder="passthrough",
2705+
)
2706+
pipe = Pipeline(
2707+
[
2708+
("enc", enc),
2709+
("regressor", LinearRegression()),
2710+
]
2711+
)
2712+
grid_params = {
2713+
"enc__enc": [
2714+
OneHotEncoder(sparse_output=False),
2715+
OrdinalEncoder(),
2716+
]
2717+
}
2718+
grid_search = GridSearchCV(pipe, grid_params, cv=2)
2719+
grid_search.fit(X, y)
2720+
assert grid_search.cv_results_["param_enc__enc"].dtype == object

0 commit comments

Comments
 (0)