diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py index ba56bcd6f004..baf23f7f0b69 100644 --- a/numpy/array_api/_indexing_functions.py +++ b/numpy/array_api/_indexing_functions.py @@ -5,14 +5,16 @@ import numpy as np -def take(x: Array, indices: Array, /, *, axis: int) -> Array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.take `. See its docstring for more information. - """ + """ + if axis is None and x.ndim != 1: + raise ValueError("axis must be specified when ndim > 1") if indices.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in indexing") - if indices.ndim != 1: + if indices.ndim != 1: raise ValueError("Only 1-dim indices array is supported") return Array._new(np.take(x._array, indices._array, axis=axis))