Skip to content

Commit f33c0c9

Browse files
committed
BUG: torch: fix result_type with python scalars
1 parent e14754b commit f33c0c9

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

Diff for: array_api_compat/torch/_aliases.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,12 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
135135
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
136136

137137
x, y = arrays_and_dtypes
138-
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
139-
return torch.result_type(x, y)
138+
if isinstance(x, _py_scalars):
139+
if isinstance(y, _py_scalars):
140+
raise ValueError("At least one array or dtype is required.")
141+
return y
142+
elif isinstance(y, _py_scalars):
143+
return x
140144

141145
xdt = x.dtype if not isinstance(x, torch.dtype) else x
142146
ydt = y.dtype if not isinstance(y, torch.dtype) else y

0 commit comments

Comments
 (0)