Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 41658f3

Browse files
committed
BUG: use appropriate dtype in cv_results as opposed to always using object
1 parent a33ade3 commit 41658f3

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

doc/whats_new/v1.5.rst

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ Changelog
117117

118118
- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
119119
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by
120+
- |Fix| the ``cv_results_`` attribute (of :class:`model_selection.GridSearchCV`) now
121+
returns masked arrays of the appropriate NumPy dtype, as opposed to always returning
122+
dtype ``object``. :pr:`28352` by :user:`Marco Gorelli<MarcoGorelli>`.
120123

121124
:mod:`sklearn.utils`
122125
....................

sklearn/model_selection/_search.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -1081,25 +1081,22 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10811081

10821082
_store("fit_time", out["fit_time"])
10831083
_store("score_time", out["score_time"])
1084-
# Use one MaskedArray and mask all the places where the param is not
1085-
# applicable for that candidate. Use defaultdict as each candidate may
1086-
# not contain all the params
1087-
param_results = defaultdict(
1088-
partial(
1089-
MaskedArray,
1090-
np.empty(
1091-
n_candidates,
1092-
),
1093-
mask=True,
1094-
dtype=object,
1095-
)
1096-
)
1084+
param_results = defaultdict(dict)
10971085
for cand_idx, params in enumerate(candidate_params):
10981086
for name, value in params.items():
1099-
# An all masked empty array gets created for the key
1100-
# `"param_%s" % name` at the first occurrence of `name`.
1101-
# Setting the value at an index also unmasks that index
11021087
param_results["param_%s" % name][cand_idx] = value
1088+
for key in param_results:
1089+
arr = np.array(list(param_results[key].values()))
1090+
if len(arr) == n_candidates:
1091+
param_results[key] = MaskedArray(arr, mask=False)
1092+
else:
1093+
# Use one MaskedArray and mask all the places where the param is not
1094+
# applicable for that candidate (which may not contain all the params).
1095+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr.dtype)
1096+
for index, value in param_results[key].items():
1097+
# Setting the value at an index unmasks that index
1098+
ma[index] = value
1099+
param_results[key] = ma
11031100

11041101
results.update(param_results)
11051102
# Store a list of param dicts at the key 'params'

sklearn/model_selection/tests/test_search.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -898,11 +898,11 @@ def test_param_sampler():
898898
assert [x for x in sampler] == [x for x in sampler]
899899

900900

901-
def check_cv_results_array_types(search, param_keys, score_keys):
901+
def check_cv_results_array_types(search, param_keys, score_keys, expected_dtypes):
902902
# Check if the search `cv_results`'s array are of correct types
903903
cv_results = search.cv_results_
904904
assert all(isinstance(cv_results[param], np.ma.MaskedArray) for param in param_keys)
905-
assert all(cv_results[key].dtype == object for key in param_keys)
905+
assert {key: cv_results[key].dtype for key in param_keys} == expected_dtypes
906906
assert not any(isinstance(cv_results[key], np.ma.MaskedArray) for key in score_keys)
907907
assert all(
908908
cv_results[key].dtype == np.float64
@@ -975,7 +975,13 @@ def test_grid_search_cv_results():
975975
if "time" not in k and k != "rank_test_score"
976976
)
977977
# Check cv_results structure
978-
check_cv_results_array_types(search, param_keys, score_keys)
978+
expected_dtypes = {
979+
"param_C": "int64",
980+
"param_degree": "int64",
981+
"param_gamma": "float64",
982+
"param_kernel": "<U4",
983+
}
984+
check_cv_results_array_types(search, param_keys, score_keys, expected_dtypes)
979985
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
980986
# Check masking
981987
cv_results = search.cv_results_
@@ -1044,7 +1050,13 @@ def test_random_search_cv_results():
10441050
search.fit(X, y)
10451051
cv_results = search.cv_results_
10461052
# Check results structure
1047-
check_cv_results_array_types(search, param_keys, score_keys)
1053+
expected_dtypes = {
1054+
"param_C": "float64",
1055+
"param_degree": "int64",
1056+
"param_gamma": "float64",
1057+
"param_kernel": "<U4",
1058+
}
1059+
check_cv_results_array_types(search, param_keys, score_keys, expected_dtypes)
10481060
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
10491061
assert all(
10501062
(

sklearn/model_selection/tests/test_successive_halving.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,13 @@ def test_halving_random_search_list_of_dicts():
826826
cv_results = search.cv_results_
827827
# Check results structure
828828
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates, extra_keys)
829-
check_cv_results_array_types(search, param_keys, score_keys)
829+
expected_dtypes = {
830+
"param_C": "float64",
831+
"param_degree": "int64",
832+
"param_gamma": "float64",
833+
"param_kernel": "<U4",
834+
}
835+
check_cv_results_array_types(search, param_keys, score_keys, expected_dtypes)
830836

831837
assert all(
832838
(

0 commit comments

Comments
 (0)