Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate TorchFT #834

Merged
merged 57 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
14a1d46
Update
fegin Feb 11, 2025
f9ff656
Update
fegin Feb 12, 2025
d443dcb
Update
fegin Feb 12, 2025
df758cc
Update
fegin Feb 12, 2025
59f6ec7
Update
fegin Feb 12, 2025
d871420
Update
fegin Feb 12, 2025
61f4a03
Update
fegin Feb 12, 2025
8e82769
Update
fegin Feb 12, 2025
c60d577
Update
fegin Feb 12, 2025
24f6aef
Update
fegin Feb 12, 2025
43c83e2
Update
fegin Feb 12, 2025
4eb6515
Update
fegin Feb 12, 2025
5e85f99
Update
fegin Feb 13, 2025
e7e14b8
Update
fegin Feb 13, 2025
dd4fa3a
Update
fegin Feb 13, 2025
4871b35
Update
fegin Feb 13, 2025
6652ec9
Update
fegin Feb 13, 2025
0eed0d0
Update
fegin Feb 13, 2025
a70b1eb
Update
fegin Feb 13, 2025
be50645
Update
fegin Feb 13, 2025
e7448d2
Update
fegin Feb 13, 2025
c9e32b4
Update
fegin Feb 13, 2025
06c669b
Update
fegin Feb 13, 2025
da07d3f
Update
fegin Feb 14, 2025
70d45d1
Update
fegin Feb 14, 2025
ed493d6
Update
fegin Feb 14, 2025
ee01914
Update
fegin Feb 14, 2025
8c731f9
Update
fegin Feb 19, 2025
542ce60
Update
fegin Feb 20, 2025
b4d2b96
Update
fegin Feb 21, 2025
a6741e7
Update
fegin Feb 21, 2025
89aad7d
Update
fegin Feb 21, 2025
5076158
Update
fegin Feb 24, 2025
821895b
Update
fegin Feb 25, 2025
e915263
Update
fegin Feb 25, 2025
a7a43e1
Update
fegin Feb 25, 2025
a2389c1
Update
fegin Feb 25, 2025
c0c86f4
Update
fegin Feb 25, 2025
10b5a4c
Update
fegin Feb 25, 2025
a288ee1
Update
fegin Feb 26, 2025
ba5992b
Update
fegin Feb 26, 2025
afb4187
Update
fegin Feb 26, 2025
f11c3bf
Update
fegin Feb 27, 2025
6c3ef30
Update
fegin Feb 27, 2025
3dd05fd
Update
fegin Feb 27, 2025
6def618
Update
fegin Feb 27, 2025
e4174a3
Update
fegin Feb 27, 2025
1da7c08
Update
fegin Feb 27, 2025
0ba37c5
Update
fegin Feb 27, 2025
c2b7c59
Update
fegin Feb 27, 2025
358098a
Update
fegin Feb 28, 2025
0a8ab01
Update
fegin Feb 28, 2025
1bd38af
Update
fegin Feb 28, 2025
46650fe
Update
fegin Feb 28, 2025
5cd025c
Update
fegin Feb 28, 2025
c4e82e9
Update
fegin Feb 28, 2025
89dd3cb
Update
fegin Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides
5 changes: 3 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.components.ft import init_ft_manager
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
Expand Down Expand Up @@ -102,7 +103,6 @@ def estimate_memory(job_config: JobConfig):
if not job_config.memory_estimation.disable_fake_mode
else contextlib.nullcontext()
):

logger.info(
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
)
Expand All @@ -122,7 +122,8 @@ def estimate_memory(job_config: JobConfig):
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers([model], job_config)
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers([model], job_config, ft_manager)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
23 changes: 23 additions & 0 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ class DummyJob:
dump_folder: str = "dummy_folder"


@dataclass
class DummyFaultTolerance:
replica_id = 0
group_size = 1


@dataclass
class DummyJobConfig:
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
job: DummyJob = field(default_factory=DummyJob)
fault_tolerance: DummyFaultTolerance = field(default_factory=DummyFaultTolerance)
ft_manager = None


# Dummy instances to supply as constructor arguments.
Expand Down Expand Up @@ -103,13 +111,16 @@ def tearDown(self):
def test_save(self, *_):
"""Test that calling save() writes a checkpoint file to disk."""
job_config = DummyJobConfig(job=self.dummy_job)
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
step = 20
manager.save(curr_step=step, force=True)
Expand All @@ -133,13 +144,16 @@ def test_save(self, *_):
def test_load(self, *_):
"""Test that load() properly reads the checkpoint file from disk and restores state."""
job_config = DummyJobConfig(job=self.dummy_job)
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
step = 30
manager.save(curr_step=step, force=True)
Expand Down Expand Up @@ -171,13 +185,16 @@ def test_purge_stale_checkpoints_rank_zero(self, *_):
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
Expand Down Expand Up @@ -215,13 +232,16 @@ def test_purge_stale_checkpoints_rank_nonzero(self, *_):
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
Expand Down Expand Up @@ -252,13 +272,16 @@ def test_async_save_calls_async_wait(self, *_):
# Set async_mode to "async" in the job configuration.
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.async_mode = "async"
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
# First save: should schedule an async save.
manager.save(curr_step=10, force=False)
Expand Down
142 changes: 115 additions & 27 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.components.ft import FTManager
from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.tools.logging import init_logger, logger
Expand Down Expand Up @@ -214,6 +215,19 @@ class CheckpointManager:
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
with the assumption that all lr_schedulers have the same state_dict.

Note: TorchFT checkpointing flow

There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
checkpoint, 2) the per-replica checkpoint.

The full perisistent checkpoint is saved by the replica with
``ft_manager.participating_rank() == 0``. It contains everything including the model,
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
checkpoint is loaded by all replicas. However, we can optimize it to only load if
there are no other alive replicas.

The per-replica checkpoint contains only the dataloader and is saved/loaded by all
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.

Args:
dataloader (DataLoader): The dataloader used to load the data.
Expand All @@ -223,6 +237,7 @@ class CheckpointManager:
states (Dict[str, Any]): The states that need to be saved, other than the
previous 4 components.
job_config (JobConfig): The job config used to configure the checkpointing.
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
"""

def __init__(
Expand All @@ -233,16 +248,41 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: FTManager,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager.manager if ft_manager.enabled else None

if self.ft_manager:
optimizers.init_cache_state_dict()

def state_dict():
ret = {}
for k, v in self.states.items():
if k in {
MODEL,
OPTIMIZER,
LR_SCHEDULER,
TRAIN_STATE,
}:
ret[k] = v.state_dict()
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
self.states[k].load_state_dict(v)

self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
self.ft_replica_id = job_config.fault_tolerance.replica_id

async_mode = ckpt_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
)
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return

self.states = states
Expand All @@ -254,6 +294,13 @@ def __init__(
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.staging = False
self.sending_to_checkpoint_mp = False
Expand All @@ -264,7 +311,7 @@ def __init__(
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC:
if async_mode == AsyncMode.ASYNC or self.ft_manager:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
Expand Down Expand Up @@ -339,35 +386,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
None
"""

if self.ft_manager:
self._ft_save(curr_step)

if not self._should_save(curr_step, force):
return

begin = time.monotonic()
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
)

@torch.no_grad()
def load(self, step: int = -1) -> bool:
Expand All @@ -384,6 +440,9 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint or not os.path.isdir(self.folder):
return False

Expand Down Expand Up @@ -467,10 +526,36 @@ def _find_load_step(self, folder: str = "") -> int:
return -1
return max(step_counts)

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
folder = folder if folder else self.folder
return os.path.join(folder, f"step-{step}")

def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return

begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _states_to_load(self, step: int) -> Dict[str, Any]:
"""Determines which states to load for the given step.

Expand All @@ -491,6 +576,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)
return states_to_load

def _save_last_step(self, curr_step: int) -> None:
Expand Down Expand Up @@ -577,6 +664,7 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
Loading