File tree 1 file changed +17
-2
lines changed
1 file changed +17
-2
lines changed Original file line number Diff line number Diff line change 18
18
import inspect
19
19
import warnings
20
20
21
+ def is_jax_zero_gradient_array (x ):
22
+ """Return True if `x` is a zero-gradient array.
23
+
24
+ These arrays are a design quirk of Jax that may one day be removed.
25
+ See https://github.com/google/jax/issues/20620.
26
+ """
27
+ if 'numpy' not in sys .modules or 'jax' not in sys .modules :
28
+ return False
29
+
30
+ import numpy as np
31
+ import jax
32
+
33
+ return isinstance (x , np .ndarray ) and x .dtype == jax .float0
34
+
21
35
def is_numpy_array (x ):
22
36
"""
23
37
Return True if `x` is a NumPy array.
@@ -44,7 +58,8 @@ def is_numpy_array(x):
44
58
import numpy as np
45
59
46
60
# TODO: Should we reject ndarray subclasses?
47
- return isinstance (x , (np .ndarray , np .generic ))
61
+ return (isinstance (x , (np .ndarray , np .generic ))
62
+ and not is_jax_zero_gradient_array (x ))
48
63
49
64
def is_cupy_array (x ):
50
65
"""
@@ -149,7 +164,7 @@ def is_jax_array(x):
149
164
150
165
import jax
151
166
152
- return isinstance (x , jax .Array )
167
+ return isinstance (x , jax .Array ) or is_jax_zero_gradient_array ( x )
153
168
154
169
def is_array_api_obj (x ):
155
170
"""
You can’t perform that action at this time.
0 commit comments