From ede5ec51394d5fe70197425048e4e9c04944d022 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Mon, 11 Dec 2023 11:21:33 +0100 Subject: [PATCH] MNT remove `take` fn in array_api wrapper --- sklearn/utils/_array_api.py | 24 --------------- sklearn/utils/tests/test_array_api.py | 44 +-------------------------- 2 files changed, 1 insertion(+), 67 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 24534faa931e8..6072c0fab8580 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -205,30 +205,6 @@ def __getattr__(self, name): def __eq__(self, other): return self._namespace == other._namespace - def take(self, X, indices, *, axis=0): - # When array_api supports `take` we can use this directly - # https://github.com/data-apis/array-api/issues/177 - if self._namespace.__name__ == "numpy.array_api": - X_np = numpy.take(X, indices, axis=axis) - return self._namespace.asarray(X_np) - - # We only support axis in (0, 1) and ndim in (1, 2) because that is all we need - # in scikit-learn - if axis not in {0, 1}: - raise ValueError(f"Only axis in (0, 1) is supported. Got {axis}") - - if X.ndim not in {1, 2}: - raise ValueError(f"Only X.ndim in (1, 2) is supported. Got {X.ndim}") - - if axis == 0: - if X.ndim == 1: - selected = [X[i] for i in indices] - else: # X.ndim == 2 - selected = [X[i, :] for i in indices] - else: # axis == 1 - selected = [X[:, i] for i in indices] - return self._namespace.stack(selected, axis=axis) - def isdtype(self, dtype, kind): return isdtype(dtype, kind, xp=self._namespace) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 866fd0e1d56f3..1c4ff748e6455 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -2,7 +2,7 @@ import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose from sklearn._config import config_context from sklearn.base import BaseEstimator @@ -101,48 +101,6 @@ def test_array_api_wrapper_astype(): assert X_converted.dtype == xp.float32 -def test_array_api_wrapper_take_for_numpy_api(): - """Test that fast path is called for numpy.array_api.""" - numpy_array_api = pytest.importorskip("numpy.array_api") - # USe the same name as numpy.array_api - xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api") - xp = _ArrayAPIWrapper(xp_) - - X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64) - X_take = xp.take(X, xp.asarray([1]), axis=0) - assert hasattr(X_take, "__array_namespace__") - assert_array_equal(X_take, numpy.take(X, [1], axis=0)) - - -def test_array_api_wrapper_take(): - """Test _ArrayAPIWrapper API for take.""" - numpy_array_api = pytest.importorskip("numpy.array_api") - xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api") - xp = _ArrayAPIWrapper(xp_) - - # Check take compared to NumPy's with axis=0 - X_1d = xp.asarray([1, 2, 3], dtype=xp.float64) - X_take = xp.take(X_1d, xp.asarray([1]), axis=0) - assert hasattr(X_take, "__array_namespace__") - assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0)) - - X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64) - X_take = xp.take(X, xp.asarray([0]), axis=0) - assert hasattr(X_take, "__array_namespace__") - assert_array_equal(X_take, numpy.take(X, [0], axis=0)) - - # Check take compared to NumPy's with axis=1 - X_take = xp.take(X, xp.asarray([0, 2]), axis=1) - assert hasattr(X_take, "__array_namespace__") - assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1)) - - with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"): - xp.take(X, xp.asarray([0]), axis=2) - - with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"): - xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0) - - @pytest.mark.parametrize("array_api", ["numpy", "numpy.array_api"]) def test_asarray_with_order(array_api): """Test _asarray_with_order passes along order for NumPy arrays."""