We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent db06ede commit 6a5863fCopy full SHA for 6a5863f
array_api_compat/jax/__init__.py
@@ -1,5 +1,42 @@
1
-from jax.numpy import * # quick hack
2
-from jax import *
+from jax.numpy import (
+ # Constants
3
+ e,
4
+ inf,
5
+ nan,
6
+ pi,
7
+ newaxis,
8
+ # Dtypes
9
+ bool,
10
+ float32,
11
+ float64,
12
+ int8,
13
+ int16,
14
+ int32,
15
+ int64,
16
+ uint8,
17
+ uint16,
18
+ uint32,
19
+ uint64,
20
+ complex64,
21
+ complex128,
22
+ iinfo,
23
+ finfo,
24
+ can_cast,
25
+ result_type,
26
+ # functions
27
+ zeros,
28
+ all,
29
+ isnan,
30
+ isfinite,
31
+ reshape
32
+)
33
34
+ asarray,
35
+ s_,
36
+ int_,
37
+ argpartition,
38
+ take_along_axis
39
40
41
42
def top_k(
@@ -39,4 +76,4 @@ def top_k(
76
return (topk_values, topk_indices)
77
78
-__all__ = ['top_k']
79
+__all__ = ['top_k']
0 commit comments