Skip to content

Commit 6a5863f

Browse files
committed
fix ruff errors
1 parent db06ede commit 6a5863f

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

array_api_compat/jax/__init__.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
1-
from jax.numpy import * # quick hack
2-
from jax import *
1+
from jax.numpy import (
2+
# Constants
3+
e,
4+
inf,
5+
nan,
6+
pi,
7+
newaxis,
8+
# Dtypes
9+
bool,
10+
float32,
11+
float64,
12+
int8,
13+
int16,
14+
int32,
15+
int64,
16+
uint8,
17+
uint16,
18+
uint32,
19+
uint64,
20+
complex64,
21+
complex128,
22+
iinfo,
23+
finfo,
24+
can_cast,
25+
result_type,
26+
# functions
27+
zeros,
28+
all,
29+
isnan,
30+
isfinite,
31+
reshape
32+
)
33+
from jax.numpy import (
34+
asarray,
35+
s_,
36+
int_,
37+
argpartition,
38+
take_along_axis
39+
)
340

441

542
def top_k(
@@ -39,4 +76,4 @@ def top_k(
3976
return (topk_values, topk_indices)
4077

4178

42-
__all__ = ['top_k']
79+
__all__ = ['top_k']

0 commit comments

Comments
 (0)