Skip to content

Commit 7af0b68

Browse files
authored
Merge pull request #130 from NeilGirdhar/fix_float0
Ensure that Jax float0 array is recognized
2 parents faddb83 + 9cfbb2b commit 7af0b68

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

array_api_compat/common/_helpers.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@
1818
import inspect
1919
import warnings
2020

21+
def _is_jax_zero_gradient_array(x):
22+
"""Return True if `x` is a zero-gradient array.
23+
24+
These arrays are a design quirk of Jax that may one day be removed.
25+
See https://github.com/google/jax/issues/20620.
26+
"""
27+
if 'numpy' not in sys.modules or 'jax' not in sys.modules:
28+
return False
29+
30+
import numpy as np
31+
import jax
32+
33+
return isinstance(x, np.ndarray) and x.dtype == jax.float0
34+
2135
def is_numpy_array(x):
2236
"""
2337
Return True if `x` is a NumPy array.
@@ -44,7 +58,8 @@ def is_numpy_array(x):
4458
import numpy as np
4559

4660
# TODO: Should we reject ndarray subclasses?
47-
return isinstance(x, (np.ndarray, np.generic))
61+
return (isinstance(x, (np.ndarray, np.generic))
62+
and not _is_jax_zero_gradient_array(x))
4863

4964
def is_cupy_array(x):
5065
"""
@@ -149,7 +164,7 @@ def is_jax_array(x):
149164

150165
import jax
151166

152-
return isinstance(x, jax.Array)
167+
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
153168

154169
def is_array_api_obj(x):
155170
"""

tests/test_array_namespace.py

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import warnings
44

5+
import jax
56
import numpy as np
67
import pytest
78
import torch
@@ -55,6 +56,12 @@ def test_array_namespace(library, api_version, use_compat):
5556
"""
5657
subprocess.run([sys.executable, "-c", code], check=True)
5758

59+
def test_jax_zero_gradient():
60+
jx = jax.numpy.arange(4)
61+
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
62+
assert (array_api_compat.get_namespace(jax_zero) is
63+
array_api_compat.get_namespace(jx))
64+
5865
def test_array_namespace_errors():
5966
pytest.raises(TypeError, lambda: array_namespace([1]))
6067
pytest.raises(TypeError, lambda: array_namespace())

0 commit comments

Comments
 (0)