From 6e777747afbcfa116f4435c4d79107db6e08114c Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Thu, 19 Jan 2023 15:49:33 +0100 Subject: [PATCH] Add `_asarray_fn` override to `check_array` --- sklearn/utils/_array_api.py | 18 +++++++++++++++--- sklearn/utils/tests/test_array_api.py | 23 +++++++++++++++++++++++ sklearn/utils/validation.py | 22 +++++++++++++++++++--- 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fff8e1ee33a49..91741d8826dd7 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -165,8 +165,11 @@ def _expit(X): return 1.0 / (1.0 + xp.exp(-X)) -def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): - """Helper to support the order kwarg only for NumPy-backed arrays +def _asarray_with_order( + array, dtype=None, order=None, copy=None, _asarray_fn=None, xp=None + ): + """Helper to automatically support the order kwarg for NumPy-backed + arrays. Memory layout parameter `order` is not exposed in the Array API standard, however some input validation code in scikit-learn needs to work both @@ -179,7 +182,16 @@ def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None): is NumPy based, otherwise `order` is just silently ignored. """ if xp is None: - xp, _ = get_namespace(array) + xp, is_array_api = get_namespace(array) + + if _asarray_fn is not None: + if is_array_api: + raise ValueError( + "Passing _asarray_fn is only supported for array namespaces " + "compatible with the Array API" + ) + return _asarray_fn(array, dtype=dtype, copy=copy) + if xp.__name__ in {"numpy", "numpy.array_api"}: # Use NumPy API to support order array = numpy.asarray(array, order=order, dtype=dtype) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 7318382ae9d66..7af5c158c9ff7 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -143,6 +143,29 @@ def test_asarray_with_order_ignored(): assert not X_new_np.flags["F_CONTIGUOUS"] +def test_asarray_with_order_override(): + xp = pytest.importorskip("numpy.array_api") + xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api") + X = xp.asarray([1.2, 3.4, 5.1]) + + def _asarray_fn(array, copy, dtype): + return numpy.asarray(array, copy, dtype, order="F") + + X_new_asarray_fn = _asarray_with_order(X, xp=xp_, _asarray_fn=_asarray_fn) + assert X_new_asarray_fn.flags["F_CONTIGUOUS"] + + +def test_asarray_with_order_error_on_override(): + xp = pytest.importorskip("numpy.array_api") + X = xp.asarray([1.2, 3.4, 5.1]) + + def _asarray_fn(array, copy, dtype): + return numpy.asarray(array, copy, dtype, order="F") + + with pytest.raises(ValueError, match="_asarray_fn is only supported for"): + _asarray_with_order(X, xp=numpy, _asarray_fn=_asarray_fn) + + def test_convert_to_numpy_error(): """Test convert to numpy errors for unsupported namespaces.""" xp = pytest.importorskip("numpy.array_api") diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index dd0c007602654..90703ce0e0ff5 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -641,6 +641,7 @@ def check_array( ensure_min_features=1, estimator=None, input_name="", + _asarray_fn=None, ): """Input validation on an array, list, sparse matrix or similar. @@ -728,6 +729,17 @@ def check_array( .. versionadded:: 1.1.0 + _asarray_fn : callable or None, default=None + If not None, this callable will be used in place of calls to + `np.asarray` and `xp.asarray` (where `xp` can be any array namespace + implementing the Array API) when the data is converted to an array + object. Its signature must conform to the Array API specification for + `asarray`. This parameter can be used along with array libraries that + implement a superset of the Array API specifications and need some of + the extra arguments for input conversion (such as `order`). + + .. versionadded:: 1.3.0 + Returns ------- array_converted : object @@ -865,7 +877,9 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - array = _asarray_with_order(array, order=order, xp=xp) + array = _asarray_with_order( + array, order=order, _asarray_fn=_asarray_fn, xp=xp + ) if array.dtype.kind == "f": _assert_all_finite( array, @@ -948,12 +962,14 @@ def check_array( # only make a copy if `array` and `array_orig` may share memory` if np.may_share_memory(array, array_orig): array = _asarray_with_order( - array, dtype=dtype, order=order, copy=True, xp=xp + array, dtype=dtype, order=order, copy=True, _asarray_fn=_asarray_fn, + xp=xp ) else: # always make a copy for non-numpy arrays array = _asarray_with_order( - array, dtype=dtype, order=order, copy=True, xp=xp + array, dtype=dtype, order=order, copy=True, _asarray_fn=_asarray_fn, + xp=xp ) return array