From 37ba69c7b7404e4ae67ef2e4db9584852baa963a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 15 Jul 2023 10:29:36 -0500 Subject: [PATCH] BUG: Fix the signature for np.array_api.take The array_api take() doesn't flatten the array by default, so the axis argument must be provided for multidimensional arrays. However, it should be optional when the input array is 1-D, which the signature previously did not allow. c.f. https://github.com/data-apis/array-api/pull/644 --- numpy/array_api/_indexing_functions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py index ba56bcd6f004..baf23f7f0b69 100644 --- a/numpy/array_api/_indexing_functions.py +++ b/numpy/array_api/_indexing_functions.py @@ -5,14 +5,16 @@ import numpy as np -def take(x: Array, indices: Array, /, *, axis: int) -> Array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.take `. See its docstring for more information. - """ + """ + if axis is None and x.ndim != 1: + raise ValueError("axis must be specified when ndim > 1") if indices.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in indexing") - if indices.ndim != 1: + if indices.ndim != 1: raise ValueError("Only 1-dim indices array is supported") return Array._new(np.take(x._array, indices._array, axis=axis))