diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 9f95fc7572..6b16485837 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -537,19 +537,12 @@ def match_fw_and_bw_saved_for_bw_proxies( Returns: new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies """ - - old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1]) - old_saved_for_backward_bw = [] - for bsym in bw_trace.bound_symbols: - if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE: - flattened_args = tree_flatten(bw_trace.args[1])[0] - proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)} - if all( - not isinstance(out, CollectionProxy) and out.name not in proxy_names - for out in bsym.flat_outs - if out is not None - ): - old_saved_for_backward_bw += bsym.flat_outs + old_saved_tensors_fw = fw_trace.output[1][0] + old_saved_other_fw = fw_trace.output[1][1] + old_saved_for_backward_fw = (*old_saved_tensors_fw, *old_saved_other_fw) + old_saved_tensors_bw = bw_trace.args[0][0] + old_saved_other_bw = bw_trace.args[0][1] + old_saved_for_backward_bw = (*old_saved_tensors_bw, *old_saved_other_bw) assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw) new_required_for_backward_bw_to_fw_map = { x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e5d0b67bea..abe0f1a0cf 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -56,6 +56,7 @@ OrderedSet, ProxyDict, ) +from thunder.core.codeutils import is_literal import thunder.clang as clang from thunder.clang import ( empty, @@ -2293,12 +2294,27 @@ def iter_bound_symbols(bound_symbols): for symbol in bound_symbols: if symbol.sym.id in trace_interpreter_skip_list: continue + elif all(is_literal(sym_out) for sym_out in symbol.flat_outs): + continue elif symbol.output is None: continue else: yield symbol +def get_first_proxy(proxies) -> Proxy | None: + """Get the first proxy from a list of proxies. + + Args: + proxies (List[Proxy]): List of proxies + + Returns: + Proxy: First proxy from the list + """ + proxies = sequencify(proxies) + return next((proxy for proxy in proxies if isinstance(proxy, Proxy)), None) + + def deconstruct_forward_env_for_backward(trace, env): # Note [Saving the forward environment in the backward rule] # We cannot save the trace object in the residuals because executors may not @@ -2310,7 +2326,7 @@ def deconstruct_forward_env_for_backward(trace, env): # arguments. See test_grad.py:test_torch_autograd_function for an example # where this is tested. bound_symbols = iter_bound_symbols(trace.bound_symbols) - saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols) + saved_for_backward = tuple(env[get_first_proxy(symbol.output).name].residuals for symbol in bound_symbols) return saved_for_backward @@ -2320,7 +2336,7 @@ def reconstruct_forward_env_for_backward(trace, saved_for_backward): reconstructed_env = {} for idx, sym in enumerate(bound_symbols): - k = sequencify(sym.output)[0].name + k = get_first_proxy(sym.output).name v = VJPDual(None, saved_for_backward[idx]) reconstructed_env[k] = v @@ -2569,12 +2585,7 @@ def vjp_impl_const(symbol, *args, **kwargs): def _vjp_impl(*args, **kwargs): primals, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs)) out_primal, out_residuals = vjp_impl(*primals, **kwargs) - # We are saving the residuals and pullback only in the first output - # backward_pass then retrieves the residuals and pullback from the first output - if isinstance(out_primal, Sequence): - return (VJPDual(out_primal[0], out_residuals), *(VJPDual(o, tuple()) for o in out_primal[1:])) - - return (VJPDual(out_primal, out_residuals),) + return tree_map(lambda x: VJPDual(x, out_residuals), sequencify(out_primal)) return _vjp_impl @@ -2729,7 +2740,7 @@ def put_grad(v: Variable, val: Any) -> None: # Having a single cotangent is a common case, so we flatten it # Otherwise, we will need to rewrite the pullback functions cotangents = tree_flatten(cotangents)[0] - residuals = forward_env[symbol_output[0].name].residuals + residuals = forward_env[get_first_proxy(symbol_output).name].residuals if is_constant_for_vjp(symbol): # We can skip the pullback if all the arguments are constant continue diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 718465538c..0c9cdc6722 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1934,6 +1934,18 @@ def func(x): torch.testing.assert_close(actual_gr, expected_gr) +# See https://github.com/Lightning-AI/lightning-thunder/issues/1732 +def test_symbolic_shape_for_backward_issue_1732(): + @partial(thunder.jit, cache="symbolic values") + def f(a, b): + return a * b + + a = make_tensor((1, 32, 232, 232), device="cpu", dtype=torch.float32, requires_grad=True) + b = make_tensor((1, 1, 232, 232), device="cpu", dtype=torch.float32, requires_grad=False) + out = f(a, b) + out.backward(torch.ones_like(out)) + + @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): if device == "cuda" and not torch.cuda.is_available():