Skip to content

Commit 390e9cc

Browse files
authored
feat: add take_long_axis to specifiation
PR-URL: #816 Closes: #808
1 parent b877795 commit 390e9cc

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

spec/draft/API_specification/indexing_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ Objects in API
2121
:template: method.rst
2222

2323
take
24+
take_along_axis

src/array_api_stubs/_draft/indexing_functions.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["take"]
1+
__all__ = ["take", "take_along_axis"]
22

33
from ._types import Union, Optional, array
44

@@ -38,3 +38,27 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array:
3838
.. versionchanged:: 2023.12
3939
Out-of-bounds behavior is explicitly left unspecified.
4040
"""
41+
42+
43+
def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
44+
"""
45+
Returns elements from an array at the one-dimensional indices specified by ``indices`` along a provided ``axis``.
46+
47+
Parameters
48+
----------
49+
x: array
50+
input array. Must be compatible with ``indices``, except for the axis (dimension) specified by ``axis`` (see :ref:`broadcasting`).
51+
indices: array
52+
array indices. Must have the same rank (i.e., number of dimensions) as ``x``.
53+
54+
.. note::
55+
This specification does not require bounds checking. The behavior for out-of-bounds indices is left unspecified.
56+
57+
axis: int
58+
axis along which to select values. If ``axis`` is negative, the function must determine the axis along which to select values by counting from the last dimension. Default: ``-1``.
59+
60+
Returns
61+
-------
62+
out: array
63+
an array having the same data type as ``x``. Must have the same rank (i.e., number of dimensions) as ``x`` and must have a shape determined according to :ref:`broadcasting`, except for the axis (dimension) specified by ``axis`` whose size must equal the size of the corresponding axis (dimension) in ``indices``.
64+
"""

0 commit comments

Comments
 (0)