You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, lr_scheduler is stored differently as optimizer, model and
data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ...
stored in the state
This PR aims to flatten lr_shceduler so that all the schedulers would be
stored as a list of state_dict under self.state['lr_scheduler'], which
is consistent with optimizer
Here we have the assumption that all the optimziers have the same
lr_scheduler, thus only to save a single lr_scheduler's state_dict and
load it to all the schedulers works here.
The lr_scheduler has the state_dict like:
`{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False,
'_step_count': 2, '_get_lr_called_within_step': False, '_last_lr':
[2.985074626865671e-06], 'lr_lambdas': [{}]}`
The PR is tested by 2 parts:
1. test lr_scheduler value before and after checkpoint, resharding with
degree changes on tp and pp.
[dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4]
[dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1]
date_loader does not support resharding right now.
logs:
[dp=2, tp=4, pp=1]
step 5 before saving to checkpoint:
[{'lr': 8.955223880597014e-06, ...}]
step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]:
[{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05,
...}]
[dp=8, tp=1, pp=1]
step 5 without checkpoint:
[{'lr': 8.955223880597014e-06, ...}]
step 10 without checkpoint:
[{'lr': 1.6417910447761194e-05, ...}]
2. Memory trace:
Before the flatten, rerun llama3_8b.toml from step 5 to step 10:
<img width="1166" alt="Screenshot 2025-01-16 at 2 40 03 PM"
src="https://github.com/user-attachments/assets/d3e84d63-30be-4604-823b-68bd217498a0"
/>
After the flatten, rerun llama3_8b.toml from step 5 to step 10:
<img width="1166" alt="Screenshot 2025-01-16 at 2 40 21 PM"
src="https://github.com/user-attachments/assets/b6ed68ae-2dbf-400a-b723-06eae6740ade"
/>
0 commit comments