Skip to content

Commit 8c7fe4e

Browse files
committed
torch: rewrite result_type to handle scalars, add a test
Test torch-specific behavior which is unspecified in the spec.
1 parent cf282bc commit 8c7fe4e

File tree

2 files changed

+87
-28
lines changed

2 files changed

+87
-28
lines changed

Diff for: array_api_compat/torch/_aliases.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import wraps as _wraps
3+
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common import _aliases
@@ -124,43 +124,35 @@ def _fix_promotion(x1, x2, only_scalar=True):
124124

125125

126126
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
127-
if len(arrays_and_dtypes) == 0:
128-
raise TypeError("At least one array or dtype must be provided")
129-
if len(arrays_and_dtypes) == 1:
127+
num = len(arrays_and_dtypes)
128+
129+
if num == 0:
130+
raise ValueError("At least one array or dtype must be provided")
131+
132+
elif num == 1:
130133
x = arrays_and_dtypes[0]
131134
if isinstance(x, torch.dtype):
132135
return x
133136
return x.dtype
134137

135-
if len(arrays_and_dtypes) > 2:
136-
# sort the scalars to the left so that they are treated last
137-
scalars, others = [], []
138-
for x in arrays_and_dtypes:
139-
if isinstance(x, _py_scalars):
140-
scalars.append(x)
141-
else:
142-
others.append(x)
143-
if len(scalars) == len(arrays_and_dtypes):
144-
raise ValueError("At least one array or dtype is required.")
138+
if num == 2:
139+
x, y = arrays_and_dtypes
140+
return _result_type(x, y)
145141

146-
arrays_and_dtypes = scalars + others
147-
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
142+
else:
143+
if _builtin_all(isinstance(x, _py_scalars) for x in arrays_and_dtypes):
144+
raise ValueError("At least one array or dtype must be provided")
148145

149-
# the binary case
150-
x, y = arrays_and_dtypes
146+
return _reduce(_result_type, arrays_and_dtypes)
151147

152-
if isinstance(x, _py_scalars):
153-
if isinstance(y, _py_scalars):
154-
raise ValueError("At least one array or dtype is required.")
155-
return y
156-
elif isinstance(y, _py_scalars):
157-
return x
158148

159-
xdt = x.dtype if not isinstance(x, torch.dtype) else x
160-
ydt = y.dtype if not isinstance(y, torch.dtype) else y
149+
def _result_type(x, y):
150+
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
151+
xdt = x.dtype if not isinstance(x, torch.dtype) else x
152+
ydt = y.dtype if not isinstance(y, torch.dtype) else y
161153

162-
if (xdt, ydt) in _promotion_table:
163-
return _promotion_table[xdt, ydt]
154+
if (xdt, ydt) in _promotion_table:
155+
return _promotion_table[xdt, ydt]
164156

165157
# This doesn't result_type(dtype, dtype) for non-array API dtypes
166158
# because torch.result_type only accepts tensors. This does however, allow
@@ -169,6 +161,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
169161
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
170162
return torch.result_type(x, y)
171163

164+
172165
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
173166
if not isinstance(from_, torch.dtype):
174167
from_ = from_.dtype

Diff for: tests/test_torch.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2+
"""
3+
import functools
4+
5+
import pytest
6+
import torch
7+
8+
from array_api_compat import torch as xp
9+
10+
11+
class TestResultType:
12+
def test_empty(self):
13+
with pytest.raises(ValueError):
14+
xp.result_type()
15+
16+
def test_one_arg(self):
17+
for x in [1, 1.0, 1j, '...', None]:
18+
with pytest.raises((ValueError, AttributeError)):
19+
xp.result_type(x)
20+
21+
for x in [xp.float32, xp.int64, torch.complex64]:
22+
assert xp.result_type(x) == x
23+
24+
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
25+
assert xp.result_type(x) == x.dtype
26+
27+
def test_two_args(self):
28+
# Only include here things "unspecified" in the spec
29+
30+
# scalar, tensor or tensor,tensor
31+
for x, y in [
32+
(1., 1j),
33+
(1j, xp.arange(3)),
34+
(True, xp.asarray(3.)),
35+
(xp.ones(3) == 1, 1j*xp.ones(3)),
36+
]:
37+
assert xp.result_type(x, y) == torch.result_type(x, y)
38+
39+
# dtype, scalar
40+
for x, y in [
41+
(1j, xp.int64),
42+
(True, xp.float64),
43+
]:
44+
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
45+
46+
# dtype, dtype
47+
for x, y in [
48+
(xp.bool, xp.complex64)
49+
]:
50+
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
51+
assert xp.result_type(x, y) == torch.result_type(xt, yt)
52+
53+
def test_multi_arg(self):
54+
torch.set_default_dtype(torch.float32)
55+
56+
args = [1, 2, 3j, xp.arange(3), 4, 5, 6]
57+
assert xp.result_type(*args) == xp.complex64
58+
59+
args = [1, 2, 3j, xp.float64, 4, 5, 6]
60+
assert xp.result_type(*args) == xp.complex128
61+
62+
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
63+
assert xp.result_type(*args) == xp.complex128
64+
65+
with pytest.raises(ValueError):
66+
xp.result_type(1, 2, 3, 4)

0 commit comments

Comments
 (0)