Skip to content

Commit 158622e

Browse files
authored
Merge pull request #172 from jakevdp/jax-native
Future-proof JAX array API import
2 parents d57c671 + d570ebc commit 158622e

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

Diff for: array_api_compat/common/_helpers.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,13 @@ def your_function(x, y):
377377
elif use_compat is False:
378378
import jax.numpy as jnp
379379
else:
380-
# jax.experimental.array_api is already an array namespace. We do
381-
# not have a wrapper submodule for it.
382-
import jax.experimental.array_api as jnp
380+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
381+
# For older JAX versions, it is available via jax.experimental.array_api.
382+
import jax.numpy
383+
if hasattr(jax.numpy, "__array_api_version__"):
384+
jnp = jax.numpy
385+
else:
386+
import jax.experimental.array_api as jnp
383387
namespaces.add(jnp)
384388
elif is_pydata_sparse_array(x):
385389
if use_compat is True:
@@ -613,8 +617,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
613617
return x
614618
raise ValueError(f"Unsupported device {device!r}")
615619
elif is_jax_array(x):
616-
# This import adds to_device to x
617-
import jax.experimental.array_api # noqa: F401
620+
if not hasattr(x, "__array_namespace__"):
621+
# In JAX v0.4.31 and older, this import adds to_device method to x.
622+
import jax.experimental.array_api # noqa: F401
618623
return x.to_device(device, stream=stream)
619624
elif is_pydata_sparse_array(x) and device == _device(x):
620625
# Perform trivial check to return the same array if

Diff for: docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ import array_api_compat.dask as da
6363
```{note}
6464
There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These
6565
support for these libraries is contained in the libraries themselves (JAX
66-
support is in the `jax.experimental.array_api` module). The
66+
support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the
67+
`jax.experimental.array_api` module for older JAX versions). The
6768
array-api-compat support for these libraries consists of supporting them in
6869
the [helper functions](helper-functions).
6970
```

Diff for: tests/_helpers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ def import_(library, wrapper=False):
1818
pytest.importorskip(library)
1919
if wrapper:
2020
if 'jax' in library:
21-
library = 'jax.experimental.array_api'
21+
# JAX v0.4.32 implements the array API directly in jax.numpy
22+
# Older jax versions use jax.experimental.array_api
23+
jax_numpy = import_module("jax.numpy")
24+
if not hasattr(jax_numpy, "__array_api_version__"):
25+
library = 'jax.experimental.array_api'
2226
elif library.startswith('sparse'):
2327
library = 'sparse'
2428
else:

Diff for: tests/test_array_namespace.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ def test_array_namespace(library, api_version, use_compat):
2626

2727
if use_compat is False or use_compat is None and library not in wrapped_libraries:
2828
if library == "jax.numpy" and use_compat is None:
29-
import jax.experimental.array_api
30-
assert namespace == jax.experimental.array_api
29+
import jax.numpy
30+
if hasattr(jax.numpy, "__array_api_version__"):
31+
# JAX v0.4.32 or later uses jax.numpy directly
32+
assert namespace == jax.numpy
33+
else:
34+
# JAX v0.4.31 or earlier uses jax.experimental.array_api
35+
import jax.experimental.array_api
36+
assert namespace == jax.experimental.array_api
3137
else:
3238
assert namespace == xp
3339
else:
@@ -58,8 +64,11 @@ def test_array_namespace(library, api_version, use_compat):
5864
assert 'jax.experimental.array_api' not in sys.modules
5965
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
6066
61-
import jax.experimental.array_api
62-
assert namespace == jax.experimental.array_api
67+
if hasattr(jax.numpy, '__array_api_version__'):
68+
assert namespace == jax.numpy
69+
else:
70+
import jax.experimental.array_api
71+
assert namespace == jax.experimental.array_api
6372
"""
6473
subprocess.run([sys.executable, "-c", code], check=True)
6574

0 commit comments

Comments
 (0)