From 84b055150ef46f8945823619d8eddd823181b977 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Thu, 3 Apr 2025 10:14:01 +0100 Subject: [PATCH] ENH: torch.asarray device propagation --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_info.py | 14 +++++++++++-- array_api_compat/dask/array/_info.py | 4 ++-- array_api_compat/numpy/_info.py | 4 ++-- array_api_compat/torch/_aliases.py | 31 ++++++++++++++++++++++------ array_api_compat/torch/_info.py | 25 +++++++++++++++------- array_api_compat/torch/_typing.py | 5 ++--- 7 files changed, 62 insertions(+), 23 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 03910681..b8b18c6e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -18,7 +18,7 @@ # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) def arange( start: Union[int, float], diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..66d3c4ae 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index fc70b5a2..97ceec67 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -130,7 +130,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples @@ -335,7 +335,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..a30ee352 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -119,7 +119,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -326,7 +326,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..0891525a 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,12 +2,13 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from .._internal import get_xp from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: ( + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + ), + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. @@ -282,7 +305,6 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic @@ -318,7 +340,6 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. @@ -348,7 +369,6 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -373,7 +393,6 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..b0486a58 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- >>> info = np.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py index 29ad3fa7..52670871 100644 --- a/array_api_compat/torch/_typing.py +++ b/array_api_compat/torch/_typing.py @@ -1,4 +1,3 @@ -__all__ = ["Array", "DType", "Device"] +__all__ = ["Array", "Device", "DType"] -from torch import dtype as DType, Tensor as Array -from ..common._typing import Device +from torch import device as Device, dtype as DType, Tensor as Array