Skip to content

Commit fe83ca7

Browse files
committed
torch: rewrite result_type to handle scalars, add a test
Test torch-specific behavior which is unspecified in the spec.
1 parent cf282bc commit fe83ca7

File tree

2 files changed

+85
-28
lines changed

2 files changed

+85
-28
lines changed

array_api_compat/torch/_aliases.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import wraps as _wraps
3+
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common import _aliases
@@ -124,43 +124,35 @@ def _fix_promotion(x1, x2, only_scalar=True):
124124

125125

126126
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
127-
if len(arrays_and_dtypes) == 0:
128-
raise TypeError("At least one array or dtype must be provided")
129-
if len(arrays_and_dtypes) == 1:
127+
num = len(arrays_and_dtypes)
128+
129+
if num == 0:
130+
raise ValueError("At least one array or dtype must be provided")
131+
132+
elif num == 1:
130133
x = arrays_and_dtypes[0]
131134
if isinstance(x, torch.dtype):
132135
return x
133136
return x.dtype
134137

135-
if len(arrays_and_dtypes) > 2:
136-
# sort the scalars to the left so that they are treated last
137-
scalars, others = [], []
138-
for x in arrays_and_dtypes:
139-
if isinstance(x, _py_scalars):
140-
scalars.append(x)
141-
else:
142-
others.append(x)
143-
if len(scalars) == len(arrays_and_dtypes):
144-
raise ValueError("At least one array or dtype is required.")
138+
if num == 2:
139+
x, y = arrays_and_dtypes
140+
return _result_type(x, y)
145141

146-
arrays_and_dtypes = scalars + others
147-
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
142+
else:
143+
if _builtin_all(isinstance(x, _py_scalars) for x in arrays_and_dtypes):
144+
raise ValueError("At least one array or dtype must be provided")
148145

149-
# the binary case
150-
x, y = arrays_and_dtypes
146+
return _reduce(_result_type, arrays_and_dtypes)
151147

152-
if isinstance(x, _py_scalars):
153-
if isinstance(y, _py_scalars):
154-
raise ValueError("At least one array or dtype is required.")
155-
return y
156-
elif isinstance(y, _py_scalars):
157-
return x
158148

159-
xdt = x.dtype if not isinstance(x, torch.dtype) else x
160-
ydt = y.dtype if not isinstance(y, torch.dtype) else y
149+
def _result_type(x, y):
150+
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
151+
xdt = x.dtype if not isinstance(x, torch.dtype) else x
152+
ydt = y.dtype if not isinstance(y, torch.dtype) else y
161153

162-
if (xdt, ydt) in _promotion_table:
163-
return _promotion_table[xdt, ydt]
154+
if (xdt, ydt) in _promotion_table:
155+
return _promotion_table[xdt, ydt]
164156

165157
# This doesn't result_type(dtype, dtype) for non-array API dtypes
166158
# because torch.result_type only accepts tensors. This does however, allow
@@ -169,6 +161,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
169161
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
170162
return torch.result_type(x, y)
171163

164+
172165
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
173166
if not isinstance(from_, torch.dtype):
174167
from_ = from_.dtype

tests/test_torch.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2+
"""
3+
import pytest
4+
import torch
5+
6+
from array_api_compat import torch as xp
7+
8+
9+
class TestResultType:
10+
def test_empty(self):
11+
with pytest.raises(ValueError):
12+
xp.result_type()
13+
14+
def test_one_arg(self):
15+
for x in [1, 1.0, 1j, '...', None]:
16+
with pytest.raises((ValueError, AttributeError)):
17+
xp.result_type(x)
18+
19+
for x in [xp.float32, xp.int64, torch.complex64]:
20+
assert xp.result_type(x) == x
21+
22+
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
23+
assert xp.result_type(x) == x.dtype
24+
25+
def test_two_args(self):
26+
# Only include here things "unspecified" in the spec
27+
28+
# scalar, tensor or tensor,tensor
29+
for x, y in [
30+
(1., 1j),
31+
(1j, xp.arange(3)),
32+
(True, xp.asarray(3.)),
33+
(xp.ones(3) == 1, 1j*xp.ones(3)),
34+
]:
35+
assert xp.result_type(x, y) == torch.result_type(x, y)
36+
37+
# dtype, scalar
38+
for x, y in [
39+
(1j, xp.int64),
40+
(True, xp.float64),
41+
]:
42+
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
43+
44+
# dtype, dtype
45+
for x, y in [
46+
(xp.bool, xp.complex64)
47+
]:
48+
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
49+
assert xp.result_type(x, y) == torch.result_type(xt, yt)
50+
51+
def test_multi_arg(self):
52+
torch.set_default_dtype(torch.float32)
53+
54+
args = [1, 2, 3j, xp.arange(3), 4, 5, 6]
55+
assert xp.result_type(*args) == xp.complex64
56+
57+
args = [1, 2, 3j, xp.float64, 4, 5, 6]
58+
assert xp.result_type(*args) == xp.complex128
59+
60+
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
61+
assert xp.result_type(*args) == xp.complex128
62+
63+
with pytest.raises(ValueError):
64+
xp.result_type(1, 2, 3, 4)

0 commit comments

Comments
 (0)