Skip to content

Commit e718c7b

Browse files
committed
MAINT: validate device on numpy and dask
1 parent b6900df commit e718c7b

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

array_api_compat/common/_helpers.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -595,11 +595,29 @@ def your_function(x, y):
595595
# backwards compatibility alias
596596
get_namespace = array_namespace
597597

598-
def _check_device(xp, device):
599-
if xp == sys.modules.get('numpy'):
598+
599+
def _check_device(bare_xp, device):
600+
"""
601+
Validate dummy device on device-less array backends.
602+
603+
Notes
604+
-----
605+
This function is also invoked by CuPy, which does have multiple devices
606+
if there are multiple GPUs available.
607+
However, CuPy multi-device support is currently impossible
608+
without using the global device or a context manager:
609+
610+
https://github.com/data-apis/array-api-compat/pull/293
611+
"""
612+
if bare_xp is sys.modules.get('numpy'):
600613
if device not in ["cpu", None]:
601614
raise ValueError(f"Unsupported device for NumPy: {device!r}")
602615

616+
elif bare_xp is sys.modules.get('dask.array'):
617+
if device not in ("cpu", _DASK_DEVICE):
618+
raise ValueError(f"Unsupported device for Dask: {device!r}")
619+
620+
603621
# Placeholder object to represent the dask device
604622
# when the array backend is not the CPU.
605623
# (since it is not easy to tell which device a dask array is on)

array_api_compat/dask/array/_aliases.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
import dask.array as da
2727

28-
from ...common import _aliases, array_namespace
28+
from ...common import _aliases, _helpers, array_namespace
2929
from ...common._typing import (
3030
Array,
3131
Device,
@@ -56,6 +56,7 @@ def astype(
5656
specification for more details.
5757
"""
5858
# TODO: respect device keyword?
59+
_helpers._check_device(da, device)
5960

6061
if not copy and dtype == x.dtype:
6162
return x
@@ -86,6 +87,7 @@ def arange(
8687
specification for more details.
8788
"""
8889
# TODO: respect device keyword?
90+
_helpers._check_device(da, device)
8991

9092
args = [start]
9193
if stop is not None:
@@ -155,6 +157,7 @@ def asarray(
155157
specification for more details.
156158
"""
157159
# TODO: respect device keyword?
160+
_helpers._check_device(da, device)
158161

159162
if isinstance(obj, da.Array):
160163
if dtype is not None and dtype != obj.dtype:

array_api_compat/numpy/_aliases.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, Union
44

55
from .._internal import get_xp
6-
from ..common import _aliases
6+
from ..common import _aliases, _helpers
77
from ..common._typing import NestedSequence, SupportsBufferProtocol
88
from ._info import __array_namespace_info__
99
from ._typing import Array, Device, DType
@@ -95,8 +95,7 @@ def asarray(
9595
See the corresponding documentation in the array library and/or the array API
9696
specification for more details.
9797
"""
98-
if device not in ["cpu", None]:
99-
raise ValueError(f"Unsupported device for NumPy: {device!r}")
98+
_helpers._check_device(np, device)
10099

101100
if hasattr(np, '_CopyMode'):
102101
if copy is None:
@@ -122,6 +121,7 @@ def astype(
122121
copy: bool = True,
123122
device: Optional[Device] = None,
124123
) -> Array:
124+
_helpers._check_device(np, device)
125125
return x.astype(dtype=dtype, copy=copy)
126126

127127

0 commit comments

Comments
 (0)