Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vjp implementation to attach residuals to all outputs #1755

Draft
wants to merge 4 commits into
base: ivan-1732-1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,19 +537,12 @@ def match_fw_and_bw_saved_for_bw_proxies(
Returns:
new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies
"""

old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1])
old_saved_for_backward_bw = []
for bsym in bw_trace.bound_symbols:
if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE:
flattened_args = tree_flatten(bw_trace.args[1])[0]
proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)}
if all(
not isinstance(out, CollectionProxy) and out.name not in proxy_names
for out in bsym.flat_outs
if out is not None
):
old_saved_for_backward_bw += bsym.flat_outs
old_saved_tensors_fw = fw_trace.output[1][0]
old_saved_other_fw = fw_trace.output[1][1]
old_saved_for_backward_fw = (*old_saved_tensors_fw, *old_saved_other_fw)
old_saved_tensors_bw = bw_trace.args[0][0]
old_saved_other_bw = bw_trace.args[0][1]
old_saved_for_backward_bw = (*old_saved_tensors_bw, *old_saved_other_bw)
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
new_required_for_backward_bw_to_fw_map = {
x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
Expand Down
29 changes: 20 additions & 9 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
OrderedSet,
ProxyDict,
)
from thunder.core.codeutils import is_literal
import thunder.clang as clang
from thunder.clang import (
empty,
Expand Down Expand Up @@ -2293,12 +2294,27 @@ def iter_bound_symbols(bound_symbols):
for symbol in bound_symbols:
if symbol.sym.id in trace_interpreter_skip_list:
continue
elif all(is_literal(sym_out) for sym_out in symbol.flat_outs):
continue
elif symbol.output is None:
continue
else:
yield symbol


def get_first_proxy(proxies) -> Proxy | None:
"""Get the first proxy from a list of proxies.

Args:
proxies (List[Proxy]): List of proxies

Returns:
Proxy: First proxy from the list
"""
proxies = sequencify(proxies)
return next((proxy for proxy in proxies if isinstance(proxy, Proxy)), None)


def deconstruct_forward_env_for_backward(trace, env):
# Note [Saving the forward environment in the backward rule]
# We cannot save the trace object in the residuals because executors may not
Expand All @@ -2310,7 +2326,7 @@ def deconstruct_forward_env_for_backward(trace, env):
# arguments. See test_grad.py:test_torch_autograd_function for an example
# where this is tested.
bound_symbols = iter_bound_symbols(trace.bound_symbols)
saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
saved_for_backward = tuple(env[get_first_proxy(symbol.output).name].residuals for symbol in bound_symbols)
return saved_for_backward


Expand All @@ -2320,7 +2336,7 @@ def reconstruct_forward_env_for_backward(trace, saved_for_backward):
reconstructed_env = {}

for idx, sym in enumerate(bound_symbols):
k = sequencify(sym.output)[0].name
k = get_first_proxy(sym.output).name
v = VJPDual(None, saved_for_backward[idx])
reconstructed_env[k] = v

Expand Down Expand Up @@ -2569,12 +2585,7 @@ def vjp_impl_const(symbol, *args, **kwargs):
def _vjp_impl(*args, **kwargs):
primals, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs))
out_primal, out_residuals = vjp_impl(*primals, **kwargs)
# We are saving the residuals and pullback only in the first output
# backward_pass then retrieves the residuals and pullback from the first output
if isinstance(out_primal, Sequence):
return (VJPDual(out_primal[0], out_residuals), *(VJPDual(o, tuple()) for o in out_primal[1:]))

return (VJPDual(out_primal, out_residuals),)
return tree_map(lambda x: VJPDual(x, out_residuals), sequencify(out_primal))

return _vjp_impl

Expand Down Expand Up @@ -2729,7 +2740,7 @@ def put_grad(v: Variable, val: Any) -> None:
# Having a single cotangent is a common case, so we flatten it
# Otherwise, we will need to rewrite the pullback functions
cotangents = tree_flatten(cotangents)[0]
residuals = forward_env[symbol_output[0].name].residuals
residuals = forward_env[get_first_proxy(symbol_output).name].residuals
if is_constant_for_vjp(symbol):
# We can skip the pullback if all the arguments are constant
continue
Expand Down
12 changes: 12 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,18 @@ def func(x):
torch.testing.assert_close(actual_gr, expected_gr)


# See https://github.com/Lightning-AI/lightning-thunder/issues/1732
def test_symbolic_shape_for_backward_issue_1732():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjsjann123, is there any other test you tried that didn't work?

@partial(thunder.jit, cache="symbolic values")
def f(a, b):
return a * b

a = make_tensor((1, 32, 232, 232), device="cpu", dtype=torch.float32, requires_grad=True)
b = make_tensor((1, 1, 232, 232), device="cpu", dtype=torch.float32, requires_grad=False)
out = f(a, b)
out.backward(torch.ones_like(out))


@pytest.mark.parametrize("device", ("cuda", "cpu"))
def test_backward_recomputation_decomposed_ops(device):
if device == "cuda" and not torch.cuda.is_available():
Expand Down
Loading