Skip to content

Commit 264e6c3

Browse files
committed
Add jax.numpy to the test_array_namespace()
This requires using subprocess to test that it works even if the side-effecting jax.experimental.array_api hasn't been imported yet.
1 parent bff9bf2 commit 264e6c3

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

tests/test_array_namespace.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import subprocess
2+
import sys
3+
14
import numpy as np
25
import pytest
36
import torch
@@ -7,8 +10,7 @@
710

811
from ._helpers import import_
912

10-
11-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
13+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
1214
@pytest.mark.parametrize("api_version", [None, "2021.12"])
1315
def test_array_namespace(library, api_version):
1416
xp = import_(library)
@@ -21,9 +23,31 @@ def test_array_namespace(library, api_version):
2123
else:
2224
if library == "dask.array":
2325
assert namespace == array_api_compat.dask.array
26+
elif library == "jax.numpy":
27+
import jax.experimental.array_api
28+
assert namespace == jax.experimental.array_api
2429
else:
2530
assert namespace == getattr(array_api_compat, library)
2631

32+
# Check that array_namespace works even if jax.experimental.array_api
33+
# hasn't been imported yet (it monkeypatches __array_namespace__
34+
# onto JAX arrays, but we should support them regardless). The only way to
35+
# do this is to use a subprocess, since we cannot un-import it and another
36+
# test probably already imported it.
37+
if library == "jax.numpy":
38+
code = f"""\
39+
import sys
40+
import jax.numpy
41+
import array_api_compat
42+
array = jax.numpy.asarray([1.0, 2.0, 3.0])
43+
44+
assert 'jax.experimental.array_api' not in sys.modules
45+
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
46+
47+
import jax.experimental.array_api
48+
assert namespace == jax.experimental.array_api
49+
"""
50+
subprocess.run([sys.executable, "-c", code], check=True)
2751

2852
def test_array_namespace_errors():
2953
pytest.raises(TypeError, lambda: array_namespace([1]))

0 commit comments

Comments
 (0)