diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e7a868c9..982b284a 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -18,6 +18,20 @@ import inspect import warnings +def _is_jax_zero_gradient_array(x): + """Return True if `x` is a zero-gradient array. + + These arrays are a design quirk of Jax that may one day be removed. + See https://github.com/google/jax/issues/20620. + """ + if 'numpy' not in sys.modules or 'jax' not in sys.modules: + return False + + import numpy as np + import jax + + return isinstance(x, np.ndarray) and x.dtype == jax.float0 + def is_numpy_array(x): """ Return True if `x` is a NumPy array. @@ -44,7 +58,8 @@ def is_numpy_array(x): import numpy as np # TODO: Should we reject ndarray subclasses? - return isinstance(x, (np.ndarray, np.generic)) + return (isinstance(x, (np.ndarray, np.generic)) + and not _is_jax_zero_gradient_array(x)) def is_cupy_array(x): """ @@ -149,7 +164,7 @@ def is_jax_array(x): import jax - return isinstance(x, jax.Array) + return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) def is_array_api_obj(x): """ diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 78705189..f5454bff 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,6 +2,7 @@ import sys import warnings +import jax import numpy as np import pytest import torch @@ -55,6 +56,12 @@ def test_array_namespace(library, api_version, use_compat): """ subprocess.run([sys.executable, "-c", code], check=True) +def test_jax_zero_gradient(): + jx = jax.numpy.arange(4) + jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) + assert (array_api_compat.get_namespace(jax_zero) is + array_api_compat.get_namespace(jx)) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace())