Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion for aten.expand by using ttnn.repeat #21

Merged
merged 2 commits into from
Jul 3, 2024

Conversation

kevinwuTT
Copy link
Contributor

Problem description

Transform aten.expand by using ttnn.repeat.

What's changed

The interpretation of the shape argument between the two ops differ. For aten.expand the shape argument is the desired output shape. For ttnn.repeat the shape argument is a multiplier of the input shape instead. Therefore the shape information of the input tensor is needed.

  • Add retention of shape metadata information for Placeholder nodes since transform() removes them
    • Add missing metadata for conversions done under ReplaceMoreTtManually
  • Implement the conversion from aten.expand using the shape metadata

In addition, ttnn.repeat acts differently between TILE and ROW_MAJOR layouts. In most cases, we would need to change the layout of a node to ROW_MAJOR before passing that to ttnn.repeat. Similarly, we would need to check if a layout change back to TILE is needed if any following node that uses ttnn.repeat is a tt compute node.

  • Add logic to insert layout changes before and/or after ttnn.repeat

* Add retention of shape metadata information for Placeholder nodes
since transform() removes them
  * Add missing metadata for conversions done under
ReplaceMoreTtManually
* Add logic to insert layout changes before and/or after ttnn.repeat
@kevinwuTT kevinwuTT requested a review from ayerofieiev-tt July 3, 2024 21:04
# Replace more patterns with torch.fx.Transformer
gm = ReplaceMoreTt(gm).transform()

# Restore metadata for Placeholder nodes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we lose metadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my investigations so far, it could be an oversight. I found some very recent discussion https://discuss.pytorch.org/t/understanding-torch-fx-traceback-preserve-node-meta/205319 with the same issue but unfortunately no response yet. I could create an issue to track this actually. Could be something we optimize away at a later time if we find a better solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#22

@kevinwuTT kevinwuTT added this pull request to the merge queue Jul 3, 2024
Merged via the queue into tenstorrent:main with commit 4ed0e35 Jul 3, 2024
1 check passed
@kevinwuTT kevinwuTT deleted the kw/expand branch July 3, 2024 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants