Skip to content

Commit 05b427e

Browse files
committed
BUG: Fix the signature for np.array_api.take
The array_api take() doesn't flatten the array by default, so the axis argument must be provided for multidimensional arrays. However, it should be optional when the input array is 1-D, which the signature previously did not allow. c.f. data-apis/array-api#644 Original NumPy Commit: 37ba69c7b7404e4ae67ef2e4db9584852baa963a
1 parent dcfae1a commit 05b427e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

Diff for: array_api_strict/_indexing_functions.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
import numpy as np
77

8-
def take(x: Array, indices: Array, /, *, axis: int) -> Array:
8+
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
99
"""
1010
Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
1111
1212
See its docstring for more information.
13-
"""
13+
"""
14+
if axis is None and x.ndim != 1:
15+
raise ValueError("axis must be specified when ndim > 1")
1416
if indices.dtype not in _integer_dtypes:
1517
raise TypeError("Only integer dtypes are allowed in indexing")
16-
if indices.ndim != 1:
18+
if indices.ndim != 1:
1719
raise ValueError("Only 1-dim indices array is supported")
1820
return Array._new(np.take(x._array, indices._array, axis=axis))

0 commit comments

Comments
 (0)