Skip to content

Commit 020356f

Browse files
committed
Ensure that Jax float0 array is recognized
See jax-ml/jax#20620.
1 parent faddb83 commit 020356f

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

array_api_compat/common/_helpers.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@
1818
import inspect
1919
import warnings
2020

21+
def is_jax_zero_gradient_array(x):
22+
"""Return True if `x` is a zero-gradient array.
23+
24+
These arrays are a design quirk of Jax that may one day be removed.
25+
See https://github.com/google/jax/issues/20620.
26+
"""
27+
if 'numpy' not in sys.modules or 'jax' not in sys.modules:
28+
return False
29+
30+
import numpy as np
31+
import jax
32+
33+
return isinstance(x, np.ndarray) and x.dtype == jax.float0
34+
2135
def is_numpy_array(x):
2236
"""
2337
Return True if `x` is a NumPy array.
@@ -44,7 +58,8 @@ def is_numpy_array(x):
4458
import numpy as np
4559

4660
# TODO: Should we reject ndarray subclasses?
47-
return isinstance(x, (np.ndarray, np.generic))
61+
return (isinstance(x, (np.ndarray, np.generic))
62+
and not is_jax_zero_gradient_array(x))
4863

4964
def is_cupy_array(x):
5065
"""
@@ -149,7 +164,7 @@ def is_jax_array(x):
149164

150165
import jax
151166

152-
return isinstance(x, jax.Array)
167+
return isinstance(x, jax.Array) or is_jax_zero_gradient_array(x)
153168

154169
def is_array_api_obj(x):
155170
"""

0 commit comments

Comments
 (0)