Skip to content

Commit f504a14

Browse files
authored
Disable integration test between optimizer-in-backward and pp (#793)
Optimizer-in-backward would free gradients memory during backward, causing integration test failure with pp at gradient scale Disable test with pp first, would enable later with support to multi schedule pp Add test with dp, tp, cp, hsdp
1 parent 2fa6d83 commit f504a14

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

tests/integration_tests.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,22 @@ def build_test_list():
378378
[
379379
[
380380
"--checkpoint.enable_checkpoint",
381-
"--experimental.pipeline_parallel_degree 2",
381+
"--training.tensor_parallel_degree=2",
382+
"--experimental.context_parallel_degree=2",
383+
"--training.enable_cpu_offload",
384+
"--optimizer.early_step_in_backward",
385+
],
386+
[
387+
"--training.tensor_parallel_degree=2",
388+
"--experimental.context_parallel_degree=2",
389+
"--training.data_parallel_replicate_degree=2",
382390
"--training.enable_cpu_offload",
383391
"--optimizer.early_step_in_backward",
384392
],
385393
],
386-
"Enable CPU Offload with PP",
387-
"enable_cpu_offload+PP",
388-
ngpu=4,
394+
"Enable CPU Offload, Optimizer in backward with TP, DP, CP",
395+
"cpu_offload+opt_in_bwd+TP+DP+CP",
396+
ngpu=8,
389397
),
390398
OverrideDefinitions(
391399
[

torchtitan/optimizer.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,37 @@ def __init__(
8181
) -> None:
8282
self.optimizers = []
8383
self.model_parts = model_parts
84+
optim_dict = {}
8485
for model in self.model_parts:
8586
if name == "Adam":
8687
# TODO: make the optimizer options configurable by toml/cmd args
87-
optim_dict = {
88-
param: torch.optim.Adam([param], **optimizer_kwargs)
89-
for param in model.parameters()
90-
}
88+
optim_dict.update(
89+
{
90+
param: torch.optim.Adam([param], **optimizer_kwargs)
91+
for param in model.parameters()
92+
}
93+
)
9194
elif name == "AdamW":
92-
optim_dict = {
93-
param: torch.optim.AdamW([param], **optimizer_kwargs)
94-
for param in model.parameters()
95-
}
95+
optim_dict.update(
96+
{
97+
param: torch.optim.AdamW([param], **optimizer_kwargs)
98+
for param in model.parameters()
99+
}
100+
)
96101
else:
97102
raise NotImplementedError(f"Optimizer {name} not added.")
98103

99-
def optim_hook(param) -> None:
100-
optim_dict[param].step()
101-
optim_dict[param].zero_grad()
104+
def optim_hook(param) -> None:
105+
optim_dict[param].step()
106+
optim_dict[param].zero_grad()
102107

108+
for model in self.model_parts:
103109
for param in model.parameters():
104110
if param.requires_grad:
105111
param.register_post_accumulate_grad_hook(optim_hook)
106112

107113
self.optimizers.extend([optim_dict[param] for param in model.parameters()])
114+
108115
self._validate_length(
109116
sum(
110117
len([param for param in model.parameters()])
@@ -127,6 +134,10 @@ def build_optimizers(
127134
step() and zero_grad() method for all the child optimizers.
128135
"""
129136
optim_in_bwd = job_config.optimizer.early_step_in_backward
137+
if optim_in_bwd and job_config.experimental.pipeline_parallel_degree > 1:
138+
raise NotImplementedError(
139+
"Optimizers in backward is not supported with pipeline parallelism."
140+
)
130141
name = job_config.optimizer.name
131142
lr = job_config.optimizer.lr
132143
fused = job_config.optimizer.fused

0 commit comments

Comments
 (0)