From 698dfc7b1c81c3f6f1a1f9d8a4b20f6019f45ba1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 7 Feb 2025 14:34:02 +0200 Subject: [PATCH 1/4] Simplify and fix match_fw_and_bw_saved_for_bw_proxies implementation --- thunder/core/rematerialization.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 9f95fc7572..6b16485837 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -537,19 +537,12 @@ def match_fw_and_bw_saved_for_bw_proxies( Returns: new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies """ - - old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1]) - old_saved_for_backward_bw = [] - for bsym in bw_trace.bound_symbols: - if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE: - flattened_args = tree_flatten(bw_trace.args[1])[0] - proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)} - if all( - not isinstance(out, CollectionProxy) and out.name not in proxy_names - for out in bsym.flat_outs - if out is not None - ): - old_saved_for_backward_bw += bsym.flat_outs + old_saved_tensors_fw = fw_trace.output[1][0] + old_saved_other_fw = fw_trace.output[1][1] + old_saved_for_backward_fw = (*old_saved_tensors_fw, *old_saved_other_fw) + old_saved_tensors_bw = bw_trace.args[0][0] + old_saved_other_bw = bw_trace.args[0][1] + old_saved_for_backward_bw = (*old_saved_tensors_bw, *old_saved_other_bw) assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw) new_required_for_backward_bw_to_fw_map = { x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None From 7a2f7f7d192d428219d858241ba68124b720eaba Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 7 Feb 2025 14:41:27 +0200 Subject: [PATCH 2/4] Skip symbols with only literals in the output; they do not contribute to the interpreting environment --- thunder/core/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e5d0b67bea..0b89a21360 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -56,6 +56,7 @@ OrderedSet, ProxyDict, ) +from thunder.core.codeutils import is_literal import thunder.clang as clang from thunder.clang import ( empty, @@ -2293,6 +2294,8 @@ def iter_bound_symbols(bound_symbols): for symbol in bound_symbols: if symbol.sym.id in trace_interpreter_skip_list: continue + elif all(is_literal(sym_out) for sym_out in symbol.flat_outs): + continue elif symbol.output is None: continue else: From b50f9466f98b8871873bd4dbca952d647a53b171 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 7 Feb 2025 14:46:32 +0200 Subject: [PATCH 3/4] Store vjp residuals on all outputs; retrieve them from the first proxy --- thunder/core/transforms.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 0b89a21360..abe0f1a0cf 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2302,6 +2302,19 @@ def iter_bound_symbols(bound_symbols): yield symbol +def get_first_proxy(proxies) -> Proxy | None: + """Get the first proxy from a list of proxies. + + Args: + proxies (List[Proxy]): List of proxies + + Returns: + Proxy: First proxy from the list + """ + proxies = sequencify(proxies) + return next((proxy for proxy in proxies if isinstance(proxy, Proxy)), None) + + def deconstruct_forward_env_for_backward(trace, env): # Note [Saving the forward environment in the backward rule] # We cannot save the trace object in the residuals because executors may not @@ -2313,7 +2326,7 @@ def deconstruct_forward_env_for_backward(trace, env): # arguments. See test_grad.py:test_torch_autograd_function for an example # where this is tested. bound_symbols = iter_bound_symbols(trace.bound_symbols) - saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols) + saved_for_backward = tuple(env[get_first_proxy(symbol.output).name].residuals for symbol in bound_symbols) return saved_for_backward @@ -2323,7 +2336,7 @@ def reconstruct_forward_env_for_backward(trace, saved_for_backward): reconstructed_env = {} for idx, sym in enumerate(bound_symbols): - k = sequencify(sym.output)[0].name + k = get_first_proxy(sym.output).name v = VJPDual(None, saved_for_backward[idx]) reconstructed_env[k] = v @@ -2572,12 +2585,7 @@ def vjp_impl_const(symbol, *args, **kwargs): def _vjp_impl(*args, **kwargs): primals, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs)) out_primal, out_residuals = vjp_impl(*primals, **kwargs) - # We are saving the residuals and pullback only in the first output - # backward_pass then retrieves the residuals and pullback from the first output - if isinstance(out_primal, Sequence): - return (VJPDual(out_primal[0], out_residuals), *(VJPDual(o, tuple()) for o in out_primal[1:])) - - return (VJPDual(out_primal, out_residuals),) + return tree_map(lambda x: VJPDual(x, out_residuals), sequencify(out_primal)) return _vjp_impl @@ -2732,7 +2740,7 @@ def put_grad(v: Variable, val: Any) -> None: # Having a single cotangent is a common case, so we flatten it # Otherwise, we will need to rewrite the pullback functions cotangents = tree_flatten(cotangents)[0] - residuals = forward_env[symbol_output[0].name].residuals + residuals = forward_env[get_first_proxy(symbol_output).name].residuals if is_constant_for_vjp(symbol): # We can skip the pullback if all the arguments are constant continue From 8d053ea8d79c38d27e9dce40d22a1f9e772389ac Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 7 Feb 2025 14:53:00 +0200 Subject: [PATCH 4/4] Add test from issue #1732 --- thunder/tests/test_grad.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 718465538c..0c9cdc6722 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1934,6 +1934,18 @@ def func(x): torch.testing.assert_close(actual_gr, expected_gr) +# See https://github.com/Lightning-AI/lightning-thunder/issues/1732 +def test_symbolic_shape_for_backward_issue_1732(): + @partial(thunder.jit, cache="symbolic values") + def f(a, b): + return a * b + + a = make_tensor((1, 32, 232, 232), device="cpu", dtype=torch.float32, requires_grad=True) + b = make_tensor((1, 1, 232, 232), device="cpu", dtype=torch.float32, requires_grad=False) + out = f(a, b) + out.backward(torch.ones_like(out)) + + @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): if device == "cuda" and not torch.cuda.is_available():