From b98b52844c95ac1b3528242bf818e8cf48111713 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Thu, 2 Jul 2026 08:27:19 -0700 Subject: [PATCH] fix: restore FSDP backward hooks for fused experts Signed-off-by: hongbinl --- megatron/core/transformer/moe/experts.py | 22 ++++++++++ .../transformer/moe/test_grouped_mlp.py | 40 ++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 0504d3ccd39..a47c3c7c238 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -577,6 +577,7 @@ def _make_fused_ops(self) -> torch.nn.Module: # Emulate submodule pre-forward hooks ops.register_forward_pre_hook(self._make_fused_impl_pre_forward_hook()) + ops.register_forward_hook(self._make_fused_impl_post_forward_hook()) return ops @@ -605,6 +606,27 @@ def forward_pre_hook(module, *_) -> None: return forward_pre_hook + def _make_fused_impl_post_forward_hook(self) -> Callable: + """Forward submodule hooks to the fused output. + + Megatron FSDP uses GroupedLinear forward hooks to attach parameter + all-gathers immediately before backward. The op fuser bypasses the + GroupedLinear module calls, so attach those hooks to the fused MLP output. + """ + + def forward_post_hook(_module, _inputs, output): + for submodule in chain(self.linear_fc1.modules(), self.linear_fc2.modules()): + for hook_id, hook in submodule._forward_hooks.items(): + if hook_id in submodule._forward_hooks_with_kwargs: + ret = hook(submodule, (), {}, output) + else: + ret = hook(submodule, (), output) + if ret is not None: + output = ret + return output + + return forward_post_hook + def _fused_forward( self, permuted_local_hidden_states: torch.Tensor, diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 51ded378af4..c340b08b745 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import argparse import sys @@ -88,6 +88,9 @@ class FakeSequential(list): def register_forward_pre_hook(self, hook): self.forward_pre_hook = hook + def register_forward_hook(self, hook): + self.forward_post_hook = hook + fake_te = SimpleNamespace( pytorch=SimpleNamespace( GroupedLinear=FakeGroupedLinear, @@ -148,6 +151,7 @@ def register_forward_pre_hook(self, hook): assert ops[2].device == "meta" assert ops[2].weight is module.linear_fc2.weight assert hasattr(ops, "forward_pre_hook") + assert hasattr(ops, "forward_post_hook") def test_fused_forward_caches_ops_and_forwards_expected_arguments(): @@ -254,6 +258,34 @@ def test_make_fused_impl_pre_forward_hook_rejects_input_modifying_hook(): hook(object()) +def test_make_fused_impl_post_forward_hook_dispatches_submodule_hooks(): + module = TEGroupedMLP.__new__(TEGroupedMLP) + torch.nn.Module.__init__(module) + fc1_child = torch.nn.Linear(2, 2) + fc2_child = torch.nn.Linear(2, 2) + module.linear_fc1 = torch.nn.Sequential(fc1_child) + module.linear_fc2 = torch.nn.Sequential(fc2_child) + + calls = [] + + def fc1_hook(submodule, _inputs, output): + calls.append(("fc1", submodule)) + return output + 1 + + def fc2_hook(submodule, _inputs, _kwargs, output): + calls.append(("fc2", submodule)) + return output + 1 + + fc1_child.register_forward_hook(fc1_hook) + fc2_child.register_forward_hook(fc2_hook, with_kwargs=True) + + hook = module._make_fused_impl_post_forward_hook() + output = hook(None, (), torch.zeros(2, 2)) + + assert {label for label, _ in calls} == {"fc1", "fc2"} + torch.testing.assert_close(output, torch.full_like(output, 2)) + + def test_make_fused_ops_handles_single_grouped_weight_for_fc1(monkeypatch): class FakeGroupedLinear(torch.nn.Module): def __init__( @@ -295,6 +327,9 @@ class FakeSequential(list): def register_forward_pre_hook(self, hook): self.forward_pre_hook = hook + def register_forward_hook(self, hook): + self.forward_post_hook = hook + fake_te = SimpleNamespace( pytorch=SimpleNamespace( GroupedLinear=FakeGroupedLinear, @@ -420,6 +455,9 @@ class FakeSequential(list): def register_forward_pre_hook(self, hook): self.forward_pre_hook = hook + def register_forward_hook(self, hook): + self.forward_post_hook = hook + return ( SimpleNamespace( pytorch=SimpleNamespace(