Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: torch dtype promotions #298

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

  • On PyTorch >=2.3, support uint16, uint32, and uint64 in result_type
  • On PyTorch >=2.3, sum and prod will now promote uint8, uint16 and uint32 to uint64.
  • Change sum(x, axis=(), dtype=x.dtype) and prod with the same parameters, which were previously returning x itself, to return a deep copy instead. While this is not part of the Array API, it's what both numpy an base torch do, and probably the most healthy thing to do xref https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

@Copilot Copilot bot review requested due to automatic review settings April 3, 2025 10:35
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances dtype promotion behavior in PyTorch compatibility by supporting additional unsigned integer types and standardizing deep copy semantics for sum and prod operations when axis is an empty tuple.

  • Updated the promotion table to incorporate uint16, uint32, and uint64 on PyTorch ≥2.3.
  • Introduced a new internal helper (_sum_prod_no_axis) to implement deep-copy behavior for sum and prod when axis=().
Comments suppressed due to low confidence (2)

array_api_compat/torch/_aliases.py:280

  • Consider adding unit tests to explicitly verify the behavior of _sum_prod_no_axis across different dtypes, including the cases with and without _HAS_LARGE_UINT enabled.
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:

array_api_compat/torch/_aliases.py:280

  • [nitpick] Consider renaming _sum_prod_no_axis to a more descriptive name that conveys its role in returning a deep copy when no axis is provided, for example, _cast_and_clone_no_axis.
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:

@ev-br
Copy link
Member

ev-br commented Apr 3, 2025

#253 seems to have run into a wall with uint promotions in pytorch, did something change meanwhile in the meantime?

@crusaderky
Copy link
Contributor Author

crusaderky commented Apr 3, 2025

#253 seems to have run into a wall with uint promotions in pytorch, did something change meanwhile in the meantime?

No, nothing has changed on the torch side AFAIK.

  • The changes to the dtypes table in this PR are functionally the same as in ENH: torch: unsigned types #253, but IMHO cleaner. They don't need to wait for full torch support in various functions to be merged.
  • The fix to sum(x, dtype=x.dtype, axis=() is unrelated to dtype promotion
  • The promotion of uint8/16/32 to uint64 in sum and prod is missing in ENH: torch: unsigned types #253, and something you want to have anyway
  • Unlike ENH: torch: unsigned types #253, this PR changes neither CI nor __array_namespace_info__, which is what makes that PR non-mergeable.

@ev-br
Copy link
Member

ev-br commented Apr 3, 2025

this PR changes neither CI nor array_namespace_info, which is what makes that PR non-mergeable.

You lost me here :-). So is it basically "let's add something because we want it, but not test it on CI because we know it'll break"?

@crusaderky
Copy link
Contributor Author

this PR changes neither CI nor array_namespace_info, which is what makes that PR non-mergeable.

You lost me here :-). So is it basically "let's add something because we want it, but not test it on CI because we know it'll break"?

Well, we will need this and the other PR eventually when torch introduces support.
We can either keep this PR unmerged, or leave it in main but untested by CI. Final users that somehow manage to avoid the many non-functioning APIs will benefit from it immediately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants