21
21
try :
22
22
# torch >=2.3
23
23
_int_dtypes |= {torch .uint16 , torch .uint32 , torch .uint64 }
24
+ _HAS_LARGE_UINT = True
24
25
except AttributeError :
25
- pass
26
-
26
+ _HAS_LARGE_UINT = False
27
27
28
28
_array_api_dtypes = {
29
29
torch .bool ,
34
34
torch .complex128 ,
35
35
}
36
36
37
- _promotion_table = {
38
- # bool
39
- (torch .bool , torch .bool ): torch .bool ,
37
+ _promotion_table = {
40
38
# ints
41
- (torch .int8 , torch .int8 ): torch .int8 ,
42
39
(torch .int8 , torch .int16 ): torch .int16 ,
43
40
(torch .int8 , torch .int32 ): torch .int32 ,
44
41
(torch .int8 , torch .int64 ): torch .int64 ,
45
- (torch .int16 , torch .int8 ): torch .int16 ,
46
- (torch .int16 , torch .int16 ): torch .int16 ,
47
42
(torch .int16 , torch .int32 ): torch .int32 ,
48
43
(torch .int16 , torch .int64 ): torch .int64 ,
49
- (torch .int32 , torch .int8 ): torch .int32 ,
50
- (torch .int32 , torch .int16 ): torch .int32 ,
51
- (torch .int32 , torch .int32 ): torch .int32 ,
52
44
(torch .int32 , torch .int64 ): torch .int64 ,
53
- (torch .int64 , torch .int8 ): torch .int64 ,
54
- (torch .int64 , torch .int16 ): torch .int64 ,
55
- (torch .int64 , torch .int32 ): torch .int64 ,
56
- (torch .int64 , torch .int64 ): torch .int64 ,
57
- # uints
58
- (torch .uint8 , torch .uint8 ): torch .uint8 ,
59
45
# ints and uints (mixed sign)
60
- (torch .int8 , torch .uint8 ): torch .int16 ,
61
- (torch .int16 , torch .uint8 ): torch .int16 ,
62
- (torch .int32 , torch .uint8 ): torch .int32 ,
63
- (torch .int64 , torch .uint8 ): torch .int64 ,
64
46
(torch .uint8 , torch .int8 ): torch .int16 ,
65
47
(torch .uint8 , torch .int16 ): torch .int16 ,
66
48
(torch .uint8 , torch .int32 ): torch .int32 ,
67
49
(torch .uint8 , torch .int64 ): torch .int64 ,
68
50
# floats
69
- (torch .float32 , torch .float32 ): torch .float32 ,
70
51
(torch .float32 , torch .float64 ): torch .float64 ,
71
- (torch .float64 , torch .float32 ): torch .float64 ,
72
- (torch .float64 , torch .float64 ): torch .float64 ,
73
52
# complexes
74
- (torch .complex64 , torch .complex64 ): torch .complex64 ,
75
53
(torch .complex64 , torch .complex128 ): torch .complex128 ,
76
- (torch .complex128 , torch .complex64 ): torch .complex128 ,
77
- (torch .complex128 , torch .complex128 ): torch .complex128 ,
78
54
# Mixed float and complex
79
55
(torch .float32 , torch .complex64 ): torch .complex64 ,
80
56
(torch .float32 , torch .complex128 ): torch .complex128 ,
81
57
(torch .float64 , torch .complex64 ): torch .complex128 ,
82
58
(torch .float64 , torch .complex128 ): torch .complex128 ,
83
59
}
84
60
61
+ if _HAS_LARGE_UINT : # torch >=2.3
62
+ _promotion_table .update (
63
+ {
64
+ # uints
65
+ (torch .uint8 , torch .uint16 ): torch .uint16 ,
66
+ (torch .uint8 , torch .uint32 ): torch .uint32 ,
67
+ (torch .uint8 , torch .uint64 ): torch .uint64 ,
68
+ (torch .uint16 , torch .uint32 ): torch .uint32 ,
69
+ (torch .uint16 , torch .uint64 ): torch .uint64 ,
70
+ (torch .uint32 , torch .uint64 ): torch .uint64 ,
71
+ # ints and uints (mixed sign)
72
+ (torch .uint16 , torch .int8 ): torch .int32 ,
73
+ (torch .uint16 , torch .int16 ): torch .int32 ,
74
+ (torch .uint16 , torch .int32 ): torch .int32 ,
75
+ (torch .uint16 , torch .int64 ): torch .int64 ,
76
+ (torch .uint32 , torch .int8 ): torch .int64 ,
77
+ (torch .uint32 , torch .int16 ): torch .int64 ,
78
+ (torch .uint32 , torch .int32 ): torch .int64 ,
79
+ (torch .uint32 , torch .int64 ): torch .int64 ,
80
+ }
81
+ )
82
+
83
+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
84
+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
85
+
85
86
86
87
def _two_arg (f ):
87
88
@_wraps (f )
@@ -275,6 +276,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
275
276
out = torch .unsqueeze (out , a )
276
277
return out
277
278
279
+
280
+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
281
+ """
282
+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
283
+
284
+ Works around https://github.com/pytorch/pytorch/issues/29137
285
+ """
286
+ if dtype is not None :
287
+ return x .clone () if dtype == x .dtype else x .to (dtype )
288
+
289
+ if x .dtype in (torch .int8 , torch .int16 , torch .int32 ):
290
+ return x .to (torch .int64 )
291
+
292
+ if _HAS_LARGE_UINT and x .dtype in (torch .uint8 , torch .uint16 , torch .uint32 ):
293
+ return x .to (torch .uint64 )
294
+
295
+ if x .dtype == torch .uint8 :
296
+ # We can't upcast uint8 according to the spec because there is no
297
+ # torch.uint64, so at least upcast to int64 which is what prod does
298
+ # when axis=None.
299
+ return x .to (torch .int64 )
300
+
301
+ return x .clone ()
302
+
303
+
278
304
def prod (x : Array ,
279
305
/ ,
280
306
* ,
@@ -283,20 +309,9 @@ def prod(x: Array,
283
309
keepdims : bool = False ,
284
310
** kwargs ) -> Array :
285
311
x = torch .asarray (x )
286
- ndim = x .ndim
287
312
288
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
289
- # below because it still needs to upcast.
290
313
if axis == ():
291
- if dtype is None :
292
- # We can't upcast uint8 according to the spec because there is no
293
- # torch.uint64, so at least upcast to int64 which is what sum does
294
- # when axis=None.
295
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
296
- return x .to (torch .int64 )
297
- return x .clone ()
298
- return x .to (dtype )
299
-
314
+ return _sum_prod_no_axis (x , dtype )
300
315
# torch.prod doesn't support multiple axes
301
316
# (https://github.com/pytorch/pytorch/issues/56586).
302
317
if isinstance (axis , tuple ):
@@ -305,7 +320,7 @@ def prod(x: Array,
305
320
# torch doesn't support keepdims with axis=None
306
321
# (https://github.com/pytorch/pytorch/issues/71209)
307
322
res = torch .prod (x , dtype = dtype , ** kwargs )
308
- res = _axis_none_keepdims (res , ndim , keepdims )
323
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
309
324
return res
310
325
311
326
return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -321,17 +336,8 @@ def sum(x: Array,
321
336
x = torch .asarray (x )
322
337
ndim = x .ndim
323
338
324
- # https://github.com/pytorch/pytorch/issues/29137.
325
- # Make sure it upcasts.
326
339
if axis == ():
327
- if dtype is None :
328
- # We can't upcast uint8 according to the spec because there is no
329
- # torch.uint64, so at least upcast to int64 which is what sum does
330
- # when axis=None.
331
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
332
- return x .to (torch .int64 )
333
- return x .clone ()
334
- return x .to (dtype )
340
+ return _sum_prod_no_axis (x , dtype )
335
341
336
342
if axis is None :
337
343
# torch doesn't support keepdims with axis=None
0 commit comments