@@ -65,12 +65,17 @@ def parallelize_llama(
65
65
enable_async_tp = job_config .experimental .enable_async_tensor_parallel ,
66
66
)
67
67
68
- ep_mode = job_config .experimental .expert_parallel_mode
69
- if ep_mode != "none" :
68
+ if parallel_dims .ep_mode != "none" :
70
69
apply_ep (
71
70
model ,
72
- ep_mode = ep_mode ,
71
+ ep_mode = parallel_dims .ep_mode ,
72
+ ep_mesh = world_mesh ["ep" ] if parallel_dims .ep_mode == "dp2ep" else None ,
73
73
tp_mesh = world_mesh ["tp" ] if parallel_dims .tp_enabled else None ,
74
+ ep_tp_mesh = (
75
+ world_mesh ["ep" , "tp" ]
76
+ if parallel_dims .ep_mode == "dp2ep" and parallel_dims .tp_enabled
77
+ else None
78
+ ),
74
79
)
75
80
76
81
if job_config .activation_checkpoint .mode != "none" :
@@ -86,20 +91,31 @@ def parallelize_llama(
86
91
apply_compile (model )
87
92
88
93
if (
89
- parallel_dims .dp_shard_enabled or parallel_dims .cp_enabled
94
+ parallel_dims .dp_shard_enabled
95
+ or parallel_dims .cp_enabled
96
+ or parallel_dims .ep_mode == "dp2ep"
90
97
): # apply FSDP or HSDP, potentially with Context Parallel
91
98
if parallel_dims .dp_replicate_enabled :
92
99
dp_mesh_dim_names = ("dp_replicate" , "dp_shard_cp" )
93
100
else :
94
101
dp_mesh_dim_names = ("dp_shard_cp" ,)
95
102
103
+ # the mesh dim names of which the MoE params are sharded on
104
+ dp_mod_ep_mesh_dim_names = []
105
+ if parallel_dims .ep_mode == "dp2ep" :
106
+ if parallel_dims .dp_replicate_enabled :
107
+ dp_mod_ep_mesh_dim_names .append ("dp_replicate" )
108
+ dp_mod_ep_mesh_dim_names .append ("dp_shard_1" )
109
+
96
110
apply_fsdp (
97
111
model ,
98
112
world_mesh [tuple (dp_mesh_dim_names )],
99
113
param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
100
114
reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ],
101
115
pp_enabled = parallel_dims .pp_enabled ,
102
116
cpu_offload = job_config .training .enable_cpu_offload ,
117
+ ep_enabled = (parallel_dims .ep_mode == "dp2ep" ),
118
+ dp_mod_ep_mesh = world_mesh [tuple (dp_mod_ep_mesh_dim_names )],
103
119
)
104
120
105
121
if parallel_dims .dp_replicate_enabled :
@@ -222,11 +238,15 @@ def apply_tp(
222
238
def apply_ep (
223
239
model : nn .Module ,
224
240
ep_mode : str ,
241
+ ep_mesh : Optional [DeviceMesh ] = None ,
225
242
tp_mesh : Optional [DeviceMesh ] = None ,
243
+ ep_tp_mesh : Optional [DeviceMesh ] = None ,
226
244
):
227
245
from torch .distributed .tensor import Partial
246
+ from torch .distributed .tensor .parallel import PrepareModuleOutput
228
247
from torchtitan .parallelisms .expert_parallel import (
229
248
ExpertParallel ,
249
+ ExpertTensorParallel ,
230
250
PrepareModuleInputOutput ,
231
251
TensorParallel ,
232
252
)
@@ -282,6 +302,57 @@ def apply_ep(
282
302
parallelize_plan = moe_plan ,
283
303
)
284
304
305
+ elif ep_mode == "dp2ep" :
306
+ if not tp_mesh :
307
+ assert ep_mesh is not None
308
+
309
+ for _ , transformer_block in model .layers .items ():
310
+ parallelize_module (
311
+ module = transformer_block .moe .experts ,
312
+ device_mesh = ep_mesh ,
313
+ # input / output sharding on the tokens dim
314
+ parallelize_plan = ExpertParallel (
315
+ input_layouts = Shard (1 ),
316
+ output_layouts = Shard (1 ),
317
+ ),
318
+ )
319
+
320
+ else : # dp2ep with TP (no Router Parallel)
321
+ assert ep_tp_mesh is not None
322
+
323
+ for _ , transformer_block in model .layers .items ():
324
+ moe_plan = {
325
+ # input / output sharding on the seqlen dim
326
+ "moe" : PrepareModuleInputOutput (
327
+ input_layouts = (Shard (1 ),),
328
+ desired_input_layouts = (Replicate (),),
329
+ output_layouts = (Partial (),),
330
+ desired_output_layouts = (Shard (1 ),),
331
+ ),
332
+ # no Router Parallel
333
+ # NOTE: still need to explicitly or implicitly turn the router into DTensor
334
+ # for gradient clippint and optimizer to use DTensor foreach
335
+ # top_scores, selected_token_indices shareded on the seqlen dim
336
+ "moe.router" : PrepareModuleOutput (
337
+ output_layouts = (Replicate (), Replicate ()),
338
+ desired_output_layouts = (Shard (1 ), Shard (1 )),
339
+ ),
340
+ "moe.shared_expert" : TensorParallel (),
341
+ }
342
+ parallelize_module (
343
+ module = transformer_block ,
344
+ device_mesh = tp_mesh ,
345
+ parallelize_plan = moe_plan ,
346
+ )
347
+
348
+ parallelize_module (
349
+ module = transformer_block .moe .experts ,
350
+ device_mesh = ep_tp_mesh ,
351
+ parallelize_plan = ExpertTensorParallel (
352
+ tp_mesh = tp_mesh , ep_mesh = ep_mesh
353
+ ),
354
+ )
355
+
285
356
logger .info (f"Applied { ep_mode } Expert Parallelism to the model" )
286
357
287
358
@@ -375,7 +446,7 @@ def apply_compile(model: nn.Module):
375
446
repeated structure. Alternatively one can compile the whole model (after applying DP).
376
447
"""
377
448
for layer_id , transformer_block in model .layers .named_children ():
378
- transformer_block = torch .compile (transformer_block , fullgraph = True )
449
+ transformer_block = torch .compile (transformer_block , fullgraph = False )
379
450
model .layers .register_module (layer_id , transformer_block )
380
451
381
452
logger .info ("Compiling each TransformerBlock with torch.compile" )
@@ -388,6 +459,8 @@ def apply_fsdp(
388
459
reduce_dtype : torch .dtype ,
389
460
pp_enabled : bool ,
390
461
cpu_offload : bool = False ,
462
+ ep_enabled : bool = False ,
463
+ dp_mod_ep_mesh : Optional [DeviceMesh ] = None ,
391
464
):
392
465
"""
393
466
Apply data parallelism to the model. FSDP2 is used here.
@@ -406,6 +479,16 @@ def apply_fsdp(
406
479
# As an optimization, do not reshard after forward for the last
407
480
# transformer block since FSDP would prefetch it immediately
408
481
reshard_after_forward = int (layer_id ) < len (model .layers ) - 1
482
+
483
+ fsdp_mod_ep_config = fsdp_config .copy ()
484
+ fsdp_mod_ep_config ["mesh" ] = dp_mod_ep_mesh
485
+ if ep_enabled :
486
+ fully_shard (
487
+ transformer_block .moe .experts ,
488
+ ** fsdp_mod_ep_config ,
489
+ reshard_after_forward = reshard_after_forward ,
490
+ )
491
+
409
492
fully_shard (
410
493
transformer_block ,
411
494
** fsdp_config ,
0 commit comments