forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_array_namespace.py
110 lines (91 loc) · 4.07 KB
/
test_array_namespace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import subprocess
import sys
import warnings
import jax
import numpy as np
import pytest
import torch
import array_api_compat
from array_api_compat import array_namespace
from ._helpers import import_, all_libraries, wrapped_libraries
@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)
array = xp.asarray([1.0, 2.0, 3.0])
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
else:
assert namespace == xp
else:
if library == "dask.array":
assert namespace == array_api_compat.dask.array
else:
assert namespace == getattr(array_api_compat, library)
# Check that array_namespace works even if jax.experimental.array_api
# hasn't been imported yet (it monkeypatches __array_namespace__
# onto JAX arrays, but we should support them regardless). The only way to
# do this is to use a subprocess, since we cannot un-import it and another
# test probably already imported it.
if library == "jax.numpy" and sys.version_info >= (3, 9):
code = f"""\
import sys
import jax.numpy
import array_api_compat
array = jax.numpy.asarray([1.0, 2.0, 3.0])
assert 'jax.experimental.array_api' not in sys.modules
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
"""
subprocess.run([sys.executable, "-c", code], check=True)
def test_jax_zero_gradient():
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
def test_array_namespace_errors_torch():
y = torch.asarray([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
def test_api_version():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2022.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
def test_python_scalars():
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)
pytest.raises(TypeError, lambda: array_namespace(1))
pytest.raises(TypeError, lambda: array_namespace(1.0))
pytest.raises(TypeError, lambda: array_namespace(1j))
pytest.raises(TypeError, lambda: array_namespace(True))
pytest.raises(TypeError, lambda: array_namespace(None))
assert array_namespace(a, 1) == xp
assert array_namespace(a, 1.0) == xp
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp