Skip to content

Commit db06ede

Browse files
committed
fix jax top_k
1 parent 3820460 commit db06ede

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

.github/workflows/array-api-tests-jax.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ jobs:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
99
package-name: jax
10+
# See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes
1011
extra-env-vars: |
1112
JAX_ENABLE_X64=1
13+
ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64

array_api_compat/jax/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@ def top_k(
1010
*,
1111
largest=True,
1212
):
13-
14-
# `swapaxes` is used to implement
15-
# the `axis` kwarg
16-
# x = numpy.swapaxes(x, axis, -1)
17-
# vals, args = lax.top_k(x, k)
18-
# vals = numpy.swapaxes(vals, axis, -1)
19-
# args = numpy.swapaxes(args, axis, -1)
20-
# return vals, args
21-
2213
# The largest keyword can't be implemented with `jax.lax.top_k`
2314
# efficiently so am using `jax.numpy` for now
2415
if k <= 0:

0 commit comments

Comments
 (0)