Skip to content

Commit 2c87d36

Browse files
committed
MAINT: clarify default_device output
1 parent b6900df commit 2c87d36

File tree

5 files changed

+35
-14
lines changed

5 files changed

+35
-14
lines changed

Diff for: array_api_compat/common/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# These functions are modified from the NumPy versions.
2020

21-
# Creation functions add the device keyword (which does nothing for NumPy)
21+
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
2222

2323
def arange(
2424
start: Union[int, float],

Diff for: array_api_compat/cupy/_info.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
complex128,
2727
)
2828

29+
2930
class __array_namespace_info__:
3031
"""
3132
Get the array API inspection namespace for CuPy.
@@ -117,7 +118,7 @@ def default_device(self):
117118
118119
Returns
119120
-------
120-
device : str
121+
device : Device
121122
The default device used for new CuPy arrays.
122123
123124
Examples
@@ -126,6 +127,15 @@ def default_device(self):
126127
>>> info.default_device()
127128
Device(0)
128129
130+
Notes
131+
-----
132+
This method returns the static default device when CuPy is initialized.
133+
However, the *current* device used by creation functions (``empty`` etc.)
134+
can be changed globally or with a context manager.
135+
136+
See Also
137+
--------
138+
https://github.com/data-apis/array-api/issues/835
129139
"""
130140
return cuda.Device(0)
131141

@@ -312,7 +322,7 @@ def devices(self):
312322
313323
Returns
314324
-------
315-
devices : list of str
325+
devices : list[Device]
316326
The devices supported by CuPy.
317327
318328
See Also

Diff for: array_api_compat/dask/array/_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def default_device(self):
130130
131131
Returns
132132
-------
133-
device : str
133+
device : Device
134134
The default device used for new Dask arrays.
135135
136136
Examples
@@ -335,7 +335,7 @@ def devices(self):
335335
336336
Returns
337337
-------
338-
devices : list of str
338+
devices : list[Device]
339339
The devices supported by Dask.
340340
341341
See Also

Diff for: array_api_compat/numpy/_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def default_device(self):
119119
120120
Returns
121121
-------
122-
device : str
122+
device : Device
123123
The default device used for new NumPy arrays.
124124
125125
Examples
@@ -326,7 +326,7 @@ def devices(self):
326326
327327
Returns
328328
-------
329-
devices : list of str
329+
devices : list[Device]
330330
The devices supported by NumPy.
331331
332332
See Also

Diff for: array_api_compat/torch/_info.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,24 @@ def default_device(self):
102102
103103
Returns
104104
-------
105-
device : str
105+
device : Device
106106
The default device used for new PyTorch arrays.
107107
108108
Examples
109109
--------
110110
>>> info = np.__array_namespace_info__()
111111
>>> info.default_device()
112-
'cpu'
112+
device(type='cpu')
113113
114+
Notes
115+
-----
116+
This method returns the static default device when PyTorch is initialized.
117+
However, the *current* device used by creation functions (``empty`` etc.)
118+
can be changed at runtime.
119+
120+
See Also
121+
--------
122+
https://github.com/data-apis/array-api/issues/835
114123
"""
115124
return torch.device("cpu")
116125

@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
120129
121130
Parameters
122131
----------
123-
device : str, optional
124-
The device to get the default data types for. For PyTorch, only
125-
``'cpu'`` is allowed.
132+
device : Device, optional
133+
The device to get the default data types for.
134+
Unused for PyTorch, as all devices use the same default dtypes.
126135
127136
Returns
128137
-------
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
250259
251260
Parameters
252261
----------
253-
device : str, optional
262+
device : Device, optional
254263
The device to get the data types for.
264+
Unused for PyTorch, as all devices use the same dtypes.
255265
kind : str or tuple of str, optional
256266
The kind of data types to return. If ``None``, all data types are
257267
returned. If a string, only data types of that kind are returned.
@@ -310,7 +320,7 @@ def devices(self):
310320
311321
Returns
312322
-------
313-
devices : list of str
323+
devices : list[Device]
314324
The devices supported by PyTorch.
315325
316326
See Also
@@ -333,6 +343,7 @@ def devices(self):
333343
# device:
334344
try:
335345
torch.device('notadevice')
346+
raise AssertionError("unreachable") # pragma: nocover
336347
except RuntimeError as e:
337348
# The error message is something like:
338349
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"

0 commit comments

Comments
 (0)