diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..508865bb 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -21,9 +21,9 @@ try: # torch >=2.3 _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} + _HAS_LARGE_UINT = True except AttributeError: - pass - + _HAS_LARGE_UINT = False _array_api_dtypes = { torch.bool, @@ -34,47 +34,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -82,6 +58,31 @@ (torch.float64, torch.complex128): torch.complex128, } +if _HAS_LARGE_UINT: # torch >=2.3 + _promotion_table.update( + { + # uints + (torch.uint8, torch.uint16): torch.uint16, + (torch.uint8, torch.uint32): torch.uint32, + (torch.uint8, torch.uint64): torch.uint64, + (torch.uint16, torch.uint32): torch.uint32, + (torch.uint16, torch.uint64): torch.uint64, + (torch.uint32, torch.uint64): torch.uint64, + # ints and uints (mixed sign) + (torch.uint16, torch.int8): torch.int32, + (torch.uint16, torch.int16): torch.int32, + (torch.uint16, torch.int32): torch.int32, + (torch.uint16, torch.int64): torch.int64, + (torch.uint32, torch.int8): torch.int64, + (torch.uint32, torch.int16): torch.int64, + (torch.uint32, torch.int32): torch.int64, + (torch.uint32, torch.int64): torch.int64, + } + ) + +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -275,6 +276,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + if x.dtype in (torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): + return x.to(torch.uint64) + + if x.dtype == torch.uint8: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + return x.to(torch.int64) + + return x.clone() + + def prod(x: Array, /, *, @@ -283,20 +309,9 @@ def prod(x: Array, keepdims: bool = False, **kwargs) -> Array: x = torch.asarray(x) - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -305,7 +320,7 @@ def prod(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -319,25 +334,14 @@ def sum(x: Array, keepdims: bool = False, **kwargs) -> Array: x = torch.asarray(x) - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)