From 84b055150ef46f8945823619d8eddd823181b977 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 3 Apr 2025 10:14:01 +0100
Subject: [PATCH] ENH: torch.asarray device propagation

---
 array_api_compat/common/_aliases.py  |  2 +-
 array_api_compat/cupy/_info.py       | 14 +++++++++++--
 array_api_compat/dask/array/_info.py |  4 ++--
 array_api_compat/numpy/_info.py      |  4 ++--
 array_api_compat/torch/_aliases.py   | 31 ++++++++++++++++++++++------
 array_api_compat/torch/_info.py      | 25 +++++++++++++++-------
 array_api_compat/torch/_typing.py    |  5 ++---
 7 files changed, 62 insertions(+), 23 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 03910681..b8b18c6e 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -18,7 +18,7 @@
 
 # These functions are modified from the NumPy versions.
 
-# Creation functions add the device keyword (which does nothing for NumPy)
+# Creation functions add the device keyword (which does nothing for NumPy and Dask)
 
 def arange(
     start: Union[int, float],
diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py
index 790621e4..66d3c4ae 100644
--- a/array_api_compat/cupy/_info.py
+++ b/array_api_compat/cupy/_info.py
@@ -26,6 +26,7 @@
     complex128,
 )
 
+
 class __array_namespace_info__:
     """
     Get the array API inspection namespace for CuPy.
@@ -117,7 +118,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new CuPy arrays.
 
         Examples
@@ -126,6 +127,15 @@ def default_device(self):
         >>> info.default_device()
         Device(0)
 
+        Notes
+        -----
+        This method returns the static default device when CuPy is initialized.
+        However, the *current* device used by creation functions (``empty`` etc.)
+        can be changed globally or with a context manager.
+
+        See Also
+        --------
+        https://github.com/data-apis/array-api/issues/835
         """
         return cuda.Device(0)
 
@@ -312,7 +322,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by CuPy.
 
         See Also
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index fc70b5a2..97ceec67 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -130,7 +130,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new Dask arrays.
 
         Examples
@@ -335,7 +335,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by Dask.
 
         See Also
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index e706d118..a30ee352 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -119,7 +119,7 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new NumPy arrays.
 
         Examples
@@ -326,7 +326,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by NumPy.
 
         See Also
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 982500b0..0891525a 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -2,12 +2,13 @@
 
 from functools import reduce as _reduce, wraps as _wraps
 from builtins import all as _builtin_all, any as _builtin_any
-from typing import List, Optional, Sequence, Tuple, Union
+from typing import Any, List, Optional, Sequence, Tuple, Union
 
 import torch
 
 from .._internal import get_xp
 from ..common import _aliases
+from ..common._typing import NestedSequence, SupportsBufferProtocol
 from ._info import __array_namespace_info__
 from ._typing import Array, Device, DType
 
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
 remainder = _two_arg(torch.remainder)
 subtract = _two_arg(torch.subtract)
 
+
+def asarray(
+    obj: (
+    Array 
+        | bool | int | float | complex 
+        | NestedSequence[bool | int | float | complex] 
+        | SupportsBufferProtocol
+    ),
+    /,
+    *,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: bool | None = None,
+    **kwargs: Any,
+) -> Array:
+    # torch.asarray does not respect input->output device propagation
+    # https://github.com/pytorch/pytorch/issues/150199
+    if device is None and isinstance(obj, torch.Tensor):
+        device = obj.device
+    return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
+
+
 # These wrappers are mostly based on the fact that pytorch uses 'dim' instead
 # of 'axis'.
 
@@ -282,7 +305,6 @@ def prod(x: Array,
          dtype: Optional[DType] = None,
          keepdims: bool = False,
          **kwargs) -> Array:
-    x = torch.asarray(x)
     ndim = x.ndim
 
     # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
          dtype: Optional[DType] = None,
          keepdims: bool = False,
          **kwargs) -> Array:
-    x = torch.asarray(x)
     ndim = x.ndim
 
     # https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
         **kwargs) -> Array:
-    x = torch.asarray(x)
     ndim = x.ndim
     if axis == ():
         return x.to(torch.bool)
@@ -373,7 +393,6 @@ def all(x: Array,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
         **kwargs) -> Array:
-    x = torch.asarray(x)
     ndim = x.ndim
     if axis == ():
         return x.to(torch.bool)
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
         return out
 
 
-__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
+__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
            'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
            'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
            'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index 34fbcb21..b0486a58 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -102,15 +102,24 @@ def default_device(self):
 
         Returns
         -------
-        device : str
+        device : Device
             The default device used for new PyTorch arrays.
 
         Examples
         --------
         >>> info = np.__array_namespace_info__()
         >>> info.default_device()
-        'cpu'
+        device(type='cpu')
 
+        Notes
+        -----
+        This method returns the static default device when PyTorch is initialized.
+        However, the *current* device used by creation functions (``empty`` etc.)
+        can be changed at runtime.
+
+        See Also
+        --------
+        https://github.com/data-apis/array-api/issues/835
         """
         return torch.device("cpu")
 
@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
 
         Parameters
         ----------
-        device : str, optional
-            The device to get the default data types for. For PyTorch, only
-            ``'cpu'`` is allowed.
+        device : Device, optional
+            The device to get the default data types for.
+            Unused for PyTorch, as all devices use the same default dtypes.
 
         Returns
         -------
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
 
         Parameters
         ----------
-        device : str, optional
+        device : Device, optional
             The device to get the data types for.
+            Unused for PyTorch, as all devices use the same dtypes.
         kind : str or tuple of str, optional
             The kind of data types to return. If ``None``, all data types are
             returned. If a string, only data types of that kind are returned.
@@ -310,7 +320,7 @@ def devices(self):
 
         Returns
         -------
-        devices : list of str
+        devices : list[Device]
             The devices supported by PyTorch.
 
         See Also
@@ -333,6 +343,7 @@ def devices(self):
         # device:
         try:
             torch.device('notadevice')
+            raise AssertionError("unreachable")  # pragma: nocover
         except RuntimeError as e:
             # The error message is something like:
             # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py
index 29ad3fa7..52670871 100644
--- a/array_api_compat/torch/_typing.py
+++ b/array_api_compat/torch/_typing.py
@@ -1,4 +1,3 @@
-__all__ = ["Array", "DType", "Device"]
+__all__ = ["Array", "Device", "DType"]
 
-from torch import dtype as DType, Tensor as Array
-from ..common._typing import Device
+from torch import device as Device, dtype as DType, Tensor as Array