Skip to content

Commit 84b0551

Browse files
committed
ENH: torch.asarray device propagation
1 parent b2af137 commit 84b0551

File tree

7 files changed

+62
-23
lines changed

7 files changed

+62
-23
lines changed

array_api_compat/common/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# These functions are modified from the NumPy versions.
2020

21-
# Creation functions add the device keyword (which does nothing for NumPy)
21+
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
2222

2323
def arange(
2424
start: Union[int, float],

array_api_compat/cupy/_info.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
complex128,
2727
)
2828

29+
2930
class __array_namespace_info__:
3031
"""
3132
Get the array API inspection namespace for CuPy.
@@ -117,7 +118,7 @@ def default_device(self):
117118
118119
Returns
119120
-------
120-
device : str
121+
device : Device
121122
The default device used for new CuPy arrays.
122123
123124
Examples
@@ -126,6 +127,15 @@ def default_device(self):
126127
>>> info.default_device()
127128
Device(0)
128129
130+
Notes
131+
-----
132+
This method returns the static default device when CuPy is initialized.
133+
However, the *current* device used by creation functions (``empty`` etc.)
134+
can be changed globally or with a context manager.
135+
136+
See Also
137+
--------
138+
https://github.com/data-apis/array-api/issues/835
129139
"""
130140
return cuda.Device(0)
131141

@@ -312,7 +322,7 @@ def devices(self):
312322
313323
Returns
314324
-------
315-
devices : list of str
325+
devices : list[Device]
316326
The devices supported by CuPy.
317327
318328
See Also

array_api_compat/dask/array/_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def default_device(self):
130130
131131
Returns
132132
-------
133-
device : str
133+
device : Device
134134
The default device used for new Dask arrays.
135135
136136
Examples
@@ -335,7 +335,7 @@ def devices(self):
335335
336336
Returns
337337
-------
338-
devices : list of str
338+
devices : list[Device]
339339
The devices supported by Dask.
340340
341341
See Also

array_api_compat/numpy/_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def default_device(self):
119119
120120
Returns
121121
-------
122-
device : str
122+
device : Device
123123
The default device used for new NumPy arrays.
124124
125125
Examples
@@ -326,7 +326,7 @@ def devices(self):
326326
327327
Returns
328328
-------
329-
devices : list of str
329+
devices : list[Device]
330330
The devices supported by NumPy.
331331
332332
See Also

array_api_compat/torch/_aliases.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union
66

77
import torch
88

99
from .._internal import get_xp
1010
from ..common import _aliases
11+
from ..common._typing import NestedSequence, SupportsBufferProtocol
1112
from ._info import __array_namespace_info__
1213
from ._typing import Array, Device, DType
1314

@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208
remainder = _two_arg(torch.remainder)
208209
subtract = _two_arg(torch.subtract)
209210

211+
212+
def asarray(
213+
obj: (
214+
Array
215+
| bool | int | float | complex
216+
| NestedSequence[bool | int | float | complex]
217+
| SupportsBufferProtocol
218+
),
219+
/,
220+
*,
221+
dtype: DType | None = None,
222+
device: Device | None = None,
223+
copy: bool | None = None,
224+
**kwargs: Any,
225+
) -> Array:
226+
# torch.asarray does not respect input->output device propagation
227+
# https://github.com/pytorch/pytorch/issues/150199
228+
if device is None and isinstance(obj, torch.Tensor):
229+
device = obj.device
230+
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
231+
232+
210233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234
# of 'axis'.
212235

@@ -282,7 +305,6 @@ def prod(x: Array,
282305
dtype: Optional[DType] = None,
283306
keepdims: bool = False,
284307
**kwargs) -> Array:
285-
x = torch.asarray(x)
286308
ndim = x.ndim
287309

288310
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
318340
dtype: Optional[DType] = None,
319341
keepdims: bool = False,
320342
**kwargs) -> Array:
321-
x = torch.asarray(x)
322343
ndim = x.ndim
323344

324345
# https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
348369
axis: Optional[Union[int, Tuple[int, ...]]] = None,
349370
keepdims: bool = False,
350371
**kwargs) -> Array:
351-
x = torch.asarray(x)
352372
ndim = x.ndim
353373
if axis == ():
354374
return x.to(torch.bool)
@@ -373,7 +393,6 @@ def all(x: Array,
373393
axis: Optional[Union[int, Tuple[int, ...]]] = None,
374394
keepdims: bool = False,
375395
**kwargs) -> Array:
376-
x = torch.asarray(x)
377396
ndim = x.ndim
378397
if axis == ():
379398
return x.to(torch.bool)
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
816835
return out
817836

818837

819-
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
838+
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
820839
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
821840
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
822841
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',

array_api_compat/torch/_info.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,24 @@ def default_device(self):
102102
103103
Returns
104104
-------
105-
device : str
105+
device : Device
106106
The default device used for new PyTorch arrays.
107107
108108
Examples
109109
--------
110110
>>> info = np.__array_namespace_info__()
111111
>>> info.default_device()
112-
'cpu'
112+
device(type='cpu')
113113
114+
Notes
115+
-----
116+
This method returns the static default device when PyTorch is initialized.
117+
However, the *current* device used by creation functions (``empty`` etc.)
118+
can be changed at runtime.
119+
120+
See Also
121+
--------
122+
https://github.com/data-apis/array-api/issues/835
114123
"""
115124
return torch.device("cpu")
116125

@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
120129
121130
Parameters
122131
----------
123-
device : str, optional
124-
The device to get the default data types for. For PyTorch, only
125-
``'cpu'`` is allowed.
132+
device : Device, optional
133+
The device to get the default data types for.
134+
Unused for PyTorch, as all devices use the same default dtypes.
126135
127136
Returns
128137
-------
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
250259
251260
Parameters
252261
----------
253-
device : str, optional
262+
device : Device, optional
254263
The device to get the data types for.
264+
Unused for PyTorch, as all devices use the same dtypes.
255265
kind : str or tuple of str, optional
256266
The kind of data types to return. If ``None``, all data types are
257267
returned. If a string, only data types of that kind are returned.
@@ -310,7 +320,7 @@ def devices(self):
310320
311321
Returns
312322
-------
313-
devices : list of str
323+
devices : list[Device]
314324
The devices supported by PyTorch.
315325
316326
See Also
@@ -333,6 +343,7 @@ def devices(self):
333343
# device:
334344
try:
335345
torch.device('notadevice')
346+
raise AssertionError("unreachable") # pragma: nocover
336347
except RuntimeError as e:
337348
# The error message is something like:
338349
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"

array_api_compat/torch/_typing.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
__all__ = ["Array", "DType", "Device"]
1+
__all__ = ["Array", "Device", "DType"]
22

3-
from torch import dtype as DType, Tensor as Array
4-
from ..common._typing import Device
3+
from torch import device as Device, dtype as DType, Tensor as Array

0 commit comments

Comments
 (0)