Skip to content

Commit 3cae752

Browse files
committed
ENH: torch dtype promotions
1 parent b6900df commit 3cae752

File tree

1 file changed

+56
-50
lines changed

1 file changed

+56
-50
lines changed

array_api_compat/torch/_aliases.py

+56-50
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
try:
2222
# torch >=2.3
2323
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
24+
_HAS_LARGE_UINT = True
2425
except AttributeError:
25-
pass
26-
26+
_HAS_LARGE_UINT = False
2727

2828
_array_api_dtypes = {
2929
torch.bool,
@@ -34,54 +34,55 @@
3434
torch.complex128,
3535
}
3636

37-
_promotion_table = {
38-
# bool
39-
(torch.bool, torch.bool): torch.bool,
37+
_promotion_table = {
4038
# ints
41-
(torch.int8, torch.int8): torch.int8,
4239
(torch.int8, torch.int16): torch.int16,
4340
(torch.int8, torch.int32): torch.int32,
4441
(torch.int8, torch.int64): torch.int64,
45-
(torch.int16, torch.int8): torch.int16,
46-
(torch.int16, torch.int16): torch.int16,
4742
(torch.int16, torch.int32): torch.int32,
4843
(torch.int16, torch.int64): torch.int64,
49-
(torch.int32, torch.int8): torch.int32,
50-
(torch.int32, torch.int16): torch.int32,
51-
(torch.int32, torch.int32): torch.int32,
5244
(torch.int32, torch.int64): torch.int64,
53-
(torch.int64, torch.int8): torch.int64,
54-
(torch.int64, torch.int16): torch.int64,
55-
(torch.int64, torch.int32): torch.int64,
56-
(torch.int64, torch.int64): torch.int64,
57-
# uints
58-
(torch.uint8, torch.uint8): torch.uint8,
5945
# ints and uints (mixed sign)
60-
(torch.int8, torch.uint8): torch.int16,
61-
(torch.int16, torch.uint8): torch.int16,
62-
(torch.int32, torch.uint8): torch.int32,
63-
(torch.int64, torch.uint8): torch.int64,
6446
(torch.uint8, torch.int8): torch.int16,
6547
(torch.uint8, torch.int16): torch.int16,
6648
(torch.uint8, torch.int32): torch.int32,
6749
(torch.uint8, torch.int64): torch.int64,
6850
# floats
69-
(torch.float32, torch.float32): torch.float32,
7051
(torch.float32, torch.float64): torch.float64,
71-
(torch.float64, torch.float32): torch.float64,
72-
(torch.float64, torch.float64): torch.float64,
7352
# complexes
74-
(torch.complex64, torch.complex64): torch.complex64,
7553
(torch.complex64, torch.complex128): torch.complex128,
76-
(torch.complex128, torch.complex64): torch.complex128,
77-
(torch.complex128, torch.complex128): torch.complex128,
7854
# Mixed float and complex
7955
(torch.float32, torch.complex64): torch.complex64,
8056
(torch.float32, torch.complex128): torch.complex128,
8157
(torch.float64, torch.complex64): torch.complex128,
8258
(torch.float64, torch.complex128): torch.complex128,
8359
}
8460

61+
if _HAS_LARGE_UINT: # torch >=2.3
62+
_promotion_table.update(
63+
{
64+
# uints
65+
(torch.uint8, torch.uint16): torch.uint16,
66+
(torch.uint8, torch.uint32): torch.uint32,
67+
(torch.uint8, torch.uint64): torch.uint64,
68+
(torch.uint16, torch.uint32): torch.uint32,
69+
(torch.uint16, torch.uint64): torch.uint64,
70+
(torch.uint32, torch.uint64): torch.uint64,
71+
# ints and uints (mixed sign)
72+
(torch.uint16, torch.int8): torch.int32,
73+
(torch.uint16, torch.int16): torch.int32,
74+
(torch.uint16, torch.int32): torch.int32,
75+
(torch.uint16, torch.int64): torch.int64,
76+
(torch.uint32, torch.int8): torch.int64,
77+
(torch.uint32, torch.int16): torch.int64,
78+
(torch.uint32, torch.int32): torch.int64,
79+
(torch.uint32, torch.int64): torch.int64,
80+
}
81+
)
82+
83+
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
84+
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
85+
8586

8687
def _two_arg(f):
8788
@_wraps(f)
@@ -275,6 +276,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
275276
out = torch.unsqueeze(out, a)
276277
return out
277278

279+
280+
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
281+
"""
282+
Implements `sum(..., axis=())` and `prod(..., axis=())`.
283+
284+
Works around https://github.com/pytorch/pytorch/issues/29137
285+
"""
286+
if dtype is not None:
287+
return x.clone() if dtype == x.dtype else x.to(dtype)
288+
289+
if x.dtype in (torch.int8, torch.int16, torch.int32):
290+
return x.to(torch.int64)
291+
292+
if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
293+
return x.to(torch.uint64)
294+
295+
if x.dtype == torch.uint8:
296+
# We can't upcast uint8 according to the spec because there is no
297+
# torch.uint64, so at least upcast to int64 which is what prod does
298+
# when axis=None.
299+
return x.to(torch.int64)
300+
301+
return x.clone()
302+
303+
278304
def prod(x: Array,
279305
/,
280306
*,
@@ -283,20 +309,9 @@ def prod(x: Array,
283309
keepdims: bool = False,
284310
**kwargs) -> Array:
285311
x = torch.asarray(x)
286-
ndim = x.ndim
287312

288-
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
289-
# below because it still needs to upcast.
290313
if axis == ():
291-
if dtype is None:
292-
# We can't upcast uint8 according to the spec because there is no
293-
# torch.uint64, so at least upcast to int64 which is what sum does
294-
# when axis=None.
295-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
296-
return x.to(torch.int64)
297-
return x.clone()
298-
return x.to(dtype)
299-
314+
return _sum_prod_no_axis(x, dtype)
300315
# torch.prod doesn't support multiple axes
301316
# (https://github.com/pytorch/pytorch/issues/56586).
302317
if isinstance(axis, tuple):
@@ -305,7 +320,7 @@ def prod(x: Array,
305320
# torch doesn't support keepdims with axis=None
306321
# (https://github.com/pytorch/pytorch/issues/71209)
307322
res = torch.prod(x, dtype=dtype, **kwargs)
308-
res = _axis_none_keepdims(res, ndim, keepdims)
323+
res = _axis_none_keepdims(res, x.ndim, keepdims)
309324
return res
310325

311326
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -321,17 +336,8 @@ def sum(x: Array,
321336
x = torch.asarray(x)
322337
ndim = x.ndim
323338

324-
# https://github.com/pytorch/pytorch/issues/29137.
325-
# Make sure it upcasts.
326339
if axis == ():
327-
if dtype is None:
328-
# We can't upcast uint8 according to the spec because there is no
329-
# torch.uint64, so at least upcast to int64 which is what sum does
330-
# when axis=None.
331-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
332-
return x.to(torch.int64)
333-
return x.clone()
334-
return x.to(dtype)
340+
return _sum_prod_no_axis(x, dtype)
335341

336342
if axis is None:
337343
# torch doesn't support keepdims with axis=None

0 commit comments

Comments
 (0)