Skip to content

Commit dfcb08d

Browse files
committed
Update on "[WIP][RFC] TorchFT integration"
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` [ghstack-poisoned]
2 parents ae425b8 + d95dff9 commit dfcb08d

File tree

2 files changed

+4
-39
lines changed

2 files changed

+4
-39
lines changed

torchtitan/checkpoint.py

+1-37
Original file line numberDiff line numberDiff line change
@@ -156,41 +156,6 @@ def __init__(
156156
if not self.enable_checkpoint and self.ft_manager is None:
157157
return
158158

159-
<<<<<<< HEAD
160-
1. even for simple PP schedules, there is a separate optimizer each PP rank.
161-
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
162-
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
163-
When saving, these collide and one of them is lost. Then when reloading, only one stage can
164-
restore its optimizer states, others will error.
165-
166-
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
167-
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
168-
169-
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
170-
requiring us to reason about multiple 'optim' objects locally.
171-
172-
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
173-
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
174-
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
175-
support described in (1).
176-
177-
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
178-
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
179-
optimizers do, so it's hard to write a generic 'flattener' utility.
180-
181-
TODO: This is currently unsolved and needs a fix.
182-
"""
183-
self.states = states
184-
185-
self.states.update(
186-
{
187-
"model": ModelWrapper(model_parts),
188-
"optimizer": optimizers,
189-
"dataloader": dataloader,
190-
"lr_scheduler": lr_schedulers,
191-
}
192-
)
193-
=======
194159
self._initialize_states(
195160
states, dataloader, model_parts, optimizers, lr_schedulers
196161
)
@@ -201,7 +166,6 @@ def __init__(
201166
self.staging_id = None
202167
self.cpu_offload_state_dict = None
203168
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
204-
>>>>>>> 3430d99 ([WIP][RFC] TorchFT integration)
205169

206170
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
207171
self.interval_type = (
@@ -305,9 +269,9 @@ def _initialize_states(
305269
"model": ModelWrapper(model_parts),
306270
"optimizer": optimizers,
307271
"dataloader": dataloader,
272+
"lr_scheduler": lr_schedulers,
308273
}
309274
)
310-
self.states.update(lr_schedulers.get_lr_scheduler_state())
311275

312276
def _create_checkpoint_id(self, step: int) -> str:
313277
return os.path.join(self.folder, f"step-{step}")

torchtitan/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,10 @@ def clip_grad_norm_(
408408
# If only using PP, total_norm will be a local tensor.
409409
mesh = total_norm._spec.mesh
410410
if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
411+
# The gradients along the replicated dim has been reduced.
412+
# So we don't need another reducution beforing removing the
413+
# replicate dimension
411414
local_tensor = total_norm.to_local()
412-
dist.all_reduce(local_tensor, op=dist.ReduceOp.AVG, group=mesh.replicate_pg)
413-
414415
placements = list(copy.copy(total_norm._spec.placements))
415416
placements.pop(mesh.replicate_dim)
416417
mesh = mesh.mesh

0 commit comments

Comments
 (0)