Skip to content

Commit f5e1093

Browse files
committed
re-enable tests
1 parent 3dfa878 commit f5e1093

File tree

3 files changed

+34
-80
lines changed

3 files changed

+34
-80
lines changed

.github/workflows/array-api-tests.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ jobs:
4040
runs-on: ubuntu-latest
4141
strategy:
4242
matrix:
43-
python-version: ['3.9',
44-
# '3.10', '3.11', '3.12'
45-
]
43+
python-version: ['3.9', '3.10', '3.11', '3.12']
4644

4745
steps:
4846
- name: Checkout array-api-compat

array_api_compat/jax/__init__.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,39 @@ def top_k(
1313

1414
# `swapaxes` is used to implement
1515
# the `axis` kwarg
16-
x = numpy.swapaxes(x, axis, -1)
17-
vals, args = lax.top_k(x, k)
18-
vals = numpy.swapaxes(vals, axis, -1)
19-
args = numpy.swapaxes(args, axis, -1)
20-
return vals, args
16+
# x = numpy.swapaxes(x, axis, -1)
17+
# vals, args = lax.top_k(x, k)
18+
# vals = numpy.swapaxes(vals, axis, -1)
19+
# args = numpy.swapaxes(args, axis, -1)
20+
# return vals, args
21+
22+
# The largest keyword can't be implemented with `jax.lax.top_k`
23+
# efficiently so am using `jax.numpy` for now
24+
if k <= 0:
25+
raise ValueError(f'k(={k}) provided must be positive.')
26+
27+
positive_axis: int
28+
_arr = asanyarray(a)
29+
if axis is None:
30+
arr = _arr.ravel()
31+
positive_axis = 0
32+
else:
33+
arr = _arr
34+
positive_axis = axis if axis > 0 else axis % arr.ndim
35+
36+
slice_start = (s_[:],) * positive_axis
37+
if largest:
38+
indices_array = argpartition(arr, -k, axis=axis)
39+
slice = slice_start + (s_[-k:],)
40+
topk_indices = indices_array[slice]
41+
else:
42+
indices_array = argpartition(arr, k-1, axis=axis)
43+
slice = slice_start + (s_[:k],)
44+
topk_indices = indices_array[slice]
45+
46+
topk_values = take_along_axis(arr, topk_indices, axis=axis)
47+
48+
return (topk_values, topk_indices)
2149

2250

2351
__all__ = ['top_k']

array_api_compat/numpy/_aliases.py

-72
Original file line numberDiff line numberDiff line change
@@ -63,78 +63,6 @@
6363

6464

6565
def top_k(a, k, /, axis=-1, *, largest=True):
66-
"""
67-
Returns the ``k`` largest/smallest elements and corresponding
68-
indices along the given ``axis``.
69-
70-
When ``axis`` is None, a flattened array is used.
71-
72-
If ``largest`` is false, then the ``k`` smallest elements are returned.
73-
74-
A tuple of ``(values, indices)`` is returned, where ``values`` and
75-
``indices`` of the largest/smallest elements of each row of the input
76-
array in the given ``axis``.
77-
78-
Parameters
79-
----------
80-
a: array_like
81-
The source array
82-
k: int
83-
The number of largest/smallest elements to return. ``k`` must
84-
be a positive integer and within indexable range specified by
85-
``axis``.
86-
axis: int, optional
87-
Axis along which to find the largest/smallest elements.
88-
The default is -1 (the last axis).
89-
If None, a flattened array is used.
90-
largest: bool, optional
91-
If True, largest elements are returned. Otherwise the smallest
92-
are returned.
93-
94-
Returns
95-
-------
96-
tuple_of_array: tuple
97-
The output tuple of ``(topk_values, topk_indices)``, where
98-
``topk_values`` are returned elements from the source array
99-
(not necessarily in sorted order), and ``topk_indices`` are
100-
the corresponding indices.
101-
102-
See Also
103-
--------
104-
argpartition : Indirect partition.
105-
sort : Full sorting.
106-
107-
Notes
108-
-----
109-
The returned indices are not guaranteed to be sorted according to
110-
the values. Furthermore, the returned indices are not guaranteed
111-
to be the earliest/latest occurrence of the element. E.g.,
112-
``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))``
113-
rather than ``(array([3]), array([0]))`` or
114-
``(array([3]), array([2]))``.
115-
116-
Warning: The treatment of ``np.nan`` in the input array is undefined.
117-
118-
Examples
119-
--------
120-
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
121-
>>> np.top_k(a, 2)
122-
(array([[4, 5],
123-
[4, 5],
124-
[4, 5]]),
125-
array([[3, 4],
126-
[1, 0],
127-
[1, 2]]))
128-
>>> np.top_k(a, 2, axis=0)
129-
(array([[3, 4, 3, 2, 2],
130-
[5, 4, 5, 4, 5]]),
131-
array([[2, 1, 1, 1, 2],
132-
[1, 2, 2, 0, 0]]))
133-
>>> a.flatten()
134-
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
135-
>>> np.top_k(a, 2, axis=None)
136-
(array([5, 5]), array([ 5, 12]))
137-
"""
13866
if k <= 0:
13967
raise ValueError(f'k(={k}) provided must be positive.')
14068

0 commit comments

Comments
 (0)