Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a796f33

Browse files
committed
Fix train_test_split array API implementation
1 parent 612d93d commit a796f33

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

sklearn/model_selection/_split.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
indexable,
3030
metadata_routing,
3131
)
32+
from ..utils._array_api import (
33+
_convert_to_numpy,
34+
_is_numpy_namespace,
35+
device,
36+
get_namespace,
37+
)
3238
from ..utils._param_validation import Interval, RealNotInt, validate_params
3339
from ..utils.extmath import _approximate_mode
3440
from ..utils.metadata_routing import _MetadataRequester
@@ -2221,6 +2227,12 @@ def _iter_indices(self, X, y, groups=None):
22212227
default_test_size=self._default_test_size,
22222228
)
22232229

2230+
# Convert to numpy as not all operations are supported by the Array API.
2231+
# `y` is probably never a very large array, which means that converting it
2232+
# should be cheap
2233+
xp, _ = get_namespace(y)
2234+
y = _convert_to_numpy(y, xp=xp)
2235+
22242236
if y.ndim == 2:
22252237
# for multi-label y, map each distinct row to a string repr
22262238
# using join because str(row) uses an ellipsis if len(row) > 1000
@@ -2787,6 +2799,13 @@ def train_test_split(
27872799

27882800
train, test = next(cv.split(X=arrays[0], y=stratify))
27892801

2802+
xp, is_array_api = get_namespace(*arrays)
2803+
if is_array_api and not _is_numpy_namespace(xp):
2804+
# Move train and test indexers to the same namespace and device as the
2805+
# arrays we are indexing. Assumes that all arrays are on the same device
2806+
train = xp.asarray(train, device=device(arrays[0]))
2807+
test = xp.asarray(test, device=device(arrays[0]))
2808+
27902809
return list(
27912810
chain.from_iterable(
27922811
(_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays

sklearn/utils/__init__.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..exceptions import DataConversionWarning
1919
from . import _joblib, metadata_routing
20+
from ._array_api import _is_numpy_namespace, get_namespace
2021
from ._bunch import Bunch
2122
from ._chunking import gen_batches, gen_even_slices
2223
from ._estimator_html_repr import estimator_html_repr
@@ -89,6 +90,9 @@
8990

9091
def _array_indexing(array, key, key_dtype, axis):
9192
"""Index an array or scipy.sparse consistently across NumPy version."""
93+
xp, is_array_api = get_namespace(array)
94+
if is_array_api and not _is_numpy_namespace(xp):
95+
return xp.take(array, key, axis=axis)
9296
if issparse(array) and key_dtype == "bool":
9397
key = np.asarray(key)
9498
if isinstance(key, tuple):
@@ -215,10 +219,19 @@ def _determine_key_type(key, accept_slice=True):
215219
raise ValueError(err_msg)
216220
return key_type.pop()
217221
if hasattr(key, "dtype"):
218-
try:
219-
return array_dtype_to_str[key.dtype.kind]
220-
except KeyError:
221-
raise ValueError(err_msg)
222+
xp, is_array_api = get_namespace(key)
223+
if is_array_api and not _is_numpy_namespace(xp):
224+
if xp.isdtype(key.dtype, "bool"):
225+
return "bool"
226+
elif xp.isdtype(key.dtype, "integral"):
227+
return "int"
228+
else:
229+
raise ValueError(err_msg)
230+
else:
231+
try:
232+
return array_dtype_to_str[key.dtype.kind]
233+
except KeyError:
234+
raise ValueError(err_msg)
222235
raise ValueError(err_msg)
223236

224237

0 commit comments

Comments
 (0)