13
13
from dataclasses import dataclass , field
14
14
from io import BytesIO
15
15
from multiprocessing import get_context
16
- from typing import Any , Dict , List , Union
16
+ from typing import Any , Dict , List , Optional , Union
17
17
18
18
import torch
19
19
import torch .distributed as dist
20
20
import torch .distributed .checkpoint as dcp
21
21
import torch .nn as nn
22
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
22
23
from torch .distributed .checkpoint .state_dict import (
23
24
get_model_state_dict ,
24
25
set_model_state_dict ,
@@ -143,50 +144,29 @@ def __init__(
143
144
lr_schedulers : SchedulersContainer ,
144
145
states : Dict [str , Any ],
145
146
job_config : JobConfig ,
147
+ ft_manager : Optional [Any ] = None ,
146
148
) -> None :
147
149
ckpt_config = job_config .checkpoint
148
150
self .enable_checkpoint = ckpt_config .enable_checkpoint
149
- self .keep_latest_k = ckpt_config .keep_latest_k
151
+ self .ft_manager = ft_manager
152
+ self .enable_staging = (
153
+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
154
+ ) or self .ft_manager
150
155
151
- if not self .enable_checkpoint :
156
+ if not self .enable_checkpoint and self . ft_manager is None :
152
157
return
153
- """
154
- Note: Pipeline Parallelism and Virtual Stages
155
-
156
- 1. even for simple PP schedules, there is a separate optimizer each PP rank.
157
- rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
158
- rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
159
- When saving, these collide and one of them is lost. Then when reloading, only one stage can
160
- restore its optimizer states, others will error.
161
-
162
- The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
163
- by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
164
-
165
- 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
166
- requiring us to reason about multiple 'optim' objects locally.
167
-
168
- We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
169
- into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
170
- which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
171
- support described in (1).
172
-
173
- 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
174
- resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
175
- optimizers do, so it's hard to write a generic 'flattener' utility.
176
-
177
- TODO: This is currently unsolved and needs a fix.
178
- """
179
- self .states = states
180
158
181
- self .states .update (
182
- {
183
- "model" : ModelWrapper (model_parts ),
184
- "optimizer" : optimizers ,
185
- "dataloader" : dataloader ,
186
- "lr_scheduler" : lr_schedulers ,
187
- }
159
+ self ._initialize_states (
160
+ states , dataloader , model_parts , optimizers , lr_schedulers
188
161
)
189
162
163
+ async_mode = ckpt_config .async_mode .lower ()
164
+ self .staging = False
165
+ self .sending_to_checkpoint_mp = False
166
+ self .staging_id = None
167
+ self .cpu_offload_state_dict = None
168
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
169
+
190
170
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
191
171
self .interval_type = (
192
172
IntervalType .SECONDS
@@ -201,6 +181,7 @@ def __init__(
201
181
if async_mode == AsyncMode .ASYNC or self .interval_type == IntervalType .SECONDS :
202
182
self .pg = dist .new_group (backend = "gloo" )
203
183
184
+ self .keep_latest_k = ckpt_config .keep_latest_k
204
185
self .model_weights_only = ckpt_config .model_weights_only
205
186
self .export_dtype = TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
206
187
@@ -224,10 +205,6 @@ def __init__(
224
205
daemon = True ,
225
206
)
226
207
self .mp .start ()
227
- self .cpu_offload_state_dict = None
228
- self .staging = False
229
- self .staging_id = None
230
- self .staging_stream = torch .cuda .Stream ()
231
208
else :
232
209
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
233
210
@@ -241,8 +218,61 @@ def __del__(self):
241
218
self .mp .join ()
242
219
243
220
def reset (self ) -> None :
221
+ # We need to stage the local state if another replicate joins during the
222
+ # first step.
223
+ if self .ft_manager :
224
+ self .cpu_staging (None )
244
225
self .begin_time = time .monotonic ()
245
226
227
+ def _initialize_states (
228
+ self ,
229
+ states : Dict [str , Any ],
230
+ dataloader : DataLoader ,
231
+ model_parts : List [nn .Module ],
232
+ optimizers : OptimizersContainer ,
233
+ lr_schedulers : SchedulersContainer ,
234
+ ) -> None :
235
+ """
236
+ Note: Pipeline Parallelism and Virtual Stages
237
+
238
+ 1. Even for simple PP schedules, there is a separate optimizer each PP rank.
239
+ rank0's optimizer would have a param_group[0] which refers to layers.0 in the
240
+ original model. rank1's would _also_ have a param_group[0], since it's index based,
241
+ but referring to layers.1.
242
+ When saving, these collide and one of them is lost. Then when reloading, only one
243
+ stage can restore its optimizer states, others will error.
244
+
245
+ The solution to this problem is optimizer flattening: it landed in #127071
246
+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
247
+ kwarg to DCP functions called in the OptimizerContainer.
248
+
249
+ 2. With complex PP schedules, we have multiple model chunks per pp rank. This
250
+ compounds challenge (1) by also requiring us to reason about multiple 'optim'
251
+ objects locally.
252
+
253
+ We solve this in the Model and Optimizer wrapper classes by flattening the
254
+ state dicts from each object into one state dict before saving/loading.
255
+ We rely on the individual state_dicts to not collide, which is gauranteed for
256
+ the model by correct pipeline splitting and for the optimizer by the flattening
257
+ support described in (1).
258
+
259
+ 3. LR schedulers also index model states like optimizers and would need to be
260
+ flattened properly to support resharding. Unfortunately, the implementations of
261
+ different lr_schedulers do not follow a clear pattern like optimizers do, so it's
262
+ hard to write a generic 'flattener' utility.
263
+
264
+ TODO: This is currently unsolved and needs a fix.
265
+ """
266
+ self .states = states
267
+ self .states .update (
268
+ {
269
+ "model" : ModelWrapper (model_parts ),
270
+ "optimizer" : optimizers ,
271
+ "dataloader" : dataloader ,
272
+ "lr_scheduler" : lr_schedulers ,
273
+ }
274
+ )
275
+
246
276
def _create_checkpoint_id (self , step : int ) -> str :
247
277
return os .path .join (self .folder , f"step-{ step } " )
248
278
@@ -325,31 +355,8 @@ def _async_wait(self) -> None:
325
355
self .async_future .result ()
326
356
327
357
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
328
- try :
329
- from torch .distributed ._state_dict_utils import (
330
- _copy_state_dict ,
331
- _create_cpu_state_dict ,
332
- )
333
- except ImportError as e :
334
- raise ImportError (
335
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
336
- ) from e
337
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
338
- if self .cpu_offload_state_dict is None :
339
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
340
- self .cpu_offload_state_dict = _create_cpu_state_dict (
341
- state_dict , pin_memory = True , share_memory = True
342
- )
343
-
344
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
345
- with torch .cuda .stream (self .staging_stream ):
346
- self .cpu_offload_state_dict = _copy_state_dict (
347
- state_dict ,
348
- self .cpu_offload_state_dict ,
349
- non_blocking = True ,
350
- )
351
- self .staging = True
352
- self .staging_id = checkpoint_id
358
+ self .cpu_staging (checkpoint_id )
359
+ self .sending_to_checkpoint_mp = True
353
360
354
361
def save (self , curr_step : int , force : bool = False ) -> None :
355
362
"""
@@ -359,6 +366,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
359
366
for initial seed checkpoint.
360
367
"""
361
368
if not self ._should_save (curr_step , force ):
369
+ if self .ft_manager :
370
+ self .cpu_staging (None )
362
371
return
363
372
364
373
begin = time .monotonic ()
@@ -382,26 +391,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
382
391
f"in { time .monotonic () - begin :.2f} seconds."
383
392
)
384
393
394
+ def cpu_staging (self , checkpoint_id : Optional [str ]) -> None :
395
+ """Offload state_dict to CPU memory"""
396
+ state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
397
+ if self .cpu_offload_state_dict is None :
398
+ logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
399
+ self .cpu_offload_state_dict = _create_cpu_state_dict (
400
+ state_dict , pin_memory = True , share_memory = True
401
+ )
402
+
403
+ logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
404
+ with torch .cuda .stream (self .staging_stream ):
405
+ self .cpu_offload_state_dict = _copy_state_dict (
406
+ state_dict ,
407
+ self .cpu_offload_state_dict ,
408
+ non_blocking = True ,
409
+ )
410
+ self .staging = True
411
+ self .staging_id = checkpoint_id
412
+
413
+ def wait_for_staging (self ) -> None :
414
+ if not self .staging_stream .query ():
415
+ self .staging_stream .synchronize ()
416
+ self .staging = False
417
+
418
+ def staging_results (self ) -> Dict [str , Any ]:
419
+ self .maybe_wait_for_staging ()
420
+ return self .cpu_offload_state_dict
421
+
385
422
def maybe_wait_for_staging (self ) -> None :
386
- if (
387
- self .enable_checkpoint
388
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
389
- and self .staging
390
- ):
391
- if not self .staging_stream .query ():
392
- self .staging_stream .synchronize ()
393
-
394
- def sync_func ():
395
- self .mp_queue_send .put_nowait (
396
- (self .cpu_offload_state_dict , self .staging_id )
397
- )
398
-
399
- # This may be a faster way to do zero-overhead checkpointing staging
400
- # checkpointing but we need more thorough investigation before
401
- # swithing to this method.
402
- # self.my_thread = threading.Thread(target=func).start()
403
- sync_func ()
404
- self .staging = False
423
+ if self .enable_staging and self .staging :
424
+ self .wait_for_staging ()
425
+
426
+ if self .sending_to_checkpoint_mp :
427
+ # Copy the sync staging result to another process.
428
+ def sync_func ():
429
+ self .mp_queue_send .put_nowait (
430
+ (self .cpu_offload_state_dict , self .staging_id )
431
+ )
432
+
433
+ # This may be a faster way to do zero-overhead checkpointing staging
434
+ # checkpointing but we need more thorough investigation before
435
+ # swithing to this method.
436
+ # self.my_thread = threading.Thread(target=func).start()
437
+ sync_func ()
438
+ self .sending_to_checkpoint_mp = False
405
439
406
440
def load (self , step : int = - 1 ) -> bool :
407
441
if not self .enable_checkpoint :
0 commit comments