Skip to content

Commit 6cc8008

Browse files
authored
Merge pull request #101 from asmeurer/torch-scalars
Allow Python scalars in torch functions
2 parents 700c665 + b1b5ecd commit 6cc8008

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_compat/torch/_aliases.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def _f(x1, x2, /, **kwargs):
9999
return _f
100100

101101
def _fix_promotion(x1, x2, only_scalar=True):
102+
if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor):
103+
return x1, x2
102104
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
103105
return x1, x2
104106
# If an argument is 0-D pytorch downcasts the other argument

0 commit comments

Comments
 (0)