Skip to content

Commit

Permalink
support grad accumulation for fsdp2 (#981)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #981

Reviewed By: diego-urgell

Differential Revision: D70720659

fbshipit-source-id: 84f2af7c7c06c9f730518ebdb17aed3d8cc31a55
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Mar 7, 2025
1 parent 99c5cda commit 4059cc4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
30 changes: 30 additions & 0 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,36 @@ def test_detect_anomaly_disabled_with_torch_compile(self) -> None:

self.assertIsNone(auto_unit.detect_anomaly)

@patch("torchtnt.framework.auto_unit._is_fsdp2_module", return_value=True)
def test_gradient_accumulation_fsdp2(self, _) -> None:
auto_unit = DummyAutoUnit(
module=torch.nn.Linear(1, 1),
gradient_accumulation_steps=3,
)

# Dynamically add a mocked method as an attribute to module
fsdp_module_mock = MagicMock()
auto_unit.module.set_requires_gradient_sync = fsdp_module_mock
auto_unit._is_last_batch = False

state = get_dummy_train_state()

# Simulate train steps
for step in range(4):
# Call train_step to trigger set_requires_gradient_sync
auto_unit.train_step(state, (torch.rand(1, 1), torch.rand(1, 1)))

# Check if set_requires_gradient_sync is called with the correct boolean
if (step + 1) % auto_unit.gradient_accumulation_steps == 0:
auto_unit.module.set_requires_gradient_sync.assert_called_with(True)
else:
auto_unit.module.set_requires_gradient_sync.assert_called_with(False)

# Reset mock for the next iteration
auto_unit.module.set_requires_gradient_sync.reset_mock()

auto_unit.train_progress.increment_step()


Batch = Tuple[torch.Tensor, torch.Tensor]

Expand Down
15 changes: 12 additions & 3 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass
from typing import (
Any,
cast,
ContextManager,
Generic,
Iterator,
Expand All @@ -26,7 +27,7 @@

import torch
from pyre_extensions import none_throws
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.swa_utils import SWALR
from torchtnt.framework._unit_utils import _step_requires_iterator
Expand All @@ -42,6 +43,7 @@
GradScaler,
)
from torchtnt.utils.prepare_module import (
_is_fsdp2_module,
_is_fsdp_module,
ActivationCheckpointParams,
FSDPStrategy,
Expand Down Expand Up @@ -672,12 +674,19 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
# https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync
maybe_no_sync = (
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
module.no_sync()
if not should_update_weights
and (isinstance(module, DDP) or _is_fsdp_module(module))
and (isinstance(module, DDP) or isinstance(module, FSDP))
else contextlib.nullcontext()
)
# fsdp2 has separate way of disabling gradient sync
if _is_fsdp2_module(module):
if not should_update_weights:
cast(FSDPModule, module).set_requires_gradient_sync(False)
elif should_update_weights and self.gradient_accumulation_steps > 1:
# if gradient accumulation is used and it's time to update weights,
# we need to re-enable gradient sync
cast(FSDPModule, module).set_requires_gradient_sync(True)

# if detect_anomaly is true, run forward and backward pass in detect_anomaly context
detect_anomaly = self.detect_anomaly
Expand Down

0 comments on commit 4059cc4

Please sign in to comment.