Skip to content

Commit 6dfc00b

Browse files
authored
Merge pull request #6 from data-apis/cupy
CuPy support
2 parents 391b08b + 732b493 commit 6dfc00b

26 files changed

+1169
-517
lines changed

README.md

+157-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1-
# NumPy Array API compatibility library
1+
# Array API compatibility library
22

3-
This is a small wrapper around NumPy that is compatible with the [Array API
4-
standard](https://data-apis.org/array-api/latest/). See also [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html).
3+
This is a small wrapper around NumPy and CuPy that is compatible with the
4+
[Array API standard](https://data-apis.org/array-api/latest/). See also [NEP
5+
47](https://numpy.org/neps/nep-0047-array-api-standard.html).
56

67
Unlike `numpy.array_api`, this is not a strict minimal implementation of the
7-
Array API, but rather just an extension of the main NumPy namespace with
8-
changes needed to be compliant with the Array API. See
9-
https://numpy.org/doc/stable/reference/array_api.html for a full list of
8+
Array API, but rather just an extension of the main NumPy and CuPy namespaces
9+
with changes needed to be compliant with the Array API.
10+
11+
Library authors using the Array API may wish to test against `numpy.array_api`
12+
to ensure they are not using functionality outside of the standard, but prefer
13+
this implementation for the default when working with NumPy or CuPy arrays.
14+
15+
See https://numpy.org/doc/stable/reference/array_api.html for a full list of
1016
changes. In particular, unlike `numpy.array_api`, this package does not use a
1117
separate Array object, but rather just uses `numpy.ndarray` directly.
1218

1319
Note that some of the functionality in this library is backwards incompatible
1420
with NumPy.
1521

22+
This library also supports CuPy in addition to NumPy. If you want support for
23+
other array libraries, please [open an
24+
issue](https://github.com/data-apis/array-api-compat/issues).
25+
1626
Library authors using the Array API may wish to test against `numpy.array_api`
1727
to ensure they are not using functionality outside of the standard, but prefer
1828
this implementation for end users who use NumPy arrays.
@@ -28,5 +38,145 @@ import numpy as np
2838
with
2939

3040
```py
31-
import numpy_array_api_compat as np
41+
import array_api_compat.numpy as np
3242
```
43+
44+
and replace
45+
46+
```py
47+
import cupy as cp
48+
```
49+
50+
with
51+
52+
```py
53+
import array_api_compat.cupy as cp
54+
```
55+
56+
Each will include all the functions from the normal NumPy/CuPy namespace,
57+
except that functions that are part of the array API are wrapped so that they
58+
have the correct array API behavior. In each case, the array object used will
59+
be thew same array object from the wrapped library.
60+
61+
62+
## Helper Functions
63+
64+
In addition to the default NumPy/CuPy namespace and functions in the array API
65+
specification, there are several helper functions
66+
included that aren't part of the specification but which are useful for using
67+
the array API:
68+
69+
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
70+
object.
71+
72+
- `get_namespace(*xs)`: Get the corresponding array API namespace for the
73+
arrays `xs`. If the arrays are NumPy or CuPy arrays, the returned namespace
74+
will be `array_api_compat.numpy` or `array_api_compat.cupy` so that it is
75+
array API compatible.
76+
77+
- `device(x)`: Equivalent to
78+
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
79+
in the array API specification. Included because `numpy.ndarray` does not
80+
include the `device` attribute and this library does not wrap or extend the
81+
array object. Note that for NumPy, `device` is always `"cpu"`.
82+
83+
- `to_device(x, device, /, *, stream=None)`: Equivalent to
84+
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
85+
Included because neither NumPy's nor CuPy's ndarray objects include this
86+
method. For NumPy, this function effectively does nothing since the only
87+
supported device is the CPU, but for CuPy, this method supports CuPy CUDA
88+
[Device](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Device.html)
89+
and
90+
[Stream](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html)
91+
objects.
92+
93+
## Known Differences from the Array API Specification
94+
95+
There are some known differences between this library and the array API
96+
specification:
97+
98+
- The array methods `__array_namespace__`, `device` (for NumPy), `to_device`,
99+
and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we
100+
don't want to monkeypatch or wrap it. The helper functions `device()` and
101+
`to_device()` are provided to work around these missing methods (see above).
102+
`x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`.
103+
`get_namespace(x)` should be used instead of `x.__array_namespace__`.
104+
105+
- NumPy value-based casting for scalars will be in effect unless explicitly
106+
disabled with the environment variable NPY_PROMOTION_STATE=weak or
107+
np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50
108+
and https://github.com/numpy/numpy/issues/22341)
109+
110+
- Functions which are not wrapped may not have the same type annotations
111+
as the spec.
112+
113+
- Functions which are not wrapped may not use positional-only arguments.
114+
115+
## Vendoring
116+
117+
This library supports vendoring as an installation method. To vendor the
118+
library, simply copy `array_api_compat` into the appropriate place in the
119+
library, like
120+
121+
```
122+
cp -R array_api_compat/ mylib/vendored/array_api_compat
123+
```
124+
125+
You may also rename it to something else if you like (nowhere in the code
126+
references the name "array_api_compat").
127+
128+
Alternatively, the library may be installed as dependency on PyPI.
129+
130+
## Implementation
131+
132+
As noted before, the goal of this library is to reuse the NumPy and CuPy array
133+
objects, rather than wrapping or extending them. This means that the functions
134+
need to accept and return `np.ndarray` for NumPy and `cp.ndarray` for CuPy.
135+
136+
Each namespace (`array_api_compat.numpy` and `array_api_compat.cupy`) is
137+
populated with the normal library namespace (like `from numpy import *`). Then
138+
specific functions are replaced with wrapped variants. Wrapped functions that
139+
have the same logic between NumPy and CuPy (which is most functions) are in
140+
`array_api_compat/common/`. These functions are defined like
141+
142+
```py
143+
# In array_api_compat/common/_aliases.py
144+
145+
def acos(x, /, xp):
146+
return xp.arccos(x)
147+
```
148+
149+
The `xp` argument refers to the original array namespace (either `numpy` or
150+
`cupy`). Then in the specific `array_api_compat/numpy` and
151+
`array_api_compat/cupy` namespace, the `get_xp` decorator is applied to these
152+
functions, which automatically removes the `xp` argument from the function
153+
signature and replaces it with the corresponding array library, like
154+
155+
```py
156+
# In array_api_compat/numpy/_aliases.py
157+
158+
from ..common import _aliases
159+
160+
import numpy as np
161+
162+
acos = get_xp(np)(_aliases.acos)
163+
```
164+
165+
This `acos` now has the signature `acos(x, /)` and calls `numpy.arccos`.
166+
167+
Similarly, for CuPy:
168+
169+
```py
170+
# In array_api_compat/cupy/_aliases.py
171+
172+
from ..common import _aliases
173+
174+
import cupy as cp
175+
176+
acos = get_xp(cp)(_aliases.acos)
177+
```
178+
179+
Since NumPy and CuPy are nearly identical in their behaviors, this allows
180+
writing the wrapping logic for both libraries only once. If support is added
181+
for other libraries which differ significantly from NumPy, their wrapper code
182+
should go in their specific sub-namespace instead of `common/`.

array_api_compat/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
NumPy Array API compatibility library
3+
4+
This is a small wrapper around NumPy and CuPy that is compatible with the
5+
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6+
https://numpy.org/neps/nep-0047-array-api-standard.html.
7+
8+
Unlike numpy.array_api, this is not a strict minimal implementation of the
9+
Array API, but rather just an extension of the main NumPy namespace with
10+
changes needed to be compliant with the Array API. See
11+
https://numpy.org/doc/stable/reference/array_api.html for a full list of
12+
changes. In particular, unlike numpy.array_api, this package does not use a
13+
separate Array object, but rather just uses numpy.ndarray directly.
14+
15+
Library authors using the Array API may wish to test against numpy.array_api
16+
to ensure they are not using functionality outside of the standard, but prefer
17+
this implementation for the default when working with NumPy arrays.
18+
19+
"""
20+
from .common import *

array_api_compat/_internal.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Internal helpers
3+
"""
4+
5+
from functools import wraps
6+
from inspect import signature
7+
8+
def get_xp(xp):
9+
"""
10+
Decorator to automatically replace xp with the corresponding array module.
11+
12+
Use like
13+
14+
import numpy as np
15+
16+
@get_xp(np)
17+
def func(x, /, xp, kwarg=None):
18+
return xp.func(x, kwarg=kwarg)
19+
20+
Note that xp must be a keyword argument and come after all non-keyword
21+
arguments.
22+
23+
"""
24+
def inner(f):
25+
@wraps(f)
26+
def wrapped_f(*args, **kwargs):
27+
return f(*args, xp=xp, **kwargs)
28+
29+
sig = signature(f)
30+
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
31+
32+
if wrapped_f.__doc__ is None:
33+
wrapped_f.__doc__ = f"""\
34+
Array API compatibility wrapper for {f.__name__}.
35+
36+
See the corresponding documentation in NumPy/CuPy and/or the array API
37+
specification for more details.
38+
39+
"""
40+
wrapped_f.__signature__ = new_sig
41+
return wrapped_f
42+
43+
return inner

array_api_compat/common/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._helpers import *

0 commit comments

Comments
 (0)