From 3fcae98efc6ebc7fb370cc93812312a2eea9a957 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 31 Mar 2023 11:44:04 -0400 Subject: [PATCH] CLN Make _NumPyAPIWrapper naming consistent to _ArrayAPIWrapper --- sklearn/utils/_array_api.py | 10 +++++----- sklearn/utils/tests/test_array_api.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fff8e1ee33a49..2afa6aba5d715 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -49,7 +49,7 @@ def take(self, X, indices, *, axis): return self._namespace.stack(selected, axis=axis) -class _NumPyApiWrapper: +class _NumPyAPIWrapper: """Array API compat wrapper for any numpy version NumPy < 1.22 does not expose the numpy.array_api namespace. This @@ -98,7 +98,7 @@ def get_namespace(*arrays): See: https://numpy.org/neps/nep-0047-array-api-standard.html If `arrays` are regular numpy arrays, an instance of the - `_NumPyApiWrapper` compatibility wrapper is returned instead. + `_NumPyAPIWrapper` compatibility wrapper is returned instead. Namespace support is not enabled by default. To enabled it call: @@ -110,7 +110,7 @@ def get_namespace(*arrays): with sklearn.config_context(array_api_dispatch=True): # your code here - Otherwise an instance of the `_NumPyApiWrapper` + Otherwise an instance of the `_NumPyAPIWrapper` compatibility wrapper is always returned irrespective of the fact that arrays implement the `__array_namespace__` protocol or not. @@ -133,7 +133,7 @@ def get_namespace(*arrays): # Returns a tuple: (array_namespace, is_array_api) if not get_config()["array_api_dispatch"]: - return _NumPyApiWrapper(), False + return _NumPyAPIWrapper(), False namespaces = { x.__array_namespace__() if hasattr(x, "__array_namespace__") else None @@ -152,7 +152,7 @@ def get_namespace(*arrays): (xp,) = namespaces if xp is None: # Use numpy as default - return _NumPyApiWrapper(), False + return _NumPyAPIWrapper(), False return _ArrayAPIWrapper(xp), True diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 7318382ae9d66..9a88153a25615 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -4,7 +4,7 @@ from sklearn.base import BaseEstimator from sklearn.utils._array_api import get_namespace -from sklearn.utils._array_api import _NumPyApiWrapper +from sklearn.utils._array_api import _NumPyAPIWrapper from sklearn.utils._array_api import _ArrayAPIWrapper from sklearn.utils._array_api import _asarray_with_order from sklearn.utils._array_api import _convert_to_numpy @@ -27,7 +27,7 @@ def test_get_namespace_ndarray(): with config_context(array_api_dispatch=array_api_dispatch): xp_out, is_array_api = get_namespace(X_np) assert not is_array_api - assert isinstance(xp_out, _NumPyApiWrapper) + assert isinstance(xp_out, _NumPyAPIWrapper) def test_get_namespace_array_api():