From b1b5ecd1e7db536bc39cebbd71d544264c268e7f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Mar 2024 13:25:25 -0700 Subject: [PATCH] Allow Python scalars in torch functions Fixes #85. --- array_api_compat/torch/_aliases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index bfa7610b..e7fc7a81 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -99,6 +99,8 @@ def _f(x1, x2, /, **kwargs): return _f def _fix_promotion(x1, x2, only_scalar=True): + if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor): + return x1, x2 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: return x1, x2 # If an argument is 0-D pytorch downcasts the other argument