Skip to content

torch binary operations broken for scalar inputs in _fix_promotion #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
asford opened this issue Feb 7, 2024 · 2 comments · Fixed by #101
Closed

torch binary operations broken for scalar inputs in _fix_promotion #85

asford opened this issue Feb 7, 2024 · 2 comments · Fixed by #101

Comments

@asford
Copy link

asford commented Feb 7, 2024

The dtype promotion check in _fix_promotion does not correctly identify scalar inputs, and unconditionally accesses .dtype.
This breaks binary operators with float scalar inputs.

The can be fixed by accessing dtype via getattr with a None default or validating that the input is not a scalar.
Happy to provide a PR.

Minimal repo, in version 1.4, via:

import torch
import numpy
import array_api_compat as aac

aac.__version__ ()

t = torch.arange(10)
n = numpy.arange(10)

numpy.add(n, 1.0)
torch.add(t, 1.0)

aac.get_namespace(n).add(n, 1.0)
aac.get_namespace(t).add(t, 1.0)

Raises:

      9 torch.add(t, 1.0)
     11 aac.get_namespace(n).add(n, 1.0)
---> 12 aac.get_namespace(t).add(t, 1.0)

File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
     89 @wraps(f)
     90 def _f(x1, x2, /, **kwargs):
---> 91     x1, x2 = _fix_promotion(x1, x2)
     92     return f(x1, x2, **kwargs)

File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104, in _fix_promotion(x1, x2, only_scalar)
    103 def _fix_promotion(x1, x2, only_scalar=True):
--> 104     if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
    105         return x1, x2
    106     # If an argument is 0-D pytorch downcasts the other argument

AttributeError: 'float' object has no attribute 'dtype'

Would expect equivalent behavior to torch.add.

See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47

@asmeurer
Copy link
Member

asmeurer commented Mar 8, 2024

Somehow missed this issue. Python scalar inputs are portable in the standard, but we should still make them work if they work with the upstream functions.

asmeurer added a commit to asmeurer/array-api-compat that referenced this issue Mar 8, 2024
@asmeurer
Copy link
Member

asmeurer commented Mar 8, 2024

Python scalar inputs are portable in the standard

aren't portable. You shouldn't use them if you are aiming for full array API compatibility.

Fix at #101

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants