@@ -1086,36 +1086,31 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10861086 for key , param_result in param_results .items ():
10871087 param_list = list (param_result .values ())
10881088 try :
1089- with warnings .catch_warnings ():
1090- warnings .filterwarnings (
1091- "ignore" ,
1092- message = "in the future the `.dtype` attribute" ,
1093- category = DeprecationWarning ,
1094- )
1095- # Warning raised by NumPy 1.20+
1096- arr_dtype = np .result_type (* param_list )
1089+ arr = np .array (param_list )
1090+ arr_dtype = arr .dtype
10971091 except (TypeError , ValueError ):
10981092 arr_dtype = np .dtype (object )
10991093 else :
1100- if any (np .min_scalar_type (x ) == object for x in param_list ):
1101- # `np.result_type` might get thrown off by `.dtype` properties
1102- # (which some estimators have).
1103- # If finding the result dtype this way would give object,
1104- # then we use object.
1105- # https://github.com/scikit-learn/scikit-learn/issues/29157
1094+ if arr_dtype .kind == "U" or arr .ndim > 1 :
11061095 arr_dtype = np .dtype (object )
1107- if len (param_list ) == n_candidates and arr_dtype != object :
1108- # Exclude `object` else the numpy constructor might infer a list of
1109- # tuples to be a 2d array.
1110- results [key ] = MaskedArray (param_list , mask = False , dtype = arr_dtype )
1111- else :
1112- # Use one MaskedArray and mask all the places where the param is not
1113- # applicable for that candidate (which may not contain all the params).
1114- ma = MaskedArray (np .empty (n_candidates ), mask = True , dtype = arr_dtype )
1115- for index , value in param_result .items ():
1116- # Setting the value at an index unmasks that index
1117- ma [index ] = value
1118- results [key ] = ma
1096+
1097+ if len (param_list ) == n_candidates :
1098+ try :
1099+ ma = MaskedArray (param_list , mask = False , dtype = arr_dtype )
1100+ except ValueError :
1101+ pass
1102+ else :
1103+ if ma .ndim == 1 :
1104+ results [key ] = ma
1105+ continue
1106+
1107+ # Use one MaskedArray and mask all the places where the param is not
1108+ # applicable for that candidate (which may not contain all the params).
1109+ ma = MaskedArray (np .empty (n_candidates ), mask = True , dtype = arr_dtype )
1110+ for index , value in param_result .items ():
1111+ # Setting the value at an index unmasks that index
1112+ ma [index ] = value
1113+ results [key ] = ma
11191114
11201115 # Store a list of param dicts at the key 'params'
11211116 results ["params" ] = candidate_params
0 commit comments