Skip to content

Commit cca0702

Browse files
authored
[BE] Lr schduler flatten (#794)
Currently, lr_scheduler is stored differently as optimizer, model and data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ... stored in the state This PR aims to flatten lr_shceduler so that all the schedulers would be stored as a list of state_dict under self.state['lr_scheduler'], which is consistent with optimizer Here we have the assumption that all the optimziers have the same lr_scheduler, thus only to save a single lr_scheduler's state_dict and load it to all the schedulers works here. The lr_scheduler has the state_dict like: `{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [2.985074626865671e-06], 'lr_lambdas': [{}]}` The PR is tested by 2 parts: 1. test lr_scheduler value before and after checkpoint, resharding with degree changes on tp and pp. [dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4] [dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1] date_loader does not support resharding right now. logs: [dp=2, tp=4, pp=1] step 5 before saving to checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]: [{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05, ...}] [dp=8, tp=1, pp=1] step 5 without checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 without checkpoint: [{'lr': 1.6417910447761194e-05, ...}] 2. Memory trace: Before the flatten, rerun llama3_8b.toml from step 5 to step 10: <img width="1166" alt="Screenshot 2025-01-16 at 2 40 03 PM" src="https://github.com/user-attachments/assets/d3e84d63-30be-4604-823b-68bd217498a0" /> After the flatten, rerun llama3_8b.toml from step 5 to step 10: <img width="1166" alt="Screenshot 2025-01-16 at 2 40 21 PM" src="https://github.com/user-attachments/assets/b6ed68ae-2dbf-400a-b723-06eae6740ade" />
1 parent 2271b63 commit cca0702

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

torchtitan/checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def __init__(
183183
"model": ModelWrapper(model_parts),
184184
"optimizer": optimizers,
185185
"dataloader": dataloader,
186+
"lr_scheduler": lr_schedulers,
186187
}
187188
)
188-
self.states.update(lr_schedulers.get_lr_scheduler_state())
189189

190190
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
191191
self.interval_type = (

torchtitan/optimizer.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def linear_warmup_linear_decay(
178178
return curr_adjustment
179179

180180

181-
class SchedulersContainer:
181+
class SchedulersContainer(Stateful):
182182
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""
183183

184184
def __init__(self, optimizers, lr_lambda) -> None:
@@ -190,16 +190,21 @@ def step(self) -> None:
190190
for scheduler in self.schedulers:
191191
scheduler.step()
192192

193-
def get_lr_scheduler_state(self) -> Dict[str, Any]:
194-
state_dict = {}
195-
if len(self.schedulers) == 1:
196-
state_dict["lr_scheduler"] = self.schedulers[0]
197-
else:
198-
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
199-
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
200-
for idx, lr_scheduler in enumerate(self.schedulers):
201-
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
202-
return state_dict
193+
def state_dict(self) -> Dict[str, Any]:
194+
# Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward,
195+
# there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all.
196+
# Therefore, we only save the first one and later load it for all.
197+
assert (
198+
len(self.schedulers) > 0
199+
), "Must have at least one scheduler to save state_dict"
200+
return self.schedulers[0].state_dict()
201+
202+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
203+
# Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`,
204+
# which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain
205+
# unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety.
206+
for scheduler in self.schedulers:
207+
scheduler.load_state_dict(state_dict.copy())
203208

204209

205210
def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:

0 commit comments

Comments
 (0)