1313from dataclasses import dataclass , field
1414from io import BytesIO
1515from multiprocessing import get_context
16- from typing import Any , Dict , List , Union
16+ from typing import Any , Dict , List , Optional , Union
1717
1818import torch
1919import torch .distributed as dist
2020import torch .distributed .checkpoint as dcp
2121import torch .nn as nn
22+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
2223from torch .distributed .checkpoint .state_dict import (
2324 get_model_state_dict ,
2425 set_model_state_dict ,
@@ -143,16 +144,19 @@ def __init__(
143144 lr_schedulers : SchedulersContainer ,
144145 states : Dict [str , Any ],
145146 job_config : JobConfig ,
147+ ft_manager : Optional [Any ] = None ,
146148 ) -> None :
147149 ckpt_config = job_config .checkpoint
148150 self .enable_checkpoint = ckpt_config .enable_checkpoint
149- self .keep_latest_k = ckpt_config .keep_latest_k
151+ self .ft_manager = ft_manager
152+ self .enable_staging = (
153+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
154+ ) or self .ft_manager
150155
151- if not self .enable_checkpoint :
156+ if not self .enable_checkpoint and self . ft_manager is None :
152157 return
153- """
154- Note: Pipeline Parallelism and Virtual Stages
155158
159+ < << << << HEAD
156160 1. even for simple PP schedules , there is a separate optimizer each PP rank .
157161 rank0 's optimizer would have a param_group [0 ] which refers to layers .0 in the original model .
158162 rank1 's would _also_ have a param_group[0], since it' s index based , but referring to layers .1.
@@ -186,6 +190,18 @@ def __init__(
186190 "lr_scheduler": lr_schedulers,
187191 }
188192 )
193+ =======
194+ self._initialize_states(
195+ states, dataloader, model_parts, optimizers, lr_schedulers
196+ )
197+
198+ async_mode = ckpt_config.async_mode.lower()
199+ self.staging = False
200+ self.sending_to_checkpoint_mp = False
201+ self.staging_id = None
202+ self.cpu_offload_state_dict = None
203+ self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
204+ >>>>>>> 3430d99 ([WIP][RFC] TorchFT integration)
189205
190206 self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
191207 self.interval_type = (
@@ -201,6 +217,7 @@ def __init__(
201217 if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
202218 self.pg = dist.new_group(backend="gloo")
203219
220+ self.keep_latest_k = ckpt_config.keep_latest_k
204221 self.model_weights_only = ckpt_config.model_weights_only
205222 self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
206223
@@ -224,10 +241,6 @@ def __init__(
224241 daemon=True,
225242 )
226243 self.mp.start()
227- self .cpu_offload_state_dict = None
228- self .staging = False
229- self .staging_id = None
230- self .staging_stream = torch .cuda .Stream ()
231244 else:
232245 raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
233246
@@ -241,8 +254,61 @@ def __del__(self):
241254 self.mp.join()
242255
243256 def reset(self) -> None:
257+ # We need to stage the local state if another replicate joins during the
258+ # first step.
259+ if self.ft_manager:
260+ self.cpu_staging(None)
244261 self.begin_time = time.monotonic()
245262
263+ def _initialize_states(
264+ self,
265+ states: Dict[str, Any],
266+ dataloader: DataLoader,
267+ model_parts: List[nn.Module],
268+ optimizers: OptimizersContainer,
269+ lr_schedulers: SchedulersContainer,
270+ ) -> None:
271+ """
272+ Note : Pipeline Parallelism and Virtual Stages
273+
274+ 1. Even for simple PP schedules , there is a separate optimizer each PP rank .
275+ rank0 's optimizer would have a param_group [0 ] which refers to layers .0 in the
276+ original model . rank1 's would _also_ have a param_group[0], since it' s index based ,
277+ but referring to layers .1.
278+ When saving , these collide and one of them is lost . Then when reloading , only one
279+ stage can restore its optimizer states , others will error .
280+
281+ The solution to this problem is optimizer flattening : it landed in #127071
282+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
283+ kwarg to DCP functions called in the OptimizerContainer .
284+
285+ 2. With complex PP schedules , we have multiple model chunks per pp rank . This
286+ compounds challenge (1 ) by also requiring us to reason about multiple 'optim'
287+ objects locally .
288+
289+ We solve this in the Model and Optimizer wrapper classes by flattening the
290+ state dicts from each object into one state dict before saving / loading .
291+ We rely on the individual state_dicts to not collide , which is gauranteed for
292+ the model by correct pipeline splitting and for the optimizer by the flattening
293+ support described in (1 ).
294+
295+ 3. LR schedulers also index model states like optimizers and would need to be
296+ flattened properly to support resharding . Unfortunately , the implementations of
297+ different lr_schedulers do not follow a clear pattern like optimizers do , so it 's
298+ hard to write a generic 'flattener' utility .
299+
300+ TODO : This is currently unsolved and needs a fix .
301+ """
302+ self.states = states
303+ self.states.update(
304+ {
305+ "model": ModelWrapper(model_parts),
306+ "optimizer": optimizers,
307+ "dataloader": dataloader,
308+ }
309+ )
310+ self.states.update(lr_schedulers.get_lr_scheduler_state())
311+
246312 def _create_checkpoint_id(self, step: int) -> str:
247313 return os.path.join(self.folder, f"step-{step}")
248314
@@ -325,31 +391,8 @@ def _async_wait(self) -> None:
325391 self.async_future.result()
326392
327393 def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
328- try :
329- from torch .distributed ._state_dict_utils import (
330- _copy_state_dict ,
331- _create_cpu_state_dict ,
332- )
333- except ImportError as e :
334- raise ImportError (
335- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
336- ) from e
337- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
338- if self .cpu_offload_state_dict is None :
339- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
340- self .cpu_offload_state_dict = _create_cpu_state_dict (
341- state_dict , pin_memory = True , share_memory = True
342- )
343-
344- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
345- with torch .cuda .stream (self .staging_stream ):
346- self .cpu_offload_state_dict = _copy_state_dict (
347- state_dict ,
348- self .cpu_offload_state_dict ,
349- non_blocking = True ,
350- )
351- self .staging = True
352- self .staging_id = checkpoint_id
394+ self.cpu_staging(checkpoint_id)
395+ self.sending_to_checkpoint_mp = True
353396
354397 def save(self, curr_step: int, force: bool = False) -> None:
355398 """
@@ -359,6 +402,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
359402 for initial seed checkpoint .
360403 """
361404 if not self._should_save(curr_step, force):
405+ if self.ft_manager:
406+ self.cpu_staging(None)
362407 return
363408
364409 begin = time.monotonic()
@@ -382,26 +427,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
382427 f"in {time.monotonic() - begin:.2f} seconds."
383428 )
384429
430+ def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
431+ """ Offload state_dict to CPU memory """
432+ state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
433+ if self.cpu_offload_state_dict is None:
434+ logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
435+ self.cpu_offload_state_dict = _create_cpu_state_dict(
436+ state_dict, pin_memory=True, share_memory=True
437+ )
438+
439+ logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
440+ with torch.cuda.stream(self.staging_stream):
441+ self.cpu_offload_state_dict = _copy_state_dict(
442+ state_dict,
443+ self.cpu_offload_state_dict,
444+ non_blocking=True,
445+ )
446+ self.staging = True
447+ self.staging_id = checkpoint_id
448+
449+ def wait_for_staging(self) -> None:
450+ if not self.staging_stream.query():
451+ self.staging_stream.synchronize()
452+ self.staging = False
453+
454+ def staging_results(self) -> Dict[str, Any]:
455+ self.maybe_wait_for_staging()
456+ return self.cpu_offload_state_dict
457+
385458 def maybe_wait_for_staging(self) -> None:
386- if (
387- self .enable_checkpoint
388- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
389- and self .staging
390- ):
391- if not self .staging_stream .query ():
392- self .staging_stream .synchronize ()
393-
394- def sync_func ():
395- self .mp_queue_send .put_nowait (
396- (self .cpu_offload_state_dict , self .staging_id )
397- )
398-
399- # This may be a faster way to do zero-overhead checkpointing staging
400- # checkpointing but we need more thorough investigation before
401- # swithing to this method.
402- # self.my_thread = threading.Thread(target=func).start()
403- sync_func ()
404- self .staging = False
459+ if self.enable_staging and self.staging:
460+ self.wait_for_staging()
461+
462+ if self.sending_to_checkpoint_mp:
463+ # Copy the sync staging result to another process.
464+ def sync_func():
465+ self.mp_queue_send.put_nowait(
466+ (self.cpu_offload_state_dict, self.staging_id)
467+ )
468+
469+ # This may be a faster way to do zero-overhead checkpointing staging
470+ # checkpointing but we need more thorough investigation before
471+ # swithing to this method.
472+ # self.my_thread = threading.Thread(target=func).start()
473+ sync_func()
474+ self.sending_to_checkpoint_mp = False
405475
406476 def load(self, step: int = -1) -> bool:
407477 if not self.enable_checkpoint:
0 commit comments