Skip to content

Commit d989842

Browse files
committed
Fix PP+CP handling freqs_cis buffer
When developing test_pp_cp and chatting with fegin, we realized the freqs_cis buffers are not being handled correctly in torchtitan for the pipelining case. CP needs to modify the freqs_cis buffer to account for sharding on seq dim, but in the previous titan code this was implemented incorrectly. `model.freqs_cis` was passed to CP for sharding, but pipelining does not use `model` at all, it uses the different stage-models contained in `model_parts` list. The fix is to tell CP context about each freqs_cis buffer inside `model_parts` models. Alternatively we could tie the freqs_cis buffers for each pp stage together, by explicitly doing so after calling init_weights per pp-stage. However this is of limited value so we skip it. ghstack-source-id: 7aa3935 Pull Request resolved: #792
1 parent f504a14 commit d989842

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

train.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def loss_fn(pred, labels):
154154
pp_schedule, model_parts = models_pipelining_fns[model_name](
155155
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
156156
)
157+
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
158+
del model
157159

158160
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
159161
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
@@ -268,11 +270,12 @@ def loss_fn(pred, labels):
268270
optimizers.zero_grad()
269271

270272
# apply context parallelism if cp is enabled
273+
# ensure CP handles the separate freqs_cis buffer for each pp stage
271274
optional_context_parallel_ctx = (
272275
utils.create_context_parallel_ctx(
273276
cp_mesh=world_mesh["cp"],
274-
cp_buffers=[input_ids, labels, model.freqs_cis],
275-
cp_seq_dims=[1, 1, 0],
277+
cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts],
278+
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
276279
cp_no_restore_buffers={input_ids, labels},
277280
cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
278281
)

0 commit comments

Comments
 (0)