Skip to content

Commit a2e33f9

Browse files
committed
1 parent 50ca679 commit a2e33f9

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

Diff for: spec/draft/API_specification/searching_functions.rst

+3
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,7 @@ Objects in API
2323
argmax
2424
argmin
2525
nonzero
26+
top_k
27+
top_k_indices
28+
top_k_values
2629
where

Diff for: src/array_api_stubs/_draft/searching_functions.py

+130-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
__all__ = ["argmax", "argmin", "nonzero", "where"]
1+
__all__ = [
2+
"argmax",
3+
"argmin",
4+
"nonzero",
5+
"top_k",
6+
"top_k_values",
7+
"top_k_indices",
8+
"where",
9+
]
210

311

4-
from ._types import Optional, Tuple, array
12+
from ._types import Optional, Literal, Tuple, array
513

614

715
def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
@@ -87,6 +95,126 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
8795
"""
8896

8997

98+
def top_k(
99+
x: array,
100+
k: int,
101+
/,
102+
*,
103+
axis: Optional[int] = None,
104+
mode: Literal["largest", "smallest"] = "largest",
105+
) -> Tuple[array, array]:
106+
"""
107+
Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.
108+
109+
Parameters
110+
----------
111+
x: array
112+
input array. Should have a real-valued data type.
113+
k: int
114+
number of elements to find. Must be a positive integer value.
115+
axis: Optional[int]
116+
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
117+
mode: Literal['largest', 'smallest']
118+
search mode. Must be one of the following modes:
119+
120+
- ``'largest'``: return the ``k`` largest elements.
121+
- ``'smallest'``: return the ``k`` smallest elements.
122+
123+
Returns
124+
-------
125+
out: Tuple[array, array]
126+
a namedtuple ``(values, indices)`` whose
127+
128+
- first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.
129+
- second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``.
130+
131+
Notes
132+
-----
133+
134+
- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements.
135+
- The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values.
136+
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
137+
"""
138+
139+
140+
def top_k_indices(
141+
x: array,
142+
k: int,
143+
/,
144+
*,
145+
axis: Optional[int] = None,
146+
mode: Literal["largest", "smallest"] = "largest",
147+
) -> array:
148+
"""
149+
Returns the indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.
150+
151+
Parameters
152+
----------
153+
x: array
154+
input array. Should have a real-valued data type.
155+
k: int
156+
number of elements to find. Must be a positive integer value.
157+
axis: Optional[int]
158+
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
159+
mode: Literal['largest', 'smallest']
160+
search mode. Must be one of the following modes:
161+
162+
- ``'largest'``: return the indices of the ``k`` largest elements.
163+
- ``'smallest'``: return the indices of the ``k`` smallest elements.
164+
165+
Returns
166+
-------
167+
out: array
168+
an array containing indices corresponding to the ``k`` largest (or smallest) elements of ``x``. The array must have the default array index data type. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)`` and contain the indices of a flattened ``x``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.
169+
170+
Notes
171+
-----
172+
173+
- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices.
174+
- The order of the returned indices is left unspecified and thus implementation-dependent. Conforming implementations may return indices corresponding to sorted or unsorted values.
175+
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
176+
"""
177+
178+
179+
def top_k_values(
180+
x: array,
181+
k: int,
182+
/,
183+
*,
184+
axis: Optional[int] = None,
185+
mode: Literal["largest", "smallest"] = "largest",
186+
) -> array:
187+
"""
188+
Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.
189+
190+
Parameters
191+
----------
192+
x: array
193+
input array. Should have a real-valued data type.
194+
k: int
195+
number of elements to find. Must be a positive integer value.
196+
axis: Optional[int]
197+
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
198+
mode: Literal['largest', 'smallest']
199+
search mode. Must be one of the following modes:
200+
201+
- ``'largest'``: return the indices of the ``k`` largest elements.
202+
- ``'smallest'``: return the indices of the ``k`` smallest elements.
203+
204+
Returns
205+
-------
206+
out: array
207+
an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.
208+
209+
Notes
210+
-----
211+
212+
- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices.
213+
- The order of the returned values is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values.
214+
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
215+
"""
216+
217+
90218
def where(condition: array, x1: array, x2: array, /) -> array:
91219
"""
92220
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.

0 commit comments

Comments
 (0)