Skip to content

Commit 3996b63

Browse files
H-Huangtianyu-l
andauthored
Support ZBVZeroBubbleSchedule (#817)
This is dependent on the changes in this pytorch stack: pytorch/pytorch#146217 Add support for running `ZBVZeroBubbleSchedule` and v-shaped CSV schedules in torchtitan Fixes #774 --------- Co-authored-by: tianyu-l <[email protected]>
1 parent 49c6d6f commit 3996b63

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

tests/integration_tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def build_test_list():
139139
"pp_looped_zero_bubble",
140140
ngpu=4,
141141
),
142+
OverrideDefinitions(
143+
[
144+
[
145+
"--experimental.pipeline_parallel_degree 2",
146+
"--experimental.pipeline_parallel_schedule ZBVZeroBubble",
147+
"--experimental.pipeline_parallel_microbatches 8",
148+
],
149+
],
150+
"PP zero bubble test (v shaped)",
151+
"pp_zbv",
152+
ngpu=2,
153+
),
142154
OverrideDefinitions(
143155
[
144156
[

torchtitan/parallelisms/pipeline_llama.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import torch.nn as nn
1414
from torch.distributed import DeviceMesh
1515
from torch.distributed.pipelining import PipelineStage
16-
16+
from torch.distributed.pipelining.schedules import (
17+
get_schedule_class,
18+
ScheduleZBVZeroBubble,
19+
)
1720
from torchtitan.config_manager import JobConfig
1821
from torchtitan.logging import logger
1922
from torchtitan.models.llama.model import ModelArgs
@@ -43,7 +46,16 @@ def pipeline_llama(
4346

4447
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
4548

46-
return pp_schedule, models
49+
# This is used in the train loop to determine whether to pass in the input_ids and labels
50+
has_first_stage = False
51+
has_last_stage = False
52+
for stage in stages:
53+
if stage.is_first:
54+
has_first_stage = True
55+
if stage.is_last:
56+
has_last_stage = True
57+
58+
return pp_schedule, models, has_first_stage, has_last_stage
4759

4860

4961
def pipeline_llama_manual_split(
@@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
103115

104116
stages = []
105117
models = []
106-
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
118+
119+
schedule_class = get_schedule_class(
120+
job_config.experimental.pipeline_parallel_schedule
121+
)
122+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
123+
124+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
107125
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
108126
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
109127
stage, model_chunk = _build_stage(

train.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,12 @@ def loss_fn(pred, labels):
151151
# apply parallelisms and initialization
152152
if parallel_dims.pp_enabled:
153153
# apply PT-D Pipeline Parallel
154-
pp_schedule, model_parts = models_pipelining_fns[model_name](
154+
(
155+
pp_schedule,
156+
model_parts,
157+
has_first_stage,
158+
has_last_stage,
159+
) = models_pipelining_fns[model_name](
155160
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
156161
)
157162
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
@@ -285,22 +290,18 @@ def loss_fn(pred, labels):
285290

286291
if parallel_dims.pp_enabled:
287292
# Pipeline Parallel forward / backward inside step() call
288-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
289-
290293
with train_context(optional_context_parallel_ctx):
291-
if pp_mesh.get_local_rank() == 0:
292-
pp_schedule.step(input_ids)
293-
elif is_last_stage:
294-
losses = []
295-
pp_schedule.step(target=labels, losses=losses)
294+
targets, losses = (labels, []) if has_last_stage else (None, None)
295+
if has_first_stage:
296+
pp_schedule.step(input_ids, target=targets, losses=losses)
296297
else:
297-
pp_schedule.step()
298+
pp_schedule.step(target=targets, losses=losses)
298299

299300
# accumulate losses across pipeline microbatches
300301
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
301302
loss = (
302303
torch.mean(torch.stack(losses)).to(device)
303-
if is_last_stage
304+
if has_last_stage
304305
else torch.tensor([-1.0], device=device)
305306
)
306307
else:

0 commit comments

Comments
 (0)