Skip to content

Commit 645f9a8

Browse files
authored
Merge pull request data-apis#84 from asmeurer/jax
Add basic JAX support
2 parents dab01be + e7aff0f commit 645f9a8

File tree

9 files changed

+175
-40
lines changed

9 files changed

+175
-40
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Install Dependencies
1616
run: |
1717
python -m pip install --upgrade pip
18-
python -m pip install pytest numpy torch dask[array]
18+
python -m pip install pytest numpy torch dask[array] jax[cpu]
1919
2020
- name: Run Tests
2121
run: |

README.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, and PyTorch are supported. If you want support for other array
5+
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
66
libraries, or if you encounter any issues, please [open an
77
issue](https://github.com/data-apis/array-api-compat/issues).
88

@@ -56,7 +56,17 @@ import array_api_compat.cupy as cp
5656
import array_api_compat.torch as torch
5757
```
5858

59-
Each will include all the functions from the normal NumPy/CuPy/PyTorch
59+
```py
60+
import array_api_compat.dask as da
61+
```
62+
63+
> [!NOTE]
64+
> There is no `array_api_compat.jax` submodule. JAX support is contained
65+
> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply
66+
> wraps that submodule. The main JAX support in this module consists of
67+
> supporting it in the [helper functions](#helper-functions) defined below.
68+
69+
Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array
6070
namespace, except that functions that are part of the array API are wrapped so
6171
that they have the correct array API behavior. In each case, the array object
6272
used will be the same array object from the wrapped library.
@@ -99,6 +109,11 @@ part of the specification but which are useful for using the array API:
99109
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
100110
object.
101111

112+
- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`,
113+
`is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from
114+
the corresponding library. These functions do not import the underlying
115+
library if it has not already been imported, so they are cheap to use.
116+
102117
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
103118
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
104119
namespace will be `array_api_compat.numpy`. Note that this function will
@@ -219,6 +234,12 @@ version.
219234

220235
The minimum supported PyTorch version is 1.13.
221236

237+
### JAX
238+
239+
Unlike the other libraries supported here, JAX array API support is contained
240+
entirely in the JAX library. The JAX array API support is tracked at
241+
https://github.com/google/jax/issues/18353.
242+
222243
## Vendoring
223244

224245
This library supports vendoring as an installation method. To vendor the

array_api_compat/common/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
device,
44
get_namespace,
55
is_array_api_obj,
6+
is_cupy_array,
7+
is_dask_array,
8+
is_jax_array,
9+
is_numpy_array,
10+
is_torch_array,
611
size,
712
to_device,
813
)
@@ -12,6 +17,11 @@
1217
"device",
1318
"get_namespace",
1419
"is_array_api_obj",
20+
"is_cupy_array",
21+
"is_dask_array",
22+
"is_jax_array",
23+
"is_numpy_array",
24+
"is_torch_array",
1525
"size",
1626
"to_device",
1727
]

array_api_compat/common/_aliases.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from types import ModuleType
1515
import inspect
1616

17-
from ._helpers import _check_device, _is_numpy_array, array_namespace
17+
from ._helpers import _check_device, is_numpy_array, array_namespace
1818

1919
# These functions are modified from the NumPy versions.
2020

@@ -310,7 +310,7 @@ def _asarray(
310310
raise ValueError("Unrecognized namespace argument to asarray()")
311311

312312
_check_device(xp, device)
313-
if _is_numpy_array(obj):
313+
if is_numpy_array(obj):
314314
import numpy as np
315315
if hasattr(np, '_CopyMode'):
316316
# Not present in older NumPys

array_api_compat/common/_helpers.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
if TYPE_CHECKING:
1313
from typing import Optional, Union, Any
14-
from ._typing import Array, Device
14+
from ._typing import Array, Device
1515

1616
import sys
1717
import math
18+
import inspect
1819

19-
def _is_numpy_array(x):
20+
def is_numpy_array(x):
2021
# Avoid importing NumPy if it isn't already
2122
if 'numpy' not in sys.modules:
2223
return False
@@ -26,7 +27,7 @@ def _is_numpy_array(x):
2627
# TODO: Should we reject ndarray subclasses?
2728
return isinstance(x, (np.ndarray, np.generic))
2829

29-
def _is_cupy_array(x):
30+
def is_cupy_array(x):
3031
# Avoid importing NumPy if it isn't already
3132
if 'cupy' not in sys.modules:
3233
return False
@@ -36,7 +37,7 @@ def _is_cupy_array(x):
3637
# TODO: Should we reject ndarray subclasses?
3738
return isinstance(x, (cp.ndarray, cp.generic))
3839

39-
def _is_torch_array(x):
40+
def is_torch_array(x):
4041
# Avoid importing torch if it isn't already
4142
if 'torch' not in sys.modules:
4243
return False
@@ -46,7 +47,7 @@ def _is_torch_array(x):
4647
# TODO: Should we reject ndarray subclasses?
4748
return isinstance(x, torch.Tensor)
4849

49-
def _is_dask_array(x):
50+
def is_dask_array(x):
5051
# Avoid importing dask if it isn't already
5152
if 'dask.array' not in sys.modules:
5253
return False
@@ -55,14 +56,24 @@ def _is_dask_array(x):
5556

5657
return isinstance(x, dask.array.Array)
5758

59+
def is_jax_array(x):
60+
# Avoid importing jax if it isn't already
61+
if 'jax' not in sys.modules:
62+
return False
63+
64+
import jax
65+
66+
return isinstance(x, jax.Array)
67+
5868
def is_array_api_obj(x):
5969
"""
6070
Check if x is an array API compatible array object.
6171
"""
62-
return _is_numpy_array(x) \
63-
or _is_cupy_array(x) \
64-
or _is_torch_array(x) \
65-
or _is_dask_array(x) \
72+
return is_numpy_array(x) \
73+
or is_cupy_array(x) \
74+
or is_torch_array(x) \
75+
or is_dask_array(x) \
76+
or is_jax_array(x) \
6677
or hasattr(x, '__array_namespace__')
6778

6879
def _check_api_version(api_version):
@@ -87,37 +98,43 @@ def your_function(x, y):
8798
"""
8899
namespaces = set()
89100
for x in xs:
90-
if _is_numpy_array(x):
101+
if is_numpy_array(x):
91102
_check_api_version(api_version)
92103
if _use_compat:
93104
from .. import numpy as numpy_namespace
94105
namespaces.add(numpy_namespace)
95106
else:
96107
import numpy as np
97108
namespaces.add(np)
98-
elif _is_cupy_array(x):
109+
elif is_cupy_array(x):
99110
_check_api_version(api_version)
100111
if _use_compat:
101112
from .. import cupy as cupy_namespace
102113
namespaces.add(cupy_namespace)
103114
else:
104115
import cupy as cp
105116
namespaces.add(cp)
106-
elif _is_torch_array(x):
117+
elif is_torch_array(x):
107118
_check_api_version(api_version)
108119
if _use_compat:
109120
from .. import torch as torch_namespace
110121
namespaces.add(torch_namespace)
111122
else:
112123
import torch
113124
namespaces.add(torch)
114-
elif _is_dask_array(x):
125+
elif is_dask_array(x):
115126
_check_api_version(api_version)
116127
if _use_compat:
117128
from ..dask import array as dask_namespace
118129
namespaces.add(dask_namespace)
119130
else:
120131
raise TypeError("_use_compat cannot be False if input array is a dask array!")
132+
elif is_jax_array(x):
133+
_check_api_version(api_version)
134+
# jax.experimental.array_api is already an array namespace. We do
135+
# not have a wrapper submodule for it.
136+
import jax.experimental.array_api as jnp
137+
namespaces.add(jnp)
121138
elif hasattr(x, '__array_namespace__'):
122139
namespaces.add(x.__array_namespace__(api_version=api_version))
123140
else:
@@ -142,7 +159,7 @@ def _check_device(xp, device):
142159
if device not in ["cpu", None]:
143160
raise ValueError(f"Unsupported device for NumPy: {device!r}")
144161

145-
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
162+
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
146163
# or cupy.ndarray. They are not included in array objects of this library
147164
# because this library just reuses the respective ndarray classes without
148165
# wrapping or subclassing them. These helper functions can be used instead of
@@ -162,8 +179,17 @@ def device(x: Array, /) -> Device:
162179
out: device
163180
a ``device`` object (see the "Device Support" section of the array API specification).
164181
"""
165-
if _is_numpy_array(x):
182+
if is_numpy_array(x):
166183
return "cpu"
184+
if is_jax_array(x):
185+
# JAX has .device() as a method, but it is being deprecated so that it
186+
# can become a property, in accordance with the standard. In order for
187+
# this function to not break when JAX makes the flip, we check for
188+
# both here.
189+
if inspect.ismethod(x.device):
190+
return x.device()
191+
else:
192+
return x.device
167193
return x.device
168194

169195
# Based on cupy.array_api.Array.to_device
@@ -231,24 +257,28 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
231257
.. note::
232258
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
233259
"""
234-
if _is_numpy_array(x):
260+
if is_numpy_array(x):
235261
if stream is not None:
236262
raise ValueError("The stream argument to to_device() is not supported")
237263
if device == 'cpu':
238264
return x
239265
raise ValueError(f"Unsupported device {device!r}")
240-
elif _is_cupy_array(x):
266+
elif is_cupy_array(x):
241267
# cupy does not yet have to_device
242268
return _cupy_to_device(x, device, stream=stream)
243-
elif _is_torch_array(x):
269+
elif is_torch_array(x):
244270
return _torch_to_device(x, device, stream=stream)
245-
elif _is_dask_array(x):
271+
elif is_dask_array(x):
246272
if stream is not None:
247273
raise ValueError("The stream argument to to_device() is not supported")
248274
# TODO: What if our array is on the GPU already?
249275
if device == 'cpu':
250276
return x
251277
raise ValueError(f"Unsupported device {device!r}")
278+
elif is_jax_array(x):
279+
# This import adds to_device to x
280+
import jax.experimental.array_api # noqa: F401
281+
return x.to_device(device, stream=stream)
252282
return x.to_device(device, stream=stream)
253283

254284
def size(x):

tests/_helpers.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
from importlib import import_module
22

3+
import sys
4+
35
import pytest
46

57

6-
def import_or_skip_cupy(library):
7-
if "cupy" in library:
8+
def import_(library, wrapper=False):
9+
if library == 'cupy':
810
return pytest.importorskip(library)
11+
if 'jax' in library and sys.version_info <= (3, 8):
12+
pytest.skip('JAX array API support does not support Python 3.8')
13+
14+
if wrapper:
15+
if 'jax' in library:
16+
library = 'jax.experimental.array_api'
17+
else:
18+
library = 'array_api_compat.' + library
19+
920
return import_module(library)

tests/test_array_namespace.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
import subprocess
2+
import sys
3+
14
import numpy as np
25
import pytest
36
import torch
47

58
import array_api_compat
69
from array_api_compat import array_namespace
710

8-
from ._helpers import import_or_skip_cupy
9-
11+
from ._helpers import import_
1012

11-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
13+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
1214
@pytest.mark.parametrize("api_version", [None, "2021.12"])
1315
def test_array_namespace(library, api_version):
14-
xp = import_or_skip_cupy(library)
16+
xp = import_(library)
1517

1618
array = xp.asarray([1.0, 2.0, 3.0])
1719
namespace = array_api_compat.array_namespace(array, api_version=api_version)
@@ -21,9 +23,31 @@ def test_array_namespace(library, api_version):
2123
else:
2224
if library == "dask.array":
2325
assert namespace == array_api_compat.dask.array
26+
elif library == "jax.numpy":
27+
import jax.experimental.array_api
28+
assert namespace == jax.experimental.array_api
2429
else:
2530
assert namespace == getattr(array_api_compat, library)
2631

32+
# Check that array_namespace works even if jax.experimental.array_api
33+
# hasn't been imported yet (it monkeypatches __array_namespace__
34+
# onto JAX arrays, but we should support them regardless). The only way to
35+
# do this is to use a subprocess, since we cannot un-import it and another
36+
# test probably already imported it.
37+
if library == "jax.numpy" and sys.version_info >= (3, 9):
38+
code = f"""\
39+
import sys
40+
import jax.numpy
41+
import array_api_compat
42+
array = jax.numpy.asarray([1.0, 2.0, 3.0])
43+
44+
assert 'jax.experimental.array_api' not in sys.modules
45+
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
46+
47+
import jax.experimental.array_api
48+
assert namespace == jax.experimental.array_api
49+
"""
50+
subprocess.run([sys.executable, "-c", code], check=True)
2751

2852
def test_array_namespace_errors():
2953
pytest.raises(TypeError, lambda: array_namespace([1]))

0 commit comments

Comments
 (0)