Skip to content

Commit cf282bc

Browse files
committed
BUG: torch: sort arguments of result_type
This way, python scalars are always treated last. Otherwise, "runs" of python scalars are problematic: `result_type(float32, int64, 1, 2) -> ValueError`
1 parent f33c0c9 commit cf282bc

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

Diff for: array_api_compat/torch/_aliases.py

+14
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,24 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
131131
if isinstance(x, torch.dtype):
132132
return x
133133
return x.dtype
134+
134135
if len(arrays_and_dtypes) > 2:
136+
# sort the scalars to the left so that they are treated last
137+
scalars, others = [], []
138+
for x in arrays_and_dtypes:
139+
if isinstance(x, _py_scalars):
140+
scalars.append(x)
141+
else:
142+
others.append(x)
143+
if len(scalars) == len(arrays_and_dtypes):
144+
raise ValueError("At least one array or dtype is required.")
145+
146+
arrays_and_dtypes = scalars + others
135147
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
136148

149+
# the binary case
137150
x, y = arrays_and_dtypes
151+
138152
if isinstance(x, _py_scalars):
139153
if isinstance(y, _py_scalars):
140154
raise ValueError("At least one array or dtype is required.")

0 commit comments

Comments
 (0)