Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fd28ffd

Browse files
authored
BUG: use appropriate dtype in cv_results as opposed to always using object (#28352)
1 parent 6f17e09 commit fd28ffd

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

doc/whats_new/v1.5.rst

+3
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ Changelog
217217

218218
- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
219219
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by
220+
- |Fix| the ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`) now
221+
returns masked arrays of the appropriate NumPy dtype, as opposed to always returning
222+
dtype ``object``. :pr:`28352` by :user:`Marco Gorelli<MarcoGorelli>`.
220223

221224
:mod:`sklearn.multioutput`
222225
..........................

sklearn/model_selection/_search.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1073,27 +1073,27 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10731073

10741074
_store("fit_time", out["fit_time"])
10751075
_store("score_time", out["score_time"])
1076-
# Use one MaskedArray and mask all the places where the param is not
1077-
# applicable for that candidate. Use defaultdict as each candidate may
1078-
# not contain all the params
1079-
param_results = defaultdict(
1080-
partial(
1081-
MaskedArray,
1082-
np.empty(
1083-
n_candidates,
1084-
),
1085-
mask=True,
1086-
dtype=object,
1087-
)
1088-
)
1076+
param_results = defaultdict(dict)
10891077
for cand_idx, params in enumerate(candidate_params):
10901078
for name, value in params.items():
1091-
# An all masked empty array gets created for the key
1092-
# `"param_%s" % name` at the first occurrence of `name`.
1093-
# Setting the value at an index also unmasks that index
10941079
param_results["param_%s" % name][cand_idx] = value
1080+
for key, param_result in param_results.items():
1081+
param_list = list(param_result.values())
1082+
try:
1083+
arr_dtype = np.result_type(*param_list)
1084+
except TypeError:
1085+
arr_dtype = object
1086+
if len(param_list) == n_candidates:
1087+
results[key] = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1088+
else:
1089+
# Use one MaskedArray and mask all the places where the param is not
1090+
# applicable for that candidate (which may not contain all the params).
1091+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1092+
for index, value in param_result.items():
1093+
# Setting the value at an index unmasks that index
1094+
ma[index] = value
1095+
results[key] = ma
10951096

1096-
results.update(param_results)
10971097
# Store a list of param dicts at the key 'params'
10981098
results["params"] = candidate_params
10991099

sklearn/model_selection/tests/test_search.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -898,11 +898,15 @@ def test_param_sampler():
898898
assert [x for x in sampler] == [x for x in sampler]
899899

900900

901-
def check_cv_results_array_types(search, param_keys, score_keys):
901+
def check_cv_results_array_types(
902+
search, param_keys, score_keys, expected_cv_results_kinds
903+
):
902904
# Check if the search `cv_results`'s array are of correct types
903905
cv_results = search.cv_results_
904906
assert all(isinstance(cv_results[param], np.ma.MaskedArray) for param in param_keys)
905-
assert all(cv_results[key].dtype == object for key in param_keys)
907+
assert {
908+
key: cv_results[key].dtype.kind for key in param_keys
909+
} == expected_cv_results_kinds
906910
assert not any(isinstance(cv_results[key], np.ma.MaskedArray) for key in score_keys)
907911
assert all(
908912
cv_results[key].dtype == np.float64
@@ -975,7 +979,15 @@ def test_grid_search_cv_results():
975979
if "time" not in k and k != "rank_test_score"
976980
)
977981
# Check cv_results structure
978-
check_cv_results_array_types(search, param_keys, score_keys)
982+
expected_cv_results_kinds = {
983+
"param_C": "i",
984+
"param_degree": "i",
985+
"param_gamma": "f",
986+
"param_kernel": "O",
987+
}
988+
check_cv_results_array_types(
989+
search, param_keys, score_keys, expected_cv_results_kinds
990+
)
979991
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
980992
# Check masking
981993
cv_results = search.cv_results_
@@ -1044,7 +1056,15 @@ def test_random_search_cv_results():
10441056
search.fit(X, y)
10451057
cv_results = search.cv_results_
10461058
# Check results structure
1047-
check_cv_results_array_types(search, param_keys, score_keys)
1059+
expected_cv_results_kinds = {
1060+
"param_C": "f",
1061+
"param_degree": "i",
1062+
"param_gamma": "f",
1063+
"param_kernel": "O",
1064+
}
1065+
check_cv_results_array_types(
1066+
search, param_keys, score_keys, expected_cv_results_kinds
1067+
)
10481068
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
10491069
assert all(
10501070
(
@@ -1378,7 +1398,9 @@ def test_search_cv_results_none_param():
13781398
est_parameters,
13791399
cv=cv,
13801400
).fit(X, y)
1381-
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
1401+
assert_array_equal(
1402+
grid_search.cv_results_["param_random_state"], [0, float("nan")]
1403+
)
13821404

13831405

13841406
@ignore_warnings()

sklearn/model_selection/tests/test_successive_halving.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,15 @@ def test_halving_random_search_list_of_dicts():
826826
cv_results = search.cv_results_
827827
# Check results structure
828828
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates, extra_keys)
829-
check_cv_results_array_types(search, param_keys, score_keys)
829+
expected_cv_results_kinds = {
830+
"param_C": "f",
831+
"param_degree": "i",
832+
"param_gamma": "f",
833+
"param_kernel": "O",
834+
}
835+
check_cv_results_array_types(
836+
search, param_keys, score_keys, expected_cv_results_kinds
837+
)
830838

831839
assert all(
832840
(

0 commit comments

Comments
 (0)