Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: torch.asarray device propagation #296

Closed
wants to merge 1 commit into from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 3, 2025

@Copilot Copilot bot review requested due to automatic review settings April 3, 2025 09:14
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request implements a workaround for torch.asarray’s device propagation issue by adjusting imports and type annotations, along with minor device‐related tweaks across multiple array API compatibility modules.

  • Update type hints for device parameters and return types (from str to Device)
  • Replace the custom Device import with torch.device in torch/_typing.py
  • Introduce a new asarray function in torch/_aliases.py to address dtype propagation

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Show a summary per file
File Description
array_api_compat/torch/_typing.py Updated import of Device from torch and modified the all order
array_api_compat/torch/_info.py Adjusted docstrings to reflect Device type for default device methods
array_api_compat/torch/_aliases.py Added a new asarray function and removed redundant torch.asarray calls
array_api_compat/numpy/_info.py Changed device type annotations from str to Device
array_api_compat/dask/array/_info.py Changed device type annotations from str to Device
array_api_compat/cupy/_info.py Updated device type annotations and added Notes in docstring
array_api_compat/common/_aliases.py Updated inline comments to include Dask in creation functions
Comments suppressed due to low confidence (3)

array_api_compat/torch/_typing.py:3

  • The change to import Device directly from torch (using torch.device) may affect the expected Device behavior compared to the previous custom definition. Please verify that this change preserves the intended device propagation semantics throughout the code.
from torch import device as Device, dtype as DType, Tensor as Array

array_api_compat/torch/_aliases.py:228

  • In the new asarray function, only torch.Tensor inputs have their device set automatically. Consider verifying that non-tensor objects are handled correctly when device is None.
if device is None and isinstance(obj, torch.Tensor):

array_api_compat/torch/_aliases.py:308

  • Removal of the explicit torch.asarray conversion in functions like prod, sum, any, and all might lead to unintended type issues if the input is not already a tensor. Confirm that callers always supply a tensor or that the new asarray function is applied upstream to avoid runtime errors.
    -    x = torch.asarray(x)

@crusaderky
Copy link
Contributor Author

CI failures unrelated (#297)

@ev-br
Copy link
Member

ev-br commented Apr 3, 2025

close/reopen to rerun the CI

@crusaderky
Copy link
Contributor Author

I can't reopen a PR someone else closed.
Repoened in #299.

@ev-br
Copy link
Member

ev-br commented Apr 3, 2025

Apologies for the glitch. No idea why it's not got reopened.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants