Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 08b6b27

Browse files
committed
fix cv_results_ in GridSearch when params are arrays of varying sizes
1 parent 65b2571 commit 08b6b27

File tree

3 files changed

+65
-28
lines changed

3 files changed

+65
-28
lines changed

doc/whats_new/v1.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ Changes impacting many modules
4848
grids that have estimators as parameter values.
4949
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.
5050

51+
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
52+
grids that have arrays of different sizes as parameter values.
53+
:pr:`29314` by :user:`Marco Gorelli<MarcoGorelli>`.
54+
5155
:mod:`sklearn.metrics`
5256
..............................
5357

sklearn/model_selection/_search.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,36 +1086,31 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10861086
for key, param_result in param_results.items():
10871087
param_list = list(param_result.values())
10881088
try:
1089-
with warnings.catch_warnings():
1090-
warnings.filterwarnings(
1091-
"ignore",
1092-
message="in the future the `.dtype` attribute",
1093-
category=DeprecationWarning,
1094-
)
1095-
# Warning raised by NumPy 1.20+
1096-
arr_dtype = np.result_type(*param_list)
1089+
arr = np.array(param_list)
1090+
arr_dtype = arr.dtype
10971091
except (TypeError, ValueError):
10981092
arr_dtype = np.dtype(object)
10991093
else:
1100-
if any(np.min_scalar_type(x) == object for x in param_list):
1101-
# `np.result_type` might get thrown off by `.dtype` properties
1102-
# (which some estimators have).
1103-
# If finding the result dtype this way would give object,
1104-
# then we use object.
1105-
# https://github.com/scikit-learn/scikit-learn/issues/29157
1094+
if arr_dtype.kind == "U" or arr.ndim > 1:
11061095
arr_dtype = np.dtype(object)
1107-
if len(param_list) == n_candidates and arr_dtype != object:
1108-
# Exclude `object` else the numpy constructor might infer a list of
1109-
# tuples to be a 2d array.
1110-
results[key] = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1111-
else:
1112-
# Use one MaskedArray and mask all the places where the param is not
1113-
# applicable for that candidate (which may not contain all the params).
1114-
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1115-
for index, value in param_result.items():
1116-
# Setting the value at an index unmasks that index
1117-
ma[index] = value
1118-
results[key] = ma
1096+
1097+
if len(param_list) == n_candidates:
1098+
try:
1099+
ma = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1100+
except ValueError:
1101+
pass
1102+
else:
1103+
if ma.ndim == 1:
1104+
results[key] = ma
1105+
continue
1106+
1107+
# Use one MaskedArray and mask all the places where the param is not
1108+
# applicable for that candidate (which may not contain all the params).
1109+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1110+
for index, value in param_result.items():
1111+
# Setting the value at an index unmasks that index
1112+
ma[index] = value
1113+
results[key] = ma
11191114

11201115
# Store a list of param dicts at the key 'params'
11211116
results["params"] = candidate_params

sklearn/model_selection/tests/test_search.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,13 @@
6565
from sklearn.model_selection.tests.common import OneTimeSplitter
6666
from sklearn.naive_bayes import ComplementNB
6767
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
68-
from sklearn.pipeline import Pipeline
69-
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
68+
from sklearn.pipeline import Pipeline, make_pipeline
69+
from sklearn.preprocessing import (
70+
OneHotEncoder,
71+
OrdinalEncoder,
72+
SplineTransformer,
73+
StandardScaler,
74+
)
7075
from sklearn.svm import SVC, LinearSVC
7176
from sklearn.tests.metadata_routing_common import (
7277
ConsumingScorer,
@@ -2724,6 +2729,39 @@ def test_search_with_estimators_issue_29157():
27242729
assert grid_search.cv_results_["param_enc__enc"].dtype == object
27252730

27262731

2732+
def test_cv_results_multi_size_array_29277():
2733+
x = np.linspace(-np.pi * 2, np.pi * 5, 1000)
2734+
y_true = np.sin(x)
2735+
y_train = y_true[(0 < x) & (x < np.pi * 2)]
2736+
2737+
x_train = x[(0 < x) & (x < np.pi * 2)]
2738+
y_train_noise = y_train + np.random.normal(size=y_train.shape, scale=0.5)
2739+
2740+
x = x.reshape((-1, 1))
2741+
x_train = x_train.reshape((-1, 1))
2742+
2743+
spline_reg_pipe = make_pipeline(
2744+
SplineTransformer(extrapolation="periodic"),
2745+
LinearRegression(fit_intercept=False),
2746+
)
2747+
2748+
spline_reg_pipe_cv = GridSearchCV(
2749+
estimator=spline_reg_pipe,
2750+
param_grid={
2751+
"splinetransformer__knots": [
2752+
np.linspace(0, np.pi * 2, n_knots).reshape((-1, 1))
2753+
for n_knots in range(10, 21, 5)
2754+
],
2755+
},
2756+
verbose=1,
2757+
)
2758+
2759+
spline_reg_pipe_cv.fit(X=x_train, y=y_train_noise)
2760+
assert (
2761+
spline_reg_pipe_cv.cv_results_["param_splinetransformer__knots"].dtype == object
2762+
)
2763+
2764+
27272765
@pytest.mark.parametrize(
27282766
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
27292767
)

0 commit comments

Comments
 (0)