Skip to content

Commit 1e44663

Browse files
committed
[WIP][RFC] TorchFT integration
Summary: This is a WIP TorchFT integration PR. Test Plan: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 514fd10 Pull Request resolved: #806
1 parent cca0702 commit 1e44663

File tree

9 files changed

+268
-67
lines changed

9 files changed

+268
-67
lines changed

run_llama_train.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/checkpoint.py

Lines changed: 123 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
from dataclasses import dataclass, field
1414
from io import BytesIO
1515
from multiprocessing import get_context
16-
from typing import Any, Dict, List, Union
16+
from typing import Any, Dict, List, Optional, Union
1717

1818
import torch
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
2223
from torch.distributed.checkpoint.state_dict import (
2324
get_model_state_dict,
2425
set_model_state_dict,
@@ -143,16 +144,19 @@ def __init__(
143144
lr_schedulers: SchedulersContainer,
144145
states: Dict[str, Any],
145146
job_config: JobConfig,
147+
ft_manager: Optional[Any] = None,
146148
) -> None:
147149
ckpt_config = job_config.checkpoint
148150
self.enable_checkpoint = ckpt_config.enable_checkpoint
149-
self.keep_latest_k = ckpt_config.keep_latest_k
151+
self.ft_manager = ft_manager
152+
self.enable_staging = (
153+
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
154+
) or self.ft_manager
150155

151-
if not self.enable_checkpoint:
156+
if not self.enable_checkpoint and self.ft_manager is None:
152157
return
153-
"""
154-
Note: Pipeline Parallelism and Virtual Stages
155158

159+
<<<<<<< HEAD
156160
1. even for simple PP schedules, there is a separate optimizer each PP rank.
157161
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
158162
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
@@ -186,6 +190,18 @@ def __init__(
186190
"lr_scheduler": lr_schedulers,
187191
}
188192
)
193+
=======
194+
self._initialize_states(
195+
states, dataloader, model_parts, optimizers, lr_schedulers
196+
)
197+
198+
async_mode = ckpt_config.async_mode.lower()
199+
self.staging = False
200+
self.sending_to_checkpoint_mp = False
201+
self.staging_id = None
202+
self.cpu_offload_state_dict = None
203+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
204+
>>>>>>> 3430d99 ([WIP][RFC] TorchFT integration)
189205
190206
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
191207
self.interval_type = (
@@ -201,6 +217,7 @@ def __init__(
201217
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
202218
self.pg = dist.new_group(backend="gloo")
203219
220+
self.keep_latest_k = ckpt_config.keep_latest_k
204221
self.model_weights_only = ckpt_config.model_weights_only
205222
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
206223
@@ -224,10 +241,6 @@ def __init__(
224241
daemon=True,
225242
)
226243
self.mp.start()
227-
self.cpu_offload_state_dict = None
228-
self.staging = False
229-
self.staging_id = None
230-
self.staging_stream = torch.cuda.Stream()
231244
else:
232245
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
233246
@@ -241,8 +254,61 @@ def __del__(self):
241254
self.mp.join()
242255
243256
def reset(self) -> None:
257+
# We need to stage the local state if another replicate joins during the
258+
# first step.
259+
if self.ft_manager:
260+
self.cpu_staging(None)
244261
self.begin_time = time.monotonic()
245262
263+
def _initialize_states(
264+
self,
265+
states: Dict[str, Any],
266+
dataloader: DataLoader,
267+
model_parts: List[nn.Module],
268+
optimizers: OptimizersContainer,
269+
lr_schedulers: SchedulersContainer,
270+
) -> None:
271+
"""
272+
Note: Pipeline Parallelism and Virtual Stages
273+
274+
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
275+
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
276+
original model. rank1's would _also_ have a param_group[0], since it's index based,
277+
but referring to layers.1.
278+
When saving, these collide and one of them is lost. Then when reloading, only one
279+
stage can restore its optimizer states, others will error.
280+
281+
The solution to this problem is optimizer flattening: it landed in #127071
282+
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
283+
kwarg to DCP functions called in the OptimizerContainer.
284+
285+
2. With complex PP schedules, we have multiple model chunks per pp rank. This
286+
compounds challenge (1) by also requiring us to reason about multiple 'optim'
287+
objects locally.
288+
289+
We solve this in the Model and Optimizer wrapper classes by flattening the
290+
state dicts from each object into one state dict before saving/loading.
291+
We rely on the individual state_dicts to not collide, which is gauranteed for
292+
the model by correct pipeline splitting and for the optimizer by the flattening
293+
support described in (1).
294+
295+
3. LR schedulers also index model states like optimizers and would need to be
296+
flattened properly to support resharding. Unfortunately, the implementations of
297+
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
298+
hard to write a generic 'flattener' utility.
299+
300+
TODO: This is currently unsolved and needs a fix.
301+
"""
302+
self.states = states
303+
self.states.update(
304+
{
305+
"model": ModelWrapper(model_parts),
306+
"optimizer": optimizers,
307+
"dataloader": dataloader,
308+
}
309+
)
310+
self.states.update(lr_schedulers.get_lr_scheduler_state())
311+
246312
def _create_checkpoint_id(self, step: int) -> str:
247313
return os.path.join(self.folder, f"step-{step}")
248314
@@ -325,31 +391,8 @@ def _async_wait(self) -> None:
325391
self.async_future.result()
326392
327393
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
328-
try:
329-
from torch.distributed._state_dict_utils import (
330-
_copy_state_dict,
331-
_create_cpu_state_dict,
332-
)
333-
except ImportError as e:
334-
raise ImportError(
335-
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
336-
) from e
337-
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
338-
if self.cpu_offload_state_dict is None:
339-
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
340-
self.cpu_offload_state_dict = _create_cpu_state_dict(
341-
state_dict, pin_memory=True, share_memory=True
342-
)
343-
344-
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
345-
with torch.cuda.stream(self.staging_stream):
346-
self.cpu_offload_state_dict = _copy_state_dict(
347-
state_dict,
348-
self.cpu_offload_state_dict,
349-
non_blocking=True,
350-
)
351-
self.staging = True
352-
self.staging_id = checkpoint_id
394+
self.cpu_staging(checkpoint_id)
395+
self.sending_to_checkpoint_mp = True
353396
354397
def save(self, curr_step: int, force: bool = False) -> None:
355398
"""
@@ -359,6 +402,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
359402
for initial seed checkpoint.
360403
"""
361404
if not self._should_save(curr_step, force):
405+
if self.ft_manager:
406+
self.cpu_staging(None)
362407
return
363408
364409
begin = time.monotonic()
@@ -382,26 +427,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
382427
f"in {time.monotonic() - begin:.2f} seconds."
383428
)
384429
430+
def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
431+
"""Offload state_dict to CPU memory"""
432+
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
433+
if self.cpu_offload_state_dict is None:
434+
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
435+
self.cpu_offload_state_dict = _create_cpu_state_dict(
436+
state_dict, pin_memory=True, share_memory=True
437+
)
438+
439+
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
440+
with torch.cuda.stream(self.staging_stream):
441+
self.cpu_offload_state_dict = _copy_state_dict(
442+
state_dict,
443+
self.cpu_offload_state_dict,
444+
non_blocking=True,
445+
)
446+
self.staging = True
447+
self.staging_id = checkpoint_id
448+
449+
def wait_for_staging(self) -> None:
450+
if not self.staging_stream.query():
451+
self.staging_stream.synchronize()
452+
self.staging = False
453+
454+
def staging_results(self) -> Dict[str, Any]:
455+
self.maybe_wait_for_staging()
456+
return self.cpu_offload_state_dict
457+
385458
def maybe_wait_for_staging(self) -> None:
386-
if (
387-
self.enable_checkpoint
388-
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
389-
and self.staging
390-
):
391-
if not self.staging_stream.query():
392-
self.staging_stream.synchronize()
393-
394-
def sync_func():
395-
self.mp_queue_send.put_nowait(
396-
(self.cpu_offload_state_dict, self.staging_id)
397-
)
398-
399-
# This may be a faster way to do zero-overhead checkpointing staging
400-
# checkpointing but we need more thorough investigation before
401-
# swithing to this method.
402-
# self.my_thread = threading.Thread(target=func).start()
403-
sync_func()
404-
self.staging = False
459+
if self.enable_staging and self.staging:
460+
self.wait_for_staging()
461+
462+
if self.sending_to_checkpoint_mp:
463+
# Copy the sync staging result to another process.
464+
def sync_func():
465+
self.mp_queue_send.put_nowait(
466+
(self.cpu_offload_state_dict, self.staging_id)
467+
)
468+
469+
# This may be a faster way to do zero-overhead checkpointing staging
470+
# checkpointing but we need more thorough investigation before
471+
# swithing to this method.
472+
# self.my_thread = threading.Thread(target=func).start()
473+
sync_func()
474+
self.sending_to_checkpoint_mp = False
405475

406476
def load(self, step: int = -1) -> bool:
407477
if not self.enable_checkpoint:

torchtitan/config_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,19 @@ def __init__(self):
585585
action="store_true",
586586
)
587587

588+
self.parser.add_argument(
589+
"--experimental.enable_torchft",
590+
action="store_true",
591+
help="Enable TorchFT integration.",
592+
)
593+
594+
self.parser.add_argument(
595+
"--experimental.ft_replica_group_id",
596+
type=int,
597+
default=-1,
598+
help="The FT replicate group of this run.",
599+
)
600+
588601
def to_dict(self):
589602
return self.args_dict
590603

torchtitan/ft.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import importlib
2+
from typing import Any, Callable, Optional
3+
4+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
5+
6+
from torchtitan.config_manager import JobConfig
7+
8+
if importlib.util.find_spec("torchft") is not None:
9+
import torchft as ft
10+
11+
has_torchft = True
12+
else:
13+
has_torchft = False
14+
15+
16+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
17+
"""
18+
Initialize the FT manager for the given job.
19+
"""
20+
if not job.experimental.enable_torchft:
21+
return None
22+
23+
if not has_torchft:
24+
raise ImportError("torchft is not installed. Please install it.")
25+
26+
pg = ft.ProcessGroupBabyNCCL()
27+
manager = ft.Manager(
28+
pg=pg,
29+
min_replica_size=1,
30+
load_state_dict=None,
31+
state_dict=None,
32+
use_async_quorum=True,
33+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
34+
)
35+
36+
return manager
37+
38+
39+
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
40+
"""
41+
Set the state dict for the given manager.
42+
"""
43+
if manager is None:
44+
return
45+
46+
def state_dict():
47+
ret = {}
48+
for k, v in ckpt_manager.staging_results().items():
49+
if k in {"model", "optimizer", "lr_schedulers"}:
50+
ret[k] = v
51+
return ret
52+
53+
def load_state_dict(state_dict):
54+
assert state_dict is not None
55+
for k, v in state_dict.items():
56+
ckpt_manager.states[k].load_state_dict(v)
57+
58+
manager.set_state_dict_fns(load_state_dict, state_dict)

0 commit comments

Comments
 (0)