Skip to content

Commit ab94a99

Browse files
authored
Do GC collect after dcp.save and dcp.load (#839)
We disable auto gc and manually perform GC every 50 steps. This can cause issues when checkpointing frequence is smaller than GC frequency. This PR change checkpoint manager to call GC when a save or load happens.
1 parent fb0a942 commit ab94a99

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

torchtitan/checkpoint.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3131
from torchtitan.logging import init_logger, logger
3232
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
33+
from torchtitan.utils import GarbageCollection
3334

3435

3536
class IntervalType(enum.Enum):
@@ -106,6 +107,12 @@ class SaveDone:
106107
pass
107108

108109

110+
@torch.no_grad()
111+
def save_with_gc(state, checkpoint_id):
112+
dcp.save(state, checkpoint_id=checkpoint_id)
113+
GarbageCollection.collect("GC collection invoked by checkpointer.")
114+
115+
109116
def checkpoint_mp(recv, send):
110117
init_logger()
111118
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
@@ -125,7 +132,7 @@ def checkpoint_mp(recv, send):
125132
assert isinstance(obj, tuple)
126133
begin = time.monotonic()
127134
state, checkpoint_id = obj
128-
dcp.save(state, checkpoint_id=checkpoint_id)
135+
save_with_gc(state, checkpoint_id=checkpoint_id)
129136
logger.info(
130137
"Finish saving the checkpoint in the background process in "
131138
f"{time.monotonic() - begin:.2f} seconds."
@@ -274,7 +281,7 @@ def _save_last_step(self, curr_step: int) -> None:
274281
else:
275282
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
276283

277-
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
284+
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
278285
self.reset()
279286

280287
def _should_save(self, curr_step: int, force: bool = False) -> bool:
@@ -363,16 +370,21 @@ def save(self, curr_step: int, force: bool = False) -> None:
363370
begin = time.monotonic()
364371
checkpoint_id = self._create_checkpoint_id(curr_step)
365372
self._async_wait()
373+
# This GC is called for async checkpoint as it is useless to do
374+
# GC right after async_save -- the CPU memory is not able to be
375+
# freed until _async_wait()
366376
if force:
367377
self._save_last_step(curr_step)
368378
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
379+
GarbageCollection.collect("GC collection invoked by checkpointer.")
369380
self._async_with_pinned_memory(checkpoint_id)
370381
elif self.async_mode == AsyncMode.ASYNC:
382+
GarbageCollection.collect("GC collection invoked by checkpointer.")
371383
self.async_future = dcp.async_save(
372384
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
373385
)
374386
else:
375-
dcp.save(self.states, checkpoint_id=checkpoint_id)
387+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
376388
self.reset()
377389
self._purge_stale_checkpoints()
378390

@@ -451,6 +463,7 @@ def load(self, step: int = -1) -> bool:
451463
# bugfix from above: restore the original stateful objects,
452464
# whose states were already updated in-place by dcp.load()
453465
states.update(original_stateful_states)
466+
GarbageCollection.collect("GC collection for checkpoint loading.")
454467
return True
455468

456469
def _purge_stale_checkpoints(self):

torchtitan/utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,16 @@ def __init__(self, gc_freq=1000):
163163
assert gc_freq > 0, "gc_freq must be a positive integer"
164164
self.gc_freq = gc_freq
165165
gc.disable()
166-
gc.collect(1)
166+
self.collect("Initial GC collection.")
167167

168168
def run(self, step_count):
169169
if step_count > 1 and step_count % self.gc_freq == 0:
170-
gc.collect(1)
170+
self.collect("Peforming periodical GC collection.")
171+
172+
@staticmethod
173+
def collect(reason: str):
174+
logger.info(reason)
175+
gc.collect(1)
171176

172177

173178
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"

0 commit comments

Comments
 (0)