Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grad transform patch #1712

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Grad transform patch #1712

wants to merge 5 commits into from

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 29, 2025

This PR refactors grad transform to add a _grad_skip_ids, which allows us to skip symbols that does not participate in grad transformation.

The main motivation is to ensure that prims.shape does not break the saved for backward in the transform.

register_grad(pids.SHAPE, prims.shape)
# NOTE we need to skip pids.SHAPE in grad transform, because `prims.shape` creates multiple outputs, which
# breaks the assumption the in the pass that each symbol has a single output
register_grad(pids.SHAPE, prims.shape, True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only needed this change for prims.shape, but I think it could/should apply to any operation that doesn't require gradient.

a = a.unsqueeze(1)
return a + b

jfoo = thunder.jit(foo, cache="symbolic values")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a = a.unsqueeze(0). The grad transforms function deconstruct_forward_env_for_backward runs on the trace

@torch.no_grad()
@no_autocast
def unsqueeze(a, dim):
  # a: "cpu f32[1, [IntegerProxy name=i0, value=3, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=4, static=CONSTRAINT.CONSTRAINABLE]]"
  (_, i0, i1) = prims.shape(a)
  (_, _, _) = prims.shape(a)
  (_, _, _) = prims.shape(a)
  t8 = prims.broadcast_in_dim(a, [1, 1, i0, i1], [0, 2, 3])  # t8: "cpu f32[1, 1, [IntegerProxy name=i0, value=3, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=4, static=CONSTRAINT.CONSTRAINABLE]]"
  return t8

The transform handles saved for backwards with the assumption that each symbol produces only a single output.

saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)

which is not true for prims.shape.

Given that prims.shape doesn't participate in grad at all, I decided to skip it, instead of refactoring the transform to handle multiple outputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transform handles saved_for_backward with the assumption that "residuals" are stored only for the first output. It handles symbols with multiple outputs, one example is var_mean.
There's also a utility function that may be useful for skipping some of the symbols:

def iter_bound_symbols(bound_symbols):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transform handles saved_for_backward with the assumption that "residuals" are stored only for the first output.

Thanks for correcting me. 😝

Yes, I am skipping the symbols in the util iter_bound_symbols, i.e. the registration would put them in _grad_skip_ids and in iter_bound_symbols, we'll skip any symbol in that set.

@jjsjann123
Copy link
Collaborator Author

hmm. seems like we are hitting a real issue.

FAILED thunder/tests/test_networks.py::test_hf_kvcache - thunder.core.interpreter.InterpreterError: Encountered exception RuntimeError: dictionary changed size during iteration while tracing LlamaForCausalLM(

@jjsjann123
Copy link
Collaborator Author

not sure if it's caused by this PR, since we did clear CI before the merge. I'll double check that.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current pull request proposes to skip certain operations that do nothing in backward completely. However, we don't want to skip the forward computation because the result may have proxies that could be consumed later. An alternative would be to extend the existing code without adding more complexity to register_grad. There's is_constant_for_vjp(bsym: BoundSymbol) -> bool function

def is_constant_for_vjp(symbol: prims.Symbol) -> bool:

that could be modified instead (if output contains no differentiable proxies then it's considered as constant computation with no backward computation). I think adding an additional condition to that function, not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_outs), should be enough to enable the tests from this pull request.

@@ -2293,6 +2300,8 @@ def iter_bound_symbols(bound_symbols):
for symbol in bound_symbols:
if symbol.sym.id in trace_interpreter_skip_list:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have one skip list here, trace_interpreter_skip_list, instead of adding a new one the existing list should be reused.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That's really helpful.

sounds like is_constant_for_vjp is actually what I wanted. i.e. it skips the backward but keeps forward as is?

I don't think trace_interpreter_skip_list is the right list though. It's used in trace_interpreter.py as well, which is outside the context of grad. But since you pointed out is_constant_for_vjp, I think trace_interpreter_skip_list is the right list to add them. I'll refactor this.

@jjsjann123
Copy link
Collaborator Author

The current pull request proposes to skip certain operations that do nothing in backward completely. However, we don't want to skip the forward computation because the result may have proxies that could be consumed later.

Yeah that sounds about right.
I'm surprised that CI didn't caught any issue with that. 😢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants