From cfbd75b4a51de8d9f9167189815649159d2f6229 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Tue, 15 Apr 2025 14:18:23 +0100
Subject: [PATCH 1/2] ENH: `torch.result_type` for uint types

---
 array_api_compat/torch/_aliases.py | 40 +++++++++++++++++++++++++-----
 1 file changed, 34 insertions(+), 6 deletions(-)

diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 027a0261..bd4d5f8e 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -22,9 +22,9 @@
 try:
     # torch >=2.3
     _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
+    _HAS_LARGE_UINT = True
 except AttributeError:
-    pass
-
+    _HAS_LARGE_UINT = False
 
 _array_api_dtypes = {
     torch.bool,
@@ -59,6 +59,28 @@
     (torch.float64, torch.complex128): torch.complex128,
 }
 
+if _HAS_LARGE_UINT:  # torch >=2.3
+    _promotion_table.update(
+        {
+            # uints
+            (torch.uint8, torch.uint16): torch.uint16,
+            (torch.uint8, torch.uint32): torch.uint32,
+            (torch.uint8, torch.uint64): torch.uint64,
+            (torch.uint16, torch.uint32): torch.uint32,
+            (torch.uint16, torch.uint64): torch.uint64,
+            (torch.uint32, torch.uint64): torch.uint64,
+            # ints and uints (mixed sign)
+            (torch.uint16, torch.int8): torch.int32,
+            (torch.uint16, torch.int16): torch.int32,
+            (torch.uint16, torch.int32): torch.int32,
+            (torch.uint16, torch.int64): torch.int64,
+            (torch.uint32, torch.int8): torch.int64,
+            (torch.uint32, torch.int16): torch.int64,
+            (torch.uint32, torch.int32): torch.int64,
+            (torch.uint32, torch.int64): torch.int64,
+        }
+    )
+
 _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
 _promotion_table.update({(a, a): a for a in _array_api_dtypes})
 
@@ -295,10 +317,16 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
     if dtype is not None:
         return x.clone() if dtype == x.dtype else x.to(dtype)
 
-    # We can't upcast uint8 according to the spec because there is no
-    # torch.uint64, so at least upcast to int64 which is what prod does
-    # when axis=None.
-    if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
+    if x.dtype in (torch.int8, torch.int16, torch.int32):
+        return x.to(torch.int64)
+
+    if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
+        return x.to(torch.uint64)
+
+    if x.dtype == torch.uint8:
+        # We can't upcast uint8 according to the spec because there is no
+        # torch.uint64, so at least upcast to int64 which is what prod does
+        # when axis=None.
         return x.to(torch.int64)
 
     return x.clone()

From ff6ea189b0735e0c6c772cfea6ab897d9e2b2e49 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Tue, 15 Apr 2025 14:35:21 +0100
Subject: [PATCH 2/2] Torch info for large uints

---
 .github/workflows/array-api-tests-torch.yml |  2 -
 array_api_compat/torch/_info.py             | 95 ++++++++-------------
 2 files changed, 37 insertions(+), 60 deletions(-)

diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml
index ac20df25..7a228812 100644
--- a/.github/workflows/array-api-tests-torch.yml
+++ b/.github/workflows/array-api-tests-torch.yml
@@ -8,6 +8,4 @@ jobs:
     with:
       package-name: torch
       extra-requires: '--index-url https://download.pytorch.org/whl/cpu'
-      extra-env-vars: |
-        ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
       python-versions: '[''3.10'', ''3.13'']'
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index 818e5d37..3835f024 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None):
             "indexing": default_integral,
         }
 
-
     def _dtypes(self, kind):
-        bool = torch.bool
-        int8 = torch.int8
-        int16 = torch.int16
-        int32 = torch.int32
-        int64 = torch.int64
-        uint8 = torch.uint8
-        # uint16, uint32, and uint64 are present in newer versions of pytorch,
-        # but they aren't generally supported by the array API functions, so
-        # we omit them from this function.
-        float32 = torch.float32
-        float64 = torch.float64
-        complex64 = torch.complex64
-        complex128 = torch.complex128
-
         if kind is None:
-            return {
-                "bool": bool,
-                "int8": int8,
-                "int16": int16,
-                "int32": int32,
-                "int64": int64,
-                "uint8": uint8,
-                "float32": float32,
-                "float64": float64,
-                "complex64": complex64,
-                "complex128": complex128,
-            }
+            return self._dtypes(
+                (
+                    "bool",
+                    "signed integer",
+                    "unsigned integer",
+                    "real floating",
+                    "complex floating",
+                )
+            )
         if kind == "bool":
-            return {"bool": bool}
+            return {"bool": torch.bool}
         if kind == "signed integer":
             return {
-                "int8": int8,
-                "int16": int16,
-                "int32": int32,
-                "int64": int64,
+                "int8": torch.int8,
+                "int16": torch.int16,
+                "int32": torch.int32,
+                "int64": torch.int64,
             }
         if kind == "unsigned integer":
-            return {
-                "uint8": uint8,
-            }
+            try:
+                # torch >=2.3
+                return {
+                    "uint8": torch.uint8,
+                    "uint16": torch.uint16,
+                    "uint32": torch.uint32,
+                    "uint64": torch.uint32,
+                }
+            except AttributeError:
+                return {"uint8": torch.uint8}
         if kind == "integral":
-            return {
-                "int8": int8,
-                "int16": int16,
-                "int32": int32,
-                "int64": int64,
-                "uint8": uint8,
-            }
+            return self._dtypes(("signed integer", "unsigned integer"))
         if kind == "real floating":
             return {
-                "float32": float32,
-                "float64": float64,
+                "float32": torch.float32,
+                "float64": torch.float64,
             }
         if kind == "complex floating":
             return {
-                "complex64": complex64,
-                "complex128": complex128,
+                "complex64": torch.complex64,
+                "complex128": torch.complex128,
             }
         if kind == "numeric":
-            return {
-                "int8": int8,
-                "int16": int16,
-                "int32": int32,
-                "int64": int64,
-                "uint8": uint8,
-                "float32": float32,
-                "float64": float64,
-                "complex64": complex64,
-                "complex128": complex128,
-            }
+            return self._dtypes(
+                (
+                    "signed integer",
+                    "unsigned integer",
+                    "real floating",
+                    "complex floating",
+                )
+            )
         if isinstance(kind, tuple):
             res = {}
             for k in kind:
@@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None):
         ----------
         device : Device, optional
             The device to get the data types for.
-            Unused for PyTorch, as all devices use the same dtypes.
         kind : str or tuple of str, optional
             The kind of data types to return. If ``None``, all data types are
             returned. If a string, only data types of that kind are returned.