30
30
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
31
31
from torchtitan .logging import init_logger , logger
32
32
from torchtitan .optimizer import LRSchedulersContainer , OptimizersContainer
33
+ from torchtitan .utils import GarbageCollection
33
34
34
35
35
36
class IntervalType (enum .Enum ):
@@ -106,6 +107,12 @@ class SaveDone:
106
107
pass
107
108
108
109
110
+ @torch .no_grad ()
111
+ def save_with_gc (state , checkpoint_id ):
112
+ dcp .save (state , checkpoint_id = checkpoint_id )
113
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
114
+
115
+
109
116
def checkpoint_mp (recv , send ):
110
117
init_logger ()
111
118
os .environ ["MASTER_PORT" ] = str (int (os .environ ["MASTER_PORT" ]) + 2 )
@@ -125,7 +132,7 @@ def checkpoint_mp(recv, send):
125
132
assert isinstance (obj , tuple )
126
133
begin = time .monotonic ()
127
134
state , checkpoint_id = obj
128
- dcp . save (state , checkpoint_id = checkpoint_id )
135
+ save_with_gc (state , checkpoint_id = checkpoint_id )
129
136
logger .info (
130
137
"Finish saving the checkpoint in the background process in "
131
138
f"{ time .monotonic () - begin :.2f} seconds."
@@ -274,7 +281,7 @@ def _save_last_step(self, curr_step: int) -> None:
274
281
else :
275
282
logger .info (f"Saving a full checkpoint at last step, step { curr_step } ." )
276
283
277
- dcp . save (self .states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
284
+ save_with_gc (self .states , checkpoint_id = self ._create_checkpoint_id (curr_step ))
278
285
self .reset ()
279
286
280
287
def _should_save (self , curr_step : int , force : bool = False ) -> bool :
@@ -363,16 +370,21 @@ def save(self, curr_step: int, force: bool = False) -> None:
363
370
begin = time .monotonic ()
364
371
checkpoint_id = self ._create_checkpoint_id (curr_step )
365
372
self ._async_wait ()
373
+ # This GC is called for async checkpoint as it is useless to do
374
+ # GC right after async_save -- the CPU memory is not able to be
375
+ # freed until _async_wait()
366
376
if force :
367
377
self ._save_last_step (curr_step )
368
378
elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
379
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
369
380
self ._async_with_pinned_memory (checkpoint_id )
370
381
elif self .async_mode == AsyncMode .ASYNC :
382
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
371
383
self .async_future = dcp .async_save (
372
384
self .states , checkpoint_id = checkpoint_id , process_group = self .pg
373
385
)
374
386
else :
375
- dcp . save (self .states , checkpoint_id = checkpoint_id )
387
+ save_with_gc (self .states , checkpoint_id = checkpoint_id )
376
388
self .reset ()
377
389
self ._purge_stale_checkpoints ()
378
390
@@ -451,6 +463,7 @@ def load(self, step: int = -1) -> bool:
451
463
# bugfix from above: restore the original stateful objects,
452
464
# whose states were already updated in-place by dcp.load()
453
465
states .update (original_stateful_states )
466
+ GarbageCollection .collect ("GC collection for checkpoint loading." )
454
467
return True
455
468
456
469
def _purge_stale_checkpoints (self ):
0 commit comments