You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a tracker of the current state of support for more than one device at once in the Array API, its helper libraries, and the libraries that implement it.
Supporting multiple devices at the same time is typically substantially more fragile than pinning one of the available devices at interpreter level and then using that one exclusively, which typically works as intended.
Array API
Dictates that the device of the output arrays must always follow from that of the input(s), unless explicitly overridden by a device= kwarg, where allowed.
array-api-compat adds stub support, that returns "cpu" when wrapping around numpy and a dummy DASK_DEVICE otherwise. Notably, this is stored nowhere and does not survive a round-trip (device(to_device(x, d) == d can fail).
This is a non-issue when wrapping around numpy, or when wrapping around cupy with both client and workers mounting a single GPU.
Multi-GPU Dask+CuPy support could be achieved by starting separate worker processes on the same host and pinning the GPU at interpreter level. This is extremely inefficient as it incurs in IPC and possibly memory duplication. If a user does so, the client and array-api-compat will never know.
dask-cuda may improve the situation (did not investigate).
This is a tracker of the current state of support for more than one device at once in the Array API, its helper libraries, and the libraries that implement it.
Supporting multiple devices at the same time is typically substantially more fragile than pinning one of the available devices at interpreter level and then using that one exclusively, which typically works as intended.
Array API
device=
kwarg, where allowed.__array_namespace_info__().default_device()
should return: Clarify definitions of "default device" and "current device" #835array-api-strict
array-api-tests
array-api-compat
device()
andto_device()
to work around non-compliance of wrapped librariesarray-api-extra
NumPy
array-api-compat
backports it to NumPy 1.x.CuPy
device=
parameter to functions.PyTorch
asarray
: device does not propagate from input to output afterset_default_device
pytorch/pytorch#150199JAX
__array_namespace_info__
: Array APIdefault_device()
anddevices()
are incorrect jax-ml/jax#27606jax.jit
, input-to-output device propagation works, but it's impossible to call creation functions (empty
,zeros
,full
, etc.) on a non-default device: Missing.device
attribute inside@jax.jit
jax-ml/jax#26000Dask
DASK_DEVICE
otherwise. Notably, this is stored nowhere and does not survive a round-trip (device(to_device(x, d) == d
can fail).dask-cuda
may improve the situation (did not investigate).SciPy
special.logsumexp
on non-default device scipy/scipy#22756The text was updated successfully, but these errors were encountered: