-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: torch
dtype promotions
#298
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances dtype promotion behavior in PyTorch compatibility by supporting additional unsigned integer types and standardizing deep copy semantics for sum and prod operations when axis is an empty tuple.
- Updated the promotion table to incorporate uint16, uint32, and uint64 on PyTorch ≥2.3.
- Introduced a new internal helper (_sum_prod_no_axis) to implement deep-copy behavior for sum and prod when axis=().
Comments suppressed due to low confidence (2)
array_api_compat/torch/_aliases.py:280
- Consider adding unit tests to explicitly verify the behavior of _sum_prod_no_axis across different dtypes, including the cases with and without _HAS_LARGE_UINT enabled.
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
array_api_compat/torch/_aliases.py:280
- [nitpick] Consider renaming _sum_prod_no_axis to a more descriptive name that conveys its role in returning a deep copy when no axis is provided, for example, _cast_and_clone_no_axis.
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
#253 seems to have run into a wall with uint promotions in pytorch, did something change |
No, nothing has changed on the torch side AFAIK.
|
You lost me here :-). So is it basically "let's add something because we want it, but not test it on CI because we know it'll break"? |
Well, we will need this and the other PR eventually when torch introduces support. |
result_type
sum
andprod
will now promote uint8, uint16 and uint32 to uint64.sum(x, axis=(), dtype=x.dtype)
andprod
with the same parameters, which were previously returningx
itself, to return a deep copy instead. While this is not part of the Array API, it's what both numpy an base torch do, and probably the most healthy thing to do xref https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html