-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT REVIEW][SMOKE TEST] Skip prims shape in grad transform #1704
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
…backward_transform_dependency_fix
…backward_transform_dependency_fix
Co-authored-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@@ -123,7 +123,7 @@ def keep_or_swap(p): | |||
if not isinstance(p, NumberProxyInterface): | |||
return p | |||
if p.name in seen: | |||
return p.value # don't make it a duplicate | |||
return None # don't make it a duplicate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note for myself, I needed this to avoid an error in rematerialization pass. I'll work on a separate PR with a repro when I get that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errrrr... this one doesn't seem to be working any more... I'm seeing the assert in rematerialization.py again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import thunder
def foo(a, b):
return a + b
import torch
a = torch.randn(1, 32, 232, 232)
a.requires_grad_()
b = torch.randn(1, 1, 232, 232)
#b.requires_grad_()
jfoo = thunder.jit(foo, cache="symbolic values")
out = jfoo(a, b)
took me a while to get a repro here.
I think the issue is coming from saving for backward not properly identifying which gradient path isn't required. So after dce kicks in, the saved_for_backward is inconsistent... now I feel that I just missed a dce somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying my luck with #1725
else: | ||
assert isinstance(new, ProxyInterface), (old, new) | ||
swap_map[variableify(new)] = old | ||
if variableify(old) != variableify(new): | ||
swap_map[variableify(new.primal)] = old |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note for myself, this is a separate thing. break this into a separate PR.
No description provided.