Skip to content

Commit fc04a64

Browse files
committed
torch: allow python scalars in result_type
1 parent 4d81a0d commit fc04a64

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

Diff for: array_api_compat/torch/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
125125
x1 = x1.to(dtype)
126126
return x1, x2
127127

128-
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
128+
129+
_py_scalars = (bool, int, float, complex)
130+
131+
132+
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
129133
if len(arrays_and_dtypes) == 0:
130134
raise TypeError("At least one array or dtype must be provided")
131135
if len(arrays_and_dtypes) == 1:
@@ -137,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
137141
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
138142

139143
x, y = arrays_and_dtypes
144+
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
145+
return torch.result_type(x, y)
146+
140147
xdt = x.dtype if not isinstance(x, torch.dtype) else x
141148
ydt = y.dtype if not isinstance(y, torch.dtype) else y
142149

0 commit comments

Comments
 (0)