Skip to content

Commit 1f34630

Browse files
committed
WIP: Add top_k compatibility
This references the PR data-apis/array-api-tests#274.
1 parent 51daace commit 1f34630

File tree

6 files changed

+168
-4
lines changed

6 files changed

+168
-4
lines changed
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: Array API Tests (JAX)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-jax:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: jax
10+
pytest-extra-args: -k top_k

.github/workflows/array-api-tests.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ jobs:
5050
- name: Checkout array-api-tests
5151
uses: actions/checkout@v4
5252
with:
53-
repository: data-apis/array-api-tests
53+
repository: JuliaPoo/array-api-tests
54+
ref: wip-topk-tests
5455
submodules: 'true'
5556
path: array-api-tests
5657
- name: Set up Python ${{ matrix.python-version }}

array_api_compat/dask/array/_aliases.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,28 @@ def asarray(
150150

151151
return da.asarray(obj, dtype=dtype, **kwargs)
152152

153+
154+
def top_k(
155+
x: Array,
156+
k: int,
157+
/,
158+
axis: Optional[int] = None,
159+
*,
160+
largest: bool = True,
161+
) -> tuple[Array, Array]:
162+
163+
if not largest:
164+
k = -k
165+
166+
# For now, perform the computation twice,
167+
# since an equivalent to numpy's `take_along_axis`
168+
# does not exist.
169+
# See https://github.com/dask/dask/issues/3663.
170+
args = da.argtopk(x, k, axis=axis).compute()
171+
vals = da.topk(x, k, axis=axis).compute()
172+
return vals, args
173+
174+
153175
from dask.array import (
154176
# Element wise aliases
155177
arccos as acos,
@@ -178,6 +200,7 @@ def asarray(
178200
'bitwise_right_shift', 'concat', 'pow',
179201
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
180202
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
181-
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
203+
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type',
204+
'top_k']
182205

183206
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']

array_api_compat/jax/__init__.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import jax
2+
from typing import TYPE_CHECKING
3+
if TYPE_CHECKING:
4+
from typing import Optional, Tuple
5+
6+
from ..common._typing import Array
7+
8+
9+
def top_k(
10+
x: Array,
11+
k: int,
12+
/,
13+
axis: Optional[int] = None,
14+
*,
15+
largest: bool = True,
16+
) -> Tuple[Array, Array]:
17+
18+
# `swapaxes` is used to implement
19+
# the `axis` kwarg
20+
x = jax.numpy.swapaxes(x, axis, -1)
21+
vals, args = jax.lax.top_k(x, k)
22+
vals = jax.numpy.swapaxes(vals, axis, -1)
23+
args = jax.numpy.swapaxes(args, axis, -1)
24+
return vals, args
25+
26+
27+
__all__ = ['top_k']

array_api_compat/numpy/_aliases.py

+102-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,107 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363

64+
65+
def top_k(a, k, /, *, axis=-1, largest=True):
66+
"""
67+
Returns the ``k`` largest/smallest elements and corresponding
68+
indices along the given ``axis``.
69+
70+
When ``axis`` is None, a flattened array is used.
71+
72+
If ``largest`` is false, then the ``k`` smallest elements are returned.
73+
74+
A tuple of ``(values, indices)`` is returned, where ``values`` and
75+
``indices`` of the largest/smallest elements of each row of the input
76+
array in the given ``axis``.
77+
78+
Parameters
79+
----------
80+
a: array_like
81+
The source array
82+
k: int
83+
The number of largest/smallest elements to return. ``k`` must
84+
be a positive integer and within indexable range specified by
85+
``axis``.
86+
axis: int, optional
87+
Axis along which to find the largest/smallest elements.
88+
The default is -1 (the last axis).
89+
If None, a flattened array is used.
90+
largest: bool, optional
91+
If True, largest elements are returned. Otherwise the smallest
92+
are returned.
93+
94+
Returns
95+
-------
96+
tuple_of_array: tuple
97+
The output tuple of ``(topk_values, topk_indices)``, where
98+
``topk_values`` are returned elements from the source array
99+
(not necessarily in sorted order), and ``topk_indices`` are
100+
the corresponding indices.
101+
102+
See Also
103+
--------
104+
argpartition : Indirect partition.
105+
sort : Full sorting.
106+
107+
Notes
108+
-----
109+
The returned indices are not guaranteed to be sorted according to
110+
the values. Furthermore, the returned indices are not guaranteed
111+
to be the earliest/latest occurrence of the element. E.g.,
112+
``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))``
113+
rather than ``(array([3]), array([0]))`` or
114+
``(array([3]), array([2]))``.
115+
116+
Warning: The treatment of ``np.nan`` in the input array is undefined.
117+
118+
Examples
119+
--------
120+
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
121+
>>> np.top_k(a, 2)
122+
(array([[4, 5],
123+
[4, 5],
124+
[4, 5]]),
125+
array([[3, 4],
126+
[1, 0],
127+
[1, 2]]))
128+
>>> np.top_k(a, 2, axis=0)
129+
(array([[3, 4, 3, 2, 2],
130+
[5, 4, 5, 4, 5]]),
131+
array([[2, 1, 1, 1, 2],
132+
[1, 2, 2, 0, 0]]))
133+
>>> a.flatten()
134+
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
135+
>>> np.top_k(a, 2, axis=None)
136+
(array([5, 5]), array([ 5, 12]))
137+
"""
138+
if k <= 0:
139+
raise ValueError(f'k(={k}) provided must be positive.')
140+
141+
positive_axis: int
142+
_arr = np.asanyarray(a)
143+
if axis is None:
144+
arr = _arr.ravel()
145+
positive_axis = 0
146+
else:
147+
arr = _arr
148+
positive_axis = axis if axis > 0 else axis % arr.ndim
149+
150+
slice_start = (np.s_[:],) * positive_axis
151+
if largest:
152+
indices_array = np.argpartition(arr, -k, axis=axis)
153+
slice = slice_start + (np.s_[-k:],)
154+
topk_indices = indices_array[slice]
155+
else:
156+
indices_array = np.argpartition(arr, k-1, axis=axis)
157+
slice = slice_start + (np.s_[:k],)
158+
topk_indices = indices_array[slice]
159+
160+
topk_values = np.take_along_axis(arr, topk_indices, axis=axis)
161+
162+
return (topk_values, topk_indices)
163+
164+
64165
def _supports_buffer_protocol(obj):
65166
try:
66167
memoryview(obj)
@@ -126,6 +227,6 @@ def asarray(
126227
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
127228
'acosh', 'asin', 'asinh', 'atan', 'atan2',
128229
'atanh', 'bitwise_left_shift', 'bitwise_invert',
129-
'bitwise_right_shift', 'concat', 'pow']
230+
'bitwise_right_shift', 'concat', 'pow', 'top_k']
130231

131232
_all_ignore = ['np', 'get_xp']

array_api_compat/torch/_aliases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
700700
axis = 0
701701
return torch.index_select(x, axis, indices, **kwargs)
702702

703+
top_k = torch.topk
704+
703705
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
704706
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
705707
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
@@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
713715
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
714716
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
715717
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
716-
'take']
718+
'take', 'top_k']
717719

718720
_all_ignore = ['torch', 'get_xp']

0 commit comments

Comments
 (0)