diff --git a/docs/checkpoint.md b/docs/checkpoint.md index d5dc36487..198412c2b 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -20,7 +20,6 @@ In your torchtitan training config, ensure that `enable_checkpoint` is set to Tr [checkpoint] enable_checkpoint = true folder = "checkpoint" -interval_type = "steps" interval = 500 ``` @@ -47,7 +46,6 @@ export_dtype = "bfloat16" [checkpoint] enable_checkpoint = true folder = "checkpoint" -interval_type = "steps" interval = 10 load_step = 5 model_weights_only = true diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py new file mode 100644 index 000000000..0a701d1d5 --- /dev/null +++ b/tests/unit_tests/test_checkpoint.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from unittest import mock + +import torch + +from torchtitan.components.checkpoint import CheckpointManager + + +def fake_dcp_save(state, checkpoint_id): + state = {k: v.state_dict() for k, v in state.items()} + os.makedirs(checkpoint_id, exist_ok=True) + torch.save(state, os.path.join(checkpoint_id, "state.pt")) + + +def fake_dcp_load(state, checkpoint_id): + state["trainer"].dcp_load_is_called = 7312 + + +def fake_async_save(state, checkpoint_id, process_group): + def run_save(): + fake_dcp_save(state, checkpoint_id) + + with ThreadPoolExecutor(max_workers=1) as executor: + f = executor.submit(run_save) + + mock_future = mock.Mock() + mock_future.result = mock.Mock(side_effect=f.result) + return mock_future + + +def fake_get_model_state_dict(model, *args, **kwargs): + return model.state_dict() + + +@dataclass +class DummyCheckpointConfig: + enable_checkpoint: bool = True + folder: str = "dummy_folder" + interval: int = 10 + async_mode: str = "disabled" + keep_latest_k: int = 0 + model_weights_only: bool = False + export_dtype: str = "float32" + exclude_from_loading = [] + + +@dataclass +class DummyJob: + dump_folder: str = "dummy_folder" + + +@dataclass +class DummyJobConfig: + checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig) + job: DummyJob = field(default_factory=DummyJob) + + +# Dummy instances to supply as constructor arguments. +dummy_dataloader = mock.Mock() +dummy_dataloader.state_dict = mock.Mock(side_effect=lambda: {"dataloader": 1}) +dummy_model_parts = [mock.Mock()] +dummy_model_parts[0].state_dict = mock.Mock(side_effect=lambda: {"model": 2}) +dummy_optimizers = mock.Mock() +dummy_optimizers.state_dict = mock.Mock(side_effect=lambda: {"optimizer": 3}) +dummy_lr_schedulers = mock.Mock() +dummy_lr_schedulers.state_dict = mock.Mock(side_effect=lambda: {"lr_scheduler": 4}) + + +class TestCheckpointManager(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + self.dummy_job = DummyJob(dump_folder=self.temp_dir) + self.job_config = DummyJobConfig(job=self.dummy_job) + self.checkpoint_folder = os.path.join( + self.dummy_job.dump_folder, self.job_config.checkpoint.folder + ) + os.makedirs(self.checkpoint_folder, exist_ok=True) + self.trainer_state = mock.Mock() + self.trainer_state.state_dict = mock.Mock(side_effect=lambda: {"my_state": 765}) + + def tearDown(self): + # Remove the temporary directory after each test. + shutil.rmtree(self.temp_dir) + + @mock.patch( + "torchtitan.components.checkpoint.get_model_state_dict", + side_effect=fake_get_model_state_dict, + ) + @mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save) + def test_save(self, *_): + """Test that calling save() writes a checkpoint file to disk.""" + job_config = DummyJobConfig(job=self.dummy_job) + manager = CheckpointManager( + dummy_dataloader, + dummy_model_parts, + dummy_optimizers, + dummy_lr_schedulers, + {"trainer": self.trainer_state}, + job_config, + ) + step = 20 + manager.save(curr_step=step, force=True) + state_file = self._checkpoint_id(step) + self.assertTrue( + os.path.exists(state_file), "The checkpoint file was not created on disk." + ) + loaded_state = torch.load(state_file, weights_only=False) + self.assertEqual( + loaded_state["trainer"]["my_state"], + 765, + "Saved state does not match expected value.", + ) + + @mock.patch( + "torchtitan.components.checkpoint.get_model_state_dict", + side_effect=fake_get_model_state_dict, + ) + @mock.patch("torchtitan.components.checkpoint.dcp.load", side_effect=fake_dcp_load) + @mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save) + def test_load(self, *_): + """Test that load() properly reads the checkpoint file from disk and restores state.""" + job_config = DummyJobConfig(job=self.dummy_job) + manager = CheckpointManager( + dummy_dataloader, + dummy_model_parts, + dummy_optimizers, + dummy_lr_schedulers, + {"trainer": self.trainer_state}, + job_config, + ) + step = 30 + manager.save(curr_step=step, force=True) + # Simulate a state change. + manager.states["test"] = 999 + success = manager.load(step=step) + self.assertTrue( + success, + "The load() method should have returned True for an existing checkpoint.", + ) + self.assertTrue(hasattr(manager.states["trainer"], "dcp_load_is_called")) + + self.assertEqual( + manager.states["trainer"].dcp_load_is_called, + 7312, + "The state was not correctly restored after loading.", + ) + + @mock.patch("torchtitan.components.checkpoint.dist.get_rank", return_value=0) + @mock.patch( + "torchtitan.components.checkpoint.get_model_state_dict", + side_effect=fake_get_model_state_dict, + ) + @mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save) + def test_purge_stale_checkpoints_rank_zero(self, *_): + """ + Test that when keep_latest_k is 3 and dist.get_rank() returns 0, stale checkpoints + are purged by placing the correct paths into the purge queue. + """ + job_config = DummyJobConfig(job=self.dummy_job) + job_config.checkpoint.keep_latest_k = 3 + manager = CheckpointManager( + dummy_dataloader, + dummy_model_parts, + dummy_optimizers, + dummy_lr_schedulers, + {"trainer": self.trainer_state}, + job_config, + ) + steps = [10, 20, 30, 40, 50] + for s in steps: + manager.save(curr_step=s, force=False) + while not manager.purge_queue.empty(): + time.sleep(1) + time.sleep(1) + os.sync() + expected_paths = [ + os.path.join(self.checkpoint_folder, "step-30"), + os.path.join(self.checkpoint_folder, "step-40"), + os.path.join(self.checkpoint_folder, "step-50"), + ] + for step in [10, 20]: + self.assertFalse( + os.path.exists(self._checkpoint_id(step)), + "The checkpoint is not purged.", + ) + + for step in [30, 40, 50]: + self.assertTrue( + os.path.exists(self._checkpoint_id(step)), "The checkpointis purged." + ) + + @mock.patch("torchtitan.components.checkpoint.dist.get_rank", return_value=1) + @mock.patch( + "torchtitan.components.checkpoint.get_model_state_dict", + side_effect=fake_get_model_state_dict, + ) + @mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save) + def test_purge_stale_checkpoints_rank_nonzero(self, *_): + """ + Test that when dist.get_rank() returns a non-zero value, the purge logic does not + place any paths in the purge queue. + """ + job_config = DummyJobConfig(job=self.dummy_job) + job_config.checkpoint.keep_latest_k = 3 + manager = CheckpointManager( + dummy_dataloader, + dummy_model_parts, + dummy_optimizers, + dummy_lr_schedulers, + {"trainer": self.trainer_state}, + job_config, + ) + steps = [10, 20, 30, 40, 50] + for s in steps: + manager.save(curr_step=s, force=False) + while not manager.purge_queue.empty(): + time.sleep(1) + time.sleep(1) + os.sync() + + for step in [10, 20, 30, 40, 50]: + self.assertTrue( + os.path.exists(self._checkpoint_id(step)), "The checkpointis purged." + ) + + @mock.patch("torchtitan.components.checkpoint.dist.new_group") + @mock.patch( + "torchtitan.components.checkpoint.get_model_state_dict", + side_effect=fake_get_model_state_dict, + ) + @mock.patch( + "torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save + ) + def test_async_save_calls_async_wait(self, *_): + """ + Test that in async mode (AsyncMode.ASYNC), calling save() twice correctly waits + on the previous async future via _async_wait(). + """ + # Set async_mode to "async" in the job configuration. + job_config = DummyJobConfig(job=self.dummy_job) + job_config.checkpoint.async_mode = "async" + manager = CheckpointManager( + dummy_dataloader, + dummy_model_parts, + dummy_optimizers, + dummy_lr_schedulers, + {"trainer": self.trainer_state}, + job_config, + ) + # First save: should schedule an async save. + manager.save(curr_step=10, force=False) + f = manager.async_future + f.result.assert_not_called() + manager.save(curr_step=20, force=False) + f.result.assert_called_once() + f = manager.async_future + f.result.assert_not_called() + + def _checkpoint_id(self, step): + checkpoint_id = os.path.join(self.checkpoint_folder, f"step-{step}") + state_file = os.path.join(checkpoint_id, "state.pt") + return state_file + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e4394d595..38ab77d99 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -6,19 +6,23 @@ import enum import functools +import multiprocessing as mp import os +import queue import re import shutil +import threading import time from dataclasses import dataclass, field from io import BytesIO from multiprocessing import get_context -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -33,9 +37,11 @@ from torchtitan.tools.utils import GarbageCollection -class IntervalType(enum.Enum): - SECONDS = enum.auto() - STEPS = enum.auto() +MODEL = "model" +OPTIMIZER = "optimizer" +LR_SCHEDULER = "lr_scheduler" +DATALOADER = "dataloader" +TRAIN_STATE = "train_state" class AsyncMode(str, enum.Enum): @@ -44,6 +50,8 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +# TODO: move this out from checkpoint.py and merge it with the trainer.py +# We probably want to create a Trainer object. @dataclass class TrainState(Stateful): step: int = 0 @@ -84,12 +92,13 @@ def load_state_dict(self, state_dict) -> None: class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - - def state_dict(self) -> Dict[str, Any]: - return { + self.cache_state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } + def state_dict(self) -> Dict[str, Any]: + return self.cache_state_dict + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: func = functools.partial( set_model_state_dict, @@ -113,7 +122,15 @@ def save_with_gc(state, checkpoint_id): GarbageCollection.collect("GC collection invoked by checkpointer.") -def checkpoint_mp(recv, send): +def checkpoint_mp(recv: mp.Queue, send: mp.Queue): + """Process to save the checkpoint in the background. + + This is only used when async_checkpoint_with_pinned_memory is enabled. + + Args: + recv (mp.Queue): The queue to receive the state_dict and Terminate signal. + send (mp.Queue): The queue to send the SaveDone signal. + """ init_logger() os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2) os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" @@ -134,15 +151,80 @@ def checkpoint_mp(recv, send): state, checkpoint_id = obj save_with_gc(state, checkpoint_id=checkpoint_id) logger.info( - "Finish saving the checkpoint in the background process in " - f"{time.monotonic() - begin:.2f} seconds." + "Finish saving the checkpoint in the background process in %.2f seconds.", + time.monotonic() - begin, ) finally: logger.info("Destroying the process group.") dist.destroy_process_group() +def purge_thread(purge_queue: queue.Queue): + """Thread to purge the old checkpoints. + + This is only used when keep_latest_k > 0. + + Args: + purge_queue (queue.Queue): The queue to receive the path to purge and Terminate signal. + """ + try: + while True: + path = purge_queue.get() + if isinstance(path, Terminate): + return + assert isinstance(path, str) + logger.info("Checkpointer is deleting %s.", path) + begin = time.monotonic() + shutil.rmtree(path, ignore_errors=True) + logger.info( + "Checkpointer deleted %s in %.2f seconds.", + path, + time.monotonic() - begin, + ) + finally: + logger.info("Destroying the purge thread.") + + class CheckpointManager: + """This class manages the checkpointing logic for the TorchTitan trainer. + + + Note: Pipeline Parallelism and Virtual Stages + + 1. even for simple PP schedules, there is a separate optimizer each PP rank. + rank0's optimizer would have a param_group[0] which refers to layers.0 in the original + model. rank1's would _also_ have a param_group[0], since it's index based, but + referring to layers.1. When saving, these collide and one of them is lost. Then when + reloading, only one stage can restore its optimizer states, others will error. + + The solution to this problem is optimizer flattening: it landed in #127071 and is + enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' kwarg to DCP + functions called in the OptimizerContainer. + See PR #127071 (https://github.com/pytorch/pytorch/pull/127071) for the example of + a flattening state_dict. + + 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds + challenge (1) by also requiring us to reason about multiple 'optim' objects locally. + + We solve this in the Model and Optimizer wrapper classes by flattening the state dicts + from each object into one state dict before saving/loading. We rely on the individual + state_dicts to not collide, which is gauranteed for the model by correct pipeline + splitting and for the optimizer by the flattening support described in (1). + + 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers + with the assumption that all lr_schedulers have the same state_dict. + + + Args: + dataloader (DataLoader): The dataloader used to load the data. + model_parts (List[nn.Module]): List of model parts to be optimized. + optimizers (OptimizersContainer): The optimizers used to optimize the model. + lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model. + states (Dict[str, Any]): The states that need to be saved, other than the + previous 4 components. + job_config (JobConfig): The job config used to configure the checkpointing. + """ + def __init__( self, dataloader: DataLoader, @@ -154,58 +236,52 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.keep_latest_k = ckpt_config.keep_latest_k + + async_mode = ckpt_config.async_mode.lower() + self.enable_staging = ( + self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM + ) if not self.enable_checkpoint: return - """ - Note: Pipeline Parallelism and Virtual Stages - - 1. even for simple PP schedules, there is a separate optimizer each PP rank. - rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model. - rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1. - When saving, these collide and one of them is lost. Then when reloading, only one stage can - restore its optimizer states, others will error. - The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan - by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. - - 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also - requiring us to reason about multiple 'optim' objects locally. - - We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object - into one state dict before saving/loading. We rely on the individual state_dicts to not collide, - which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening - support described in (1). - - 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that - all lr_schedulers have the same state_dict. - """ self.states = states - self.states.update( { - "model": ModelWrapper(model_parts), - "optimizer": optimizers, - "dataloader": dataloader, - "lr_scheduler": lr_schedulers, + MODEL: ModelWrapper(model_parts), + OPTIMIZER: optimizers, + DATALOADER: dataloader, + LR_SCHEDULER: lr_schedulers, } ) + self.staging = False + self.sending_to_checkpoint_mp = False + self.staging_id = None + self.cpu_offload_state_dict = None + self.staging_stream = torch.cuda.Stream() if self.enable_staging else None + self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) - self.interval_type = ( - IntervalType.SECONDS - if ckpt_config.interval_type == "seconds" - else IntervalType.STEPS - ) self.interval = ckpt_config.interval - self.begin_time = 0 - self.time_sync_work = None - self.time_sync_result = None async_mode = ckpt_config.async_mode.lower() - if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS: + if async_mode == AsyncMode.ASYNC: self.pg = dist.new_group(backend="gloo") + self.keep_latest_k = ckpt_config.keep_latest_k + if self.keep_latest_k > 0: + if self.keep_latest_k == 1: + raise ValueError( + "We need to maintain at least 2 checkpoint replicas, " + "as the last one may be in the process of being saved." + ) + self.purge_queue = queue.Queue() + self.purge_thread = threading.Thread( + target=purge_thread, args=(self.purge_queue,) + ) + self.purge_thread.start() + else: + self.purge_thread = None + self.model_weights_only = ckpt_config.model_weights_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.exclude_from_loading = ckpt_config.exclude_from_loading @@ -230,10 +306,6 @@ def __init__( daemon=True, ) self.mp.start() - self.cpu_offload_state_dict = None - self.staging = False - self.staging_id = None - self.staging_stream = torch.cuda.Stream() else: raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") @@ -242,28 +314,198 @@ def __init__( ) def __del__(self): - if self.enable_checkpoint and self.mp and self.mp.is_alive(): - self.mp_queue_send.put(Terminate()) - self.mp.join() + if self.enable_checkpoint + if self.mp and self.mp.is_alive(): + self.mp_queue_send.put(Terminate()) + self.mp.join() + if self.purge_thread and self.purge_thread.is_alive(): + self.purge_queue.put(Terminate()) + self.purge_thread.join() + + @torch.no_grad() + def save(self, curr_step: int, force: bool = False) -> None: + """Save the checkpoint for the current step. + + This function will save the checkpoint for the current step. If ``force`` is + true, it will save the checkpoint even if the interval has not been reached. + This only happens when train_state.step == job_config.training.steps, or + for initial seed checkpoint. + + Args: + curr_step (int): The current step. + force (bool, optional): Whether to force save the checkpoint. Defaults to False. - def reset(self) -> None: - self.begin_time = time.monotonic() + Returns: + None + """ - def _create_checkpoint_id(self, step: int) -> str: - return os.path.join(self.folder, f"step-{step}") + if not self._should_save(curr_step, force): + return + + begin = time.monotonic() + logger.info("Saving the checkpoint (or staging if async is enabled).") + checkpoint_id = self._create_checkpoint_id(curr_step) + self._async_wait() + # This GC is called for async checkpoint as it is useless to do + # GC right after async_save -- the CPU memory is not able to be + # freed until _async_wait() + if force: + self._save_last_step(curr_step) + elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self._async_with_pinned_memory(checkpoint_id) + elif self.async_mode == AsyncMode.ASYNC: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self.async_future = dcp.async_save( + self.states, checkpoint_id=checkpoint_id, process_group=self.pg + ) + GarbageCollection.collect("GC collection invoked by checkpointer.") + else: + save_with_gc(self.states, checkpoint_id=checkpoint_id) + self._purge_stale_checkpoints() + + logger.info( + "Finished saving the checkpoint (or staging if async is enabled)" + f"in {time.monotonic() - begin:.2f} seconds." + ) + + @torch.no_grad() + def load(self, step: int = -1) -> bool: + """Load the checkpoint for the given step. + + This function will load the checkpoint for the given step. If ``step`` is -1, it + will load the latest checkpoint. If the checkpoint does not exist, it will return + False and load nothing. + + Args: + step (int, optional): The step to load the checkpoint for. Defaults to -1. + + Returns: + bool: Whether the checkpoint was loaded successfully. + """ + + if not self.enable_checkpoint or not os.path.isdir(self.folder): + return False + + if step == -1: + step = self._find_load_step() + if step == -1: + return False + + checkpoint_id = self._create_checkpoint_id(step) + if not os.path.isdir(checkpoint_id): + return False + + logger.info(f"Loading the checkpoint at step {step}.") + begin = time.monotonic() + states = self._states_to_load(checkpoint_id) + dcp.load(states, checkpoint_id=checkpoint_id) + GarbageCollection.collect("GC collection for checkpoint loading.") + logger.info( + f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + return True + + def maybe_wait_for_staging(self) -> None: + """Wait for the staging to finish if it is enabled. + + This function will wait for staging to finish. The staging is only enabled + with ``async_checkpoint_with_pinned_memory``. + """ + if self.enable_staging and self.staging: + if not self.staging_stream.query(): + begin = time.monotonic() + self.staging_stream.synchronize() + logger.info( + "Checkpointer waited staging %.2f seconds.", + time.monotonic() - begin, + ) + self.staging = False + + if self.sending_to_checkpoint_mp: + # Copy the sync staging result to another process. + def sync_func(): + self.mp_queue_send.put_nowait( + (self.cpu_offload_state_dict, self.staging_id) + ) + + # This may be a faster way to do zero-overhead checkpointing staging + # checkpointing but we need more thorough investigation before + # swithing to this method. + # self.my_thread = threading.Thread(target=func).start() + begin = time.monotonic() + sync_func() + logger.info( + "Checkpointer sent staged state_dict to another process %.2f seconds", + time.monotonic() - begin, + ) + self.sending_to_checkpoint_mp = False + + def _find_load_step(self, folder: str = "") -> int: + """Find the step to load the checkpoint for. + + Args: + folder (str, optional): The folder to find the checkpoint for. If ``folder`` + is "", then ``self.folder`` will be used. + + Returns: + int: The step to load the checkpoint for. + """ + folder = folder if folder else self.folder + pattern = r"step-(\d+)" + step_counts = [] + + if not os.path.isdir(folder): + return -1 + + for filename in os.listdir(folder): + match = re.search(pattern, filename) + metadata_probe = os.path.join(folder, filename, ".metadata") + if match and os.path.isfile(metadata_probe): + step_counts.append(int(match.group(1))) + if not step_counts: + return -1 + return max(step_counts) + + def _create_checkpoint_id(self, step: int, folder: str = "") -> str: + folder = folder if folder else self.folder + return os.path.join(folder, f"step-{step}") + + def _states_to_load(self, step: int) -> Dict[str, Any]: + """Determines which states to load for the given step. + + When checkpointer determines which step of the checkpoint to load, this API is + used to determine which states to load based on the step. + + Args: + step (int): The step to load the checkpoint for. + + Returns: + Dict[str, Any]: The states to load for the given step. + """ + # For the first step, we will only load the model weights. + states = {MODEL: self.states[MODEL]} if step == 0 else self.states + states_to_load = { + k: v for k, v in states.items() if k not in self.exclude_from_loading + } + for exclude_key in self.exclude_from_loading: + if exclude_key not in states: + raise ValueError(f"{exclude_key} not found in state_dict.") + return states_to_load def _save_last_step(self, curr_step: int) -> None: # We only consider saving weights only at the end of the training. So # this won't affect preemption and training resume. We also only allow # dtype conversion when we are checkpoint model weights only and the # current dtype is not the same as the export dtype at the end of the training. + if self.model_weights_only: # We update self.states to keep the model only. # After this update, self.states = { # 'tok_embeddings.weight':..., # 'layers.0.attention.wq.weight': ... # }. - self.states = self.states["model"].state_dict() + self.states = self.states[MODEL].state_dict() # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. @@ -282,41 +524,18 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) - self.reset() def _should_save(self, curr_step: int, force: bool = False) -> bool: if not self.enable_checkpoint: return False - if not force: - if self.interval_type == IntervalType.STEPS and not ( - curr_step % self.interval == 0 - ): - return False - if self.interval_type == IntervalType.SECONDS: - time_sync_result = (time.monotonic() - self.begin_time) >= self.interval - self.time_sync_result = torch.tensor(int(time_sync_result)) - if self.time_sync_work is None: - self.time_sync_work = dist.all_reduce( - self.time_sync_result, group=self.pg, async_op=True - ) - return False - elif curr_step % 5 == 4: - self.time_sync_work.wait() - self.time_sync_work = None - time_sync_result = self.time_sync_result.item() - self.time_sync_result = None - if time_sync_result == 0: - return False - else: - return False - - if self.time_sync_work: - self.time_sync_work.wait() - self.time_sync_work = None - self.time_sync_result = None + if force: + return True - return True + if curr_step % self.interval == 0: + return True + + return False def _async_wait(self) -> None: if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: @@ -331,15 +550,11 @@ def _async_wait(self) -> None: self.async_future.result() def _async_with_pinned_memory(self, checkpoint_id: str) -> None: - try: - from torch.distributed._state_dict_utils import ( - _copy_state_dict, - _create_cpu_state_dict, - ) - except ImportError as e: - raise ImportError( - "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." - ) from e + self._cpu_staging(checkpoint_id) + self.sending_to_checkpoint_mp = True + + def _cpu_staging(self, checkpoint_id: Optional[str]) -> None: + """Offload state_dict to CPU memory""" state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states) if self.cpu_offload_state_dict is None: logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") @@ -357,117 +572,12 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None: self.staging = True self.staging_id = checkpoint_id - def save(self, curr_step: int, force: bool = False) -> None: - """ - force = True will force the checkpoint to be saved, even if the interval - has not been reached. - This only happens when train_state.step == job_config.training.steps, or - for initial seed checkpoint. - """ - if not self._should_save(curr_step, force): - return - - begin = time.monotonic() - checkpoint_id = self._create_checkpoint_id(curr_step) - self._async_wait() - # This GC is called for async checkpoint as it is useless to do - # GC right after async_save -- the CPU memory is not able to be - # freed until _async_wait() - if force: - self._save_last_step(curr_step) - elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self._async_with_pinned_memory(checkpoint_id) - elif self.async_mode == AsyncMode.ASYNC: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self.async_future = dcp.async_save( - self.states, checkpoint_id=checkpoint_id, process_group=self.pg - ) - else: - save_with_gc(self.states, checkpoint_id=checkpoint_id) - self.reset() - self._purge_stale_checkpoints() - - logger.info( - "Finished saving the checkpoint (or staging if async is enabled)" - f"in {time.monotonic() - begin:.2f} seconds." - ) - - def maybe_wait_for_staging(self) -> None: + def _purge_stale_checkpoints(self): if ( - self.enable_checkpoint - and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - and self.staging + self.keep_latest_k > 0 + and dist.get_rank() == 0 + and os.path.isdir(self.folder) ): - if not self.staging_stream.query(): - self.staging_stream.synchronize() - - def sync_func(): - self.mp_queue_send.put_nowait( - (self.cpu_offload_state_dict, self.staging_id) - ) - - # This may be a faster way to do zero-overhead checkpointing staging - # checkpointing but we need more thorough investigation before - # swithing to this method. - # self.my_thread = threading.Thread(target=func).start() - sync_func() - self.staging = False - - def load(self, step: int = -1) -> bool: - if not self.enable_checkpoint: - return False - if not os.path.isdir(self.folder): - return False - if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): - return False - - if step == -1: - step_counts = [] - for filename in os.listdir(self.folder): - match = re.search(r"step-(\d+)", filename) - metadata_probe = os.path.join(self.folder, filename, ".metadata") - if match and os.path.isfile(metadata_probe): - step_counts.append(int(match.group(1))) - if not step_counts: - return False - step = max(step_counts) - - # We won't have optimizer states to load, if we are loading a seed checkpoint - states = {"model": self.states["model"]} if step == 0 else self.states - # PyTorch bug: (pytorch/pytorch#138575) - # dcp.load() replaces the values of stateful elements in `states` with new objects - # from loading the checkpoint, in addition to updating the states of the original - # objects from `states` in-place. This is a problem because the state_dict no longer - # refers to the objects being used in the train loop, meaning any future checkpoints - # will not include updates to these objects (such as updated optimizer states, etc.) - original_stateful_states = { - k: v for k, v in states.items() if isinstance(v, Stateful) - } - logger.info(f"Loading the checkpoint at step {step}.") - begin = time.monotonic() - states_to_load = { - k: v for k, v in states.items() if k not in self.exclude_from_loading - } - for exclude_key in self.exclude_from_loading: - if exclude_key not in states: - raise ValueError(f"{exclude_key} not found in state_dict.") - dcp.load( - states_to_load, - checkpoint_id=self._create_checkpoint_id(step), - ) - states.update(states_to_load) - logger.info( - f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." - ) - # bugfix from above: restore the original stateful objects, - # whose states were already updated in-place by dcp.load() - states.update(original_stateful_states) - GarbageCollection.collect("GC collection for checkpoint loading.") - return True - - def _purge_stale_checkpoints(self): - if self.keep_latest_k > 0: discovered_checkpoints = [] for filename in os.listdir(self.folder): match = re.search(r"step-(\d+)", filename) @@ -478,5 +588,5 @@ def _purge_stale_checkpoints(self): to_delete = discovered_checkpoints[: -1 * self.keep_latest_k] for _, path in to_delete: - logger.info(f"Deleting old checkpoint {path}") - shutil.rmtree(path, ignore_errors=True) + assert self.purge_thread is not None + self.purge_queue.put(path) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8a98a745..66fd2377c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -502,17 +502,11 @@ def __init__(self): When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. """, ) - self.parser.add_argument( - "--checkpoint.interval_type", - type=str, - default="steps", - help="Checkpointing interval unit of measurement ['step', 'seconds']", - ) self.parser.add_argument( "--checkpoint.interval", type=int, default=500, - help="Checkpointing interval, in steps or seconds depending on --checkpoint.interval_type", + help="Checkpointing interval in steps.", ) self.parser.add_argument( "--checkpoint.model_weights_only", @@ -572,7 +566,8 @@ def __init__(self): default=0, help=""" Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints. - 0 is the default value. + 0 is the default value. k cannot be 1 as the last one may be in the process of being + saved. As a result, the metadata of the last one may not be ready yet. """, ) self.parser.add_argument( diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index 8f4a40dd6..04fe37e60 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -53,7 +53,6 @@ enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false folder = "checkpoint" -interval_type = "steps" interval = 10 model_weights_only = false export_dtype = "float32" diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index 87ee3ac2c..b5a63195e 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -46,7 +46,6 @@ enable_async_tensor_parallel = true [checkpoint] enable_checkpoint = false folder = "checkpoint" -interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index 6745cecab..b25ab7270 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -45,7 +45,6 @@ pipeline_parallel_degree = 1 [checkpoint] enable_checkpoint = false folder = "checkpoint" -interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index e551a7706..2aec70773 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -45,7 +45,6 @@ pipeline_parallel_degree = 1 [checkpoint] enable_checkpoint = false folder = "checkpoint" -interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" diff --git a/torchtitan/train.py b/torchtitan/train.py index e6d4083d5..3763175e7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -240,8 +240,6 @@ def main(job_config: JobConfig): time_last_log = time.perf_counter() device_memory_monitor.reset_peak_stats() - checkpoint.reset() - # train loop logger.info( f"Training starts at step {train_state.step + 1}, "