Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vjp implementation to attach residuals to all outputs #1755

Draft
wants to merge 4 commits into
base: ivan-1732-1
Choose a base branch
from

Conversation

IvanYashchuk
Copy link
Collaborator

Base PR: #1754, I'm keeping this PR in Draft for stacking PRs

The current code assumes that the first output of all BoundSymbols will be a Proxy with a .name attribute. It's not always the case when DCE might replace unused outputs with None or some other objects are returned from Symbols.

The important changes are:

  • skipping symbols with no proxies in the output to attach additional metadata for propagation (using is_literal, the approach is similar to the attempt in Grad transform patch #1712)
  • instead of specialing on the assumption that the first output of bound symbols is always a proxy object we actually find the first output, if all outputs are non-proxies they should have been skipped due to changes from the previous bullet point

I added the test from the issue description of #1732.

Fixes #1732.

@@ -1934,6 +1934,18 @@ def func(x):
torch.testing.assert_close(actual_gr, expected_gr)


# See https://github.com/Lightning-AI/lightning-thunder/issues/1732
def test_symbolic_shape_for_backward_issue_1732():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjsjann123, is there any other test you tried that didn't work?

@IvanYashchuk
Copy link
Collaborator Author

This is ready for review, the draft status is only to prevent merges into the previous PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant