Skip to content

Commit daaeaa5

Browse files
committed
idk what im doing
1 parent 1f34630 commit daaeaa5

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

array_api_compat/jax/__init__.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
import jax
2-
from typing import TYPE_CHECKING
3-
if TYPE_CHECKING:
4-
from typing import Optional, Tuple
5-
6-
from ..common._typing import Array
7-
82

93
def top_k(
10-
x: Array,
11-
k: int,
4+
x,
5+
k,
126
/,
13-
axis: Optional[int] = None,
7+
axis=None,
148
*,
15-
largest: bool = True,
16-
) -> Tuple[Array, Array]:
9+
largest=True,
10+
):
1711

1812
# `swapaxes` is used to implement
1913
# the `axis` kwarg

array_api_compat/numpy/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363

6464

65-
def top_k(a, k, /, *, axis=-1, largest=True):
65+
def top_k(a, k, /, axis=-1, *, largest=True):
6666
"""
6767
Returns the ``k`` largest/smallest elements and corresponding
6868
indices along the given ``axis``.

0 commit comments

Comments
 (0)