Skip to content

Commit 471edf8

Browse files
committed
torch: allow python scalars in result_type
1 parent bfe3fcc commit 471edf8

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

array_api_compat/torch/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
119119
x1 = x1.to(dtype)
120120
return x1, x2
121121

122-
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
122+
123+
_py_scalars = (bool, int, float, complex)
124+
125+
126+
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
123127
if len(arrays_and_dtypes) == 0:
124128
raise TypeError("At least one array or dtype must be provided")
125129
if len(arrays_and_dtypes) == 1:
@@ -131,6 +135,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
131135
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
132136

133137
x, y = arrays_and_dtypes
138+
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
139+
return torch.result_type(x, y)
140+
134141
xdt = x.dtype if not isinstance(x, torch.dtype) else x
135142
ydt = y.dtype if not isinstance(y, torch.dtype) else y
136143

0 commit comments

Comments
 (0)