Skip to content

Commit cfd838d

Browse files
committed
meow
1 parent daaeaa5 commit cfd838d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

array_api_compat/jax/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import jax
1+
from jax import *
22

33
def top_k(
44
x,
@@ -11,10 +11,10 @@ def top_k(
1111

1212
# `swapaxes` is used to implement
1313
# the `axis` kwarg
14-
x = jax.numpy.swapaxes(x, axis, -1)
15-
vals, args = jax.lax.top_k(x, k)
16-
vals = jax.numpy.swapaxes(vals, axis, -1)
17-
args = jax.numpy.swapaxes(args, axis, -1)
14+
x = numpy.swapaxes(x, axis, -1)
15+
vals, args = lax.top_k(x, k)
16+
vals = numpy.swapaxes(vals, axis, -1)
17+
args = numpy.swapaxes(args, axis, -1)
1818
return vals, args
1919

2020

0 commit comments

Comments
 (0)