-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: torch.asarray
device propagation
#296
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request implements a workaround for torch.asarray’s device propagation issue by adjusting imports and type annotations, along with minor device‐related tweaks across multiple array API compatibility modules.
- Update type hints for device parameters and return types (from str to Device)
- Replace the custom Device import with torch.device in torch/_typing.py
- Introduce a new asarray function in torch/_aliases.py to address dtype propagation
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
array_api_compat/torch/_typing.py | Updated import of Device from torch and modified the all order |
array_api_compat/torch/_info.py | Adjusted docstrings to reflect Device type for default device methods |
array_api_compat/torch/_aliases.py | Added a new asarray function and removed redundant torch.asarray calls |
array_api_compat/numpy/_info.py | Changed device type annotations from str to Device |
array_api_compat/dask/array/_info.py | Changed device type annotations from str to Device |
array_api_compat/cupy/_info.py | Updated device type annotations and added Notes in docstring |
array_api_compat/common/_aliases.py | Updated inline comments to include Dask in creation functions |
Comments suppressed due to low confidence (3)
array_api_compat/torch/_typing.py:3
- The change to import Device directly from torch (using torch.device) may affect the expected Device behavior compared to the previous custom definition. Please verify that this change preserves the intended device propagation semantics throughout the code.
from torch import device as Device, dtype as DType, Tensor as Array
array_api_compat/torch/_aliases.py:228
- In the new asarray function, only torch.Tensor inputs have their device set automatically. Consider verifying that non-tensor objects are handled correctly when device is None.
if device is None and isinstance(obj, torch.Tensor):
array_api_compat/torch/_aliases.py:308
- Removal of the explicit torch.asarray conversion in functions like prod, sum, any, and all might lead to unintended type issues if the input is not already a tensor. Confirm that callers always supply a tensor or that the new asarray function is applied upstream to avoid runtime errors.
- x = torch.asarray(x)
CI failures unrelated (#297) |
close/reopen to rerun the CI |
I can't reopen a PR someone else closed. |
Apologies for the glitch. No idea why it's not got reopened. |
asarray
: device does not propagate from input to output afterset_default_device
pytorch/pytorch#150199prod
,sum
,any
andall
, which had the same issue induced by array-api-compat