1
+ import subprocess
2
+ import sys
3
+
1
4
import numpy as np
2
5
import pytest
3
6
import torch
7
10
8
11
from ._helpers import import_
9
12
10
-
11
- @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" , "dask.array" ])
13
+ @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" , "dask.array" , "jax.numpy" ])
12
14
@pytest .mark .parametrize ("api_version" , [None , "2021.12" ])
13
15
def test_array_namespace (library , api_version ):
14
16
xp = import_ (library )
@@ -21,9 +23,31 @@ def test_array_namespace(library, api_version):
21
23
else :
22
24
if library == "dask.array" :
23
25
assert namespace == array_api_compat .dask .array
26
+ elif library == "jax.numpy" :
27
+ import jax .experimental .array_api
28
+ assert namespace == jax .experimental .array_api
24
29
else :
25
30
assert namespace == getattr (array_api_compat , library )
26
31
32
+ # Check that array_namespace works even if jax.experimental.array_api
33
+ # hasn't been imported yet (it monkeypatches __array_namespace__
34
+ # onto JAX arrays, but we should support them regardless). The only way to
35
+ # do this is to use a subprocess, since we cannot un-import it and another
36
+ # test probably already imported it.
37
+ if library == "jax.numpy" :
38
+ code = f"""\
39
+ import sys
40
+ import jax.numpy
41
+ import array_api_compat
42
+ array = jax.numpy.asarray([1.0, 2.0, 3.0])
43
+
44
+ assert 'jax.experimental.array_api' not in sys.modules
45
+ namespace = array_api_compat.array_namespace(array, api_version={ api_version !r} )
46
+
47
+ import jax.experimental.array_api
48
+ assert namespace == jax.experimental.array_api
49
+ """
50
+ subprocess .run ([sys .executable , "-c" , code ], check = True )
27
51
28
52
def test_array_namespace_errors ():
29
53
pytest .raises (TypeError , lambda : array_namespace ([1 ]))
0 commit comments