11
11
12
12
if TYPE_CHECKING :
13
13
from typing import Optional , Union , Any
14
- from ._typing import Array , Device
14
+ from ._typing import Array , Device
15
15
16
16
import sys
17
17
import math
18
+ import inspect
18
19
19
- def _is_numpy_array (x ):
20
+ def is_numpy_array (x ):
20
21
# Avoid importing NumPy if it isn't already
21
22
if 'numpy' not in sys .modules :
22
23
return False
@@ -26,7 +27,7 @@ def _is_numpy_array(x):
26
27
# TODO: Should we reject ndarray subclasses?
27
28
return isinstance (x , (np .ndarray , np .generic ))
28
29
29
- def _is_cupy_array (x ):
30
+ def is_cupy_array (x ):
30
31
# Avoid importing NumPy if it isn't already
31
32
if 'cupy' not in sys .modules :
32
33
return False
@@ -36,7 +37,7 @@ def _is_cupy_array(x):
36
37
# TODO: Should we reject ndarray subclasses?
37
38
return isinstance (x , (cp .ndarray , cp .generic ))
38
39
39
- def _is_torch_array (x ):
40
+ def is_torch_array (x ):
40
41
# Avoid importing torch if it isn't already
41
42
if 'torch' not in sys .modules :
42
43
return False
@@ -46,7 +47,7 @@ def _is_torch_array(x):
46
47
# TODO: Should we reject ndarray subclasses?
47
48
return isinstance (x , torch .Tensor )
48
49
49
- def _is_dask_array (x ):
50
+ def is_dask_array (x ):
50
51
# Avoid importing dask if it isn't already
51
52
if 'dask.array' not in sys .modules :
52
53
return False
@@ -55,14 +56,24 @@ def _is_dask_array(x):
55
56
56
57
return isinstance (x , dask .array .Array )
57
58
59
+ def is_jax_array (x ):
60
+ # Avoid importing jax if it isn't already
61
+ if 'jax' not in sys .modules :
62
+ return False
63
+
64
+ import jax
65
+
66
+ return isinstance (x , jax .Array )
67
+
58
68
def is_array_api_obj (x ):
59
69
"""
60
70
Check if x is an array API compatible array object.
61
71
"""
62
- return _is_numpy_array (x ) \
63
- or _is_cupy_array (x ) \
64
- or _is_torch_array (x ) \
65
- or _is_dask_array (x ) \
72
+ return is_numpy_array (x ) \
73
+ or is_cupy_array (x ) \
74
+ or is_torch_array (x ) \
75
+ or is_dask_array (x ) \
76
+ or is_jax_array (x ) \
66
77
or hasattr (x , '__array_namespace__' )
67
78
68
79
def _check_api_version (api_version ):
@@ -87,37 +98,43 @@ def your_function(x, y):
87
98
"""
88
99
namespaces = set ()
89
100
for x in xs :
90
- if _is_numpy_array (x ):
101
+ if is_numpy_array (x ):
91
102
_check_api_version (api_version )
92
103
if _use_compat :
93
104
from .. import numpy as numpy_namespace
94
105
namespaces .add (numpy_namespace )
95
106
else :
96
107
import numpy as np
97
108
namespaces .add (np )
98
- elif _is_cupy_array (x ):
109
+ elif is_cupy_array (x ):
99
110
_check_api_version (api_version )
100
111
if _use_compat :
101
112
from .. import cupy as cupy_namespace
102
113
namespaces .add (cupy_namespace )
103
114
else :
104
115
import cupy as cp
105
116
namespaces .add (cp )
106
- elif _is_torch_array (x ):
117
+ elif is_torch_array (x ):
107
118
_check_api_version (api_version )
108
119
if _use_compat :
109
120
from .. import torch as torch_namespace
110
121
namespaces .add (torch_namespace )
111
122
else :
112
123
import torch
113
124
namespaces .add (torch )
114
- elif _is_dask_array (x ):
125
+ elif is_dask_array (x ):
115
126
_check_api_version (api_version )
116
127
if _use_compat :
117
128
from ..dask import array as dask_namespace
118
129
namespaces .add (dask_namespace )
119
130
else :
120
131
raise TypeError ("_use_compat cannot be False if input array is a dask array!" )
132
+ elif is_jax_array (x ):
133
+ _check_api_version (api_version )
134
+ # jax.experimental.array_api is already an array namespace. We do
135
+ # not have a wrapper submodule for it.
136
+ import jax .experimental .array_api as jnp
137
+ namespaces .add (jnp )
121
138
elif hasattr (x , '__array_namespace__' ):
122
139
namespaces .add (x .__array_namespace__ (api_version = api_version ))
123
140
else :
@@ -142,7 +159,7 @@ def _check_device(xp, device):
142
159
if device not in ["cpu" , None ]:
143
160
raise ValueError (f"Unsupported device for NumPy: { device !r} " )
144
161
145
- # device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
162
+ # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
146
163
# or cupy.ndarray. They are not included in array objects of this library
147
164
# because this library just reuses the respective ndarray classes without
148
165
# wrapping or subclassing them. These helper functions can be used instead of
@@ -162,8 +179,17 @@ def device(x: Array, /) -> Device:
162
179
out: device
163
180
a ``device`` object (see the "Device Support" section of the array API specification).
164
181
"""
165
- if _is_numpy_array (x ):
182
+ if is_numpy_array (x ):
166
183
return "cpu"
184
+ if is_jax_array (x ):
185
+ # JAX has .device() as a method, but it is being deprecated so that it
186
+ # can become a property, in accordance with the standard. In order for
187
+ # this function to not break when JAX makes the flip, we check for
188
+ # both here.
189
+ if inspect .ismethod (x .device ):
190
+ return x .device ()
191
+ else :
192
+ return x .device
167
193
return x .device
168
194
169
195
# Based on cupy.array_api.Array.to_device
@@ -231,24 +257,28 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
231
257
.. note::
232
258
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
233
259
"""
234
- if _is_numpy_array (x ):
260
+ if is_numpy_array (x ):
235
261
if stream is not None :
236
262
raise ValueError ("The stream argument to to_device() is not supported" )
237
263
if device == 'cpu' :
238
264
return x
239
265
raise ValueError (f"Unsupported device { device !r} " )
240
- elif _is_cupy_array (x ):
266
+ elif is_cupy_array (x ):
241
267
# cupy does not yet have to_device
242
268
return _cupy_to_device (x , device , stream = stream )
243
- elif _is_torch_array (x ):
269
+ elif is_torch_array (x ):
244
270
return _torch_to_device (x , device , stream = stream )
245
- elif _is_dask_array (x ):
271
+ elif is_dask_array (x ):
246
272
if stream is not None :
247
273
raise ValueError ("The stream argument to to_device() is not supported" )
248
274
# TODO: What if our array is on the GPU already?
249
275
if device == 'cpu' :
250
276
return x
251
277
raise ValueError (f"Unsupported device { device !r} " )
278
+ elif is_jax_array (x ):
279
+ # This import adds to_device to x
280
+ import jax .experimental .array_api # noqa: F401
281
+ return x .to_device (device , stream = stream )
252
282
return x .to_device (device , stream = stream )
253
283
254
284
def size (x ):
0 commit comments