|
17 | 17 |
|
18 | 18 | from ..exceptions import DataConversionWarning |
19 | 19 | from . import _joblib, metadata_routing |
| 20 | +from ._array_api import _is_numpy_namespace, get_namespace |
20 | 21 | from ._bunch import Bunch |
21 | 22 | from ._chunking import gen_batches, gen_even_slices |
22 | 23 | from ._estimator_html_repr import estimator_html_repr |
|
89 | 90 |
|
90 | 91 | def _array_indexing(array, key, key_dtype, axis): |
91 | 92 | """Index an array or scipy.sparse consistently across NumPy version.""" |
| 93 | + xp, is_array_api = get_namespace(array) |
| 94 | + if is_array_api and not _is_numpy_namespace(xp): |
| 95 | + return xp.take(array, key, axis=axis) |
92 | 96 | if issparse(array) and key_dtype == "bool": |
93 | 97 | key = np.asarray(key) |
94 | 98 | if isinstance(key, tuple): |
@@ -215,10 +219,19 @@ def _determine_key_type(key, accept_slice=True): |
215 | 219 | raise ValueError(err_msg) |
216 | 220 | return key_type.pop() |
217 | 221 | if hasattr(key, "dtype"): |
218 | | - try: |
219 | | - return array_dtype_to_str[key.dtype.kind] |
220 | | - except KeyError: |
221 | | - raise ValueError(err_msg) |
| 222 | + xp, is_array_api = get_namespace(key) |
| 223 | + if is_array_api and not _is_numpy_namespace(xp): |
| 224 | + if xp.isdtype(key.dtype, "bool"): |
| 225 | + return "bool" |
| 226 | + elif xp.isdtype(key.dtype, "integral"): |
| 227 | + return "int" |
| 228 | + else: |
| 229 | + raise ValueError(err_msg) |
| 230 | + else: |
| 231 | + try: |
| 232 | + return array_dtype_to_str[key.dtype.kind] |
| 233 | + except KeyError: |
| 234 | + raise ValueError(err_msg) |
222 | 235 | raise ValueError(err_msg) |
223 | 236 |
|
224 | 237 |
|
|
0 commit comments