File tree 3 files changed +27
-6
lines changed
3 files changed +27
-6
lines changed Original file line number Diff line number Diff line change @@ -595,11 +595,29 @@ def your_function(x, y):
595
595
# backwards compatibility alias
596
596
get_namespace = array_namespace
597
597
598
- def _check_device (xp , device ):
599
- if xp == sys .modules .get ('numpy' ):
598
+
599
+ def _check_device (bare_xp , device ):
600
+ """
601
+ Validate dummy device on device-less array backends.
602
+
603
+ Notes
604
+ -----
605
+ This function is also invoked by CuPy, which does have multiple devices
606
+ if there are multiple GPUs available.
607
+ However, CuPy multi-device support is currently impossible
608
+ without using the global device or a context manager:
609
+
610
+ https://github.com/data-apis/array-api-compat/pull/293
611
+ """
612
+ if bare_xp is sys .modules .get ('numpy' ):
600
613
if device not in ["cpu" , None ]:
601
614
raise ValueError (f"Unsupported device for NumPy: { device !r} " )
602
615
616
+ elif bare_xp is sys .modules .get ('dask.array' ):
617
+ if device not in ("cpu" , _DASK_DEVICE ):
618
+ raise ValueError (f"Unsupported device for Dask: { device !r} " )
619
+
620
+
603
621
# Placeholder object to represent the dask device
604
622
# when the array backend is not the CPU.
605
623
# (since it is not easy to tell which device a dask array is on)
Original file line number Diff line number Diff line change 25
25
)
26
26
import dask .array as da
27
27
28
- from ...common import _aliases , array_namespace
28
+ from ...common import _aliases , _helpers , array_namespace
29
29
from ...common ._typing import (
30
30
Array ,
31
31
Device ,
@@ -56,6 +56,7 @@ def astype(
56
56
specification for more details.
57
57
"""
58
58
# TODO: respect device keyword?
59
+ _helpers ._check_device (da , device )
59
60
60
61
if not copy and dtype == x .dtype :
61
62
return x
@@ -86,6 +87,7 @@ def arange(
86
87
specification for more details.
87
88
"""
88
89
# TODO: respect device keyword?
90
+ _helpers ._check_device (da , device )
89
91
90
92
args = [start ]
91
93
if stop is not None :
@@ -155,6 +157,7 @@ def asarray(
155
157
specification for more details.
156
158
"""
157
159
# TODO: respect device keyword?
160
+ _helpers ._check_device (da , device )
158
161
159
162
if isinstance (obj , da .Array ):
160
163
if dtype is not None and dtype != obj .dtype :
Original file line number Diff line number Diff line change 3
3
from typing import Optional , Union
4
4
5
5
from .._internal import get_xp
6
- from ..common import _aliases
6
+ from ..common import _aliases , _helpers
7
7
from ..common ._typing import NestedSequence , SupportsBufferProtocol
8
8
from ._info import __array_namespace_info__
9
9
from ._typing import Array , Device , DType
@@ -95,8 +95,7 @@ def asarray(
95
95
See the corresponding documentation in the array library and/or the array API
96
96
specification for more details.
97
97
"""
98
- if device not in ["cpu" , None ]:
99
- raise ValueError (f"Unsupported device for NumPy: { device !r} " )
98
+ _helpers ._check_device (np , device )
100
99
101
100
if hasattr (np , '_CopyMode' ):
102
101
if copy is None :
@@ -122,6 +121,7 @@ def astype(
122
121
copy : bool = True ,
123
122
device : Optional [Device ] = None ,
124
123
) -> Array :
124
+ _helpers ._check_device (np , device )
125
125
return x .astype (dtype = dtype , copy = copy )
126
126
127
127
You can’t perform that action at this time.
0 commit comments