-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grad transform patch #1712
base: main
Are you sure you want to change the base?
Grad transform patch #1712
Conversation
register_grad(pids.SHAPE, prims.shape) | ||
# NOTE we need to skip pids.SHAPE in grad transform, because `prims.shape` creates multiple outputs, which | ||
# breaks the assumption the in the pass that each symbol has a single output | ||
register_grad(pids.SHAPE, prims.shape, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only needed this change for prims.shape
, but I think it could/should apply to any operation that doesn't require gradient.
a = a.unsqueeze(1) | ||
return a + b | ||
|
||
jfoo = thunder.jit(foo, cache="symbolic values") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for a = a.unsqueeze(0)
. The grad transforms function deconstruct_forward_env_for_backward
runs on the trace
@torch.no_grad()
@no_autocast
def unsqueeze(a, dim):
# a: "cpu f32[1, [IntegerProxy name=i0, value=3, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=4, static=CONSTRAINT.CONSTRAINABLE]]"
(_, i0, i1) = prims.shape(a)
(_, _, _) = prims.shape(a)
(_, _, _) = prims.shape(a)
t8 = prims.broadcast_in_dim(a, [1, 1, i0, i1], [0, 2, 3]) # t8: "cpu f32[1, 1, [IntegerProxy name=i0, value=3, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=4, static=CONSTRAINT.CONSTRAINABLE]]"
return t8
The transform handles saved for backwards
with the assumption that each symbol produces only a single output.
lightning-thunder/thunder/core/transforms.py
Line 2313 in 4f4d632
saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols) |
which is not true for prims.shape
.
Given that prims.shape
doesn't participate in grad at all, I decided to skip it, instead of refactoring the transform to handle multiple outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transform handles saved_for_backward with the assumption that "residuals" are stored only for the first output. It handles symbols with multiple outputs, one example is var_mean
.
There's also a utility function that may be useful for skipping some of the symbols:
lightning-thunder/thunder/core/transforms.py
Line 2282 in ee6edb1
def iter_bound_symbols(bound_symbols): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transform handles saved_for_backward with the assumption that "residuals" are stored only for the first output.
Thanks for correcting me. 😝
Yes, I am skipping the symbols in the util iter_bound_symbols
, i.e. the registration would put them in _grad_skip_ids
and in iter_bound_symbols, we'll skip any symbol in that set.
hmm. seems like we are hitting a real issue.
|
not sure if it's caused by this PR, since we did clear CI before the merge. I'll double check that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current pull request proposes to skip certain operations that do nothing in backward completely. However, we don't want to skip the forward computation because the result may have proxies that could be consumed later. An alternative would be to extend the existing code without adding more complexity to register_grad
. There's is_constant_for_vjp(bsym: BoundSymbol) -> bool
function
lightning-thunder/thunder/core/transforms.py
Line 2505 in 40f7972
def is_constant_for_vjp(symbol: prims.Symbol) -> bool: |
that could be modified instead (if output contains no differentiable proxies then it's considered as constant computation with no backward computation). I think adding an additional condition to that function,
not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_outs)
, should be enough to enable the tests from this pull request.
@@ -2293,6 +2300,8 @@ def iter_bound_symbols(bound_symbols): | |||
for symbol in bound_symbols: | |||
if symbol.sym.id in trace_interpreter_skip_list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have one skip list here, trace_interpreter_skip_list
, instead of adding a new one the existing list should be reused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. That's really helpful.
sounds like is_constant_for_vjp
is actually what I wanted. i.e. it skips the backward but keeps forward as is?
I don't think trace_interpreter_skip_list
is the right list though. It's used in trace_interpreter.py as well, which is outside the context of grad. But since you pointed out is_constant_for_vjp
, I think trace_interpreter_skip_list
is the right list to add them. I'll refactor this.
Yeah that sounds about right. |
This PR refactors grad transform to add a
_grad_skip_ids
, which allows us to skip symbols that does not participate in grad transformation.The main motivation is to ensure that
prims.shape
does not break the saved for backward in the transform.