diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6d95069d..67c619b8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -595,11 +595,29 @@ def your_function(x, y): # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: + +def _check_device(bare_xp, device): + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get('numpy'): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..c5cd7489 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -25,7 +25,7 @@ ) import dask.array as da -from ...common import _aliases, array_namespace +from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, Device, @@ -56,6 +56,7 @@ def astype( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if not copy and dtype == x.dtype: return x @@ -86,6 +87,7 @@ def arange( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) args = [start] if stop is not None: @@ -155,6 +157,7 @@ def asarray( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..d5b7feac 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .._internal import get_xp -from ..common import _aliases +from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -95,8 +95,7 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") + _helpers._check_device(np, device) if hasattr(np, '_CopyMode'): if copy is None: @@ -122,6 +121,7 @@ def astype( copy: bool = True, device: Optional[Device] = None, ) -> Array: + _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy)