Skip to content

Commit c24fd0b

Browse files
committedFeb 3, 2025
[MoE][PoC] Expert Parallel: dp2ep
ghstack-source-id: 2a70ed917b742c32118ef5ca02f161f833ce46bc Pull Request resolved: #732
1 parent f8d74c8 commit c24fd0b

File tree

7 files changed

+309
-20
lines changed

7 files changed

+309
-20
lines changed
 

‎torchtitan/config_manager.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,23 @@ def __init__(self):
375375
The default value is 'allgather'.
376376
""",
377377
)
378+
self.parser.add_argument(
379+
"--experimental.expert_parallel_degree",
380+
type=int,
381+
default=1,
382+
help="""
383+
Expert parallelism degree. 1 means disabled.
384+
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree.
385+
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree,
386+
where k >= 1 and k | data_parallel_shard_degree.
387+
""",
388+
)
378389
self.parser.add_argument(
379390
"--experimental.expert_parallel_mode",
380391
type=str,
381392
default="none",
382-
choices=["none", "tp", "tp2ep"],
383-
help="""
384-
Expert Parallel mode.
385-
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
386-
""",
393+
choices=["none", "tp", "tp2ep", "dp2ep"],
394+
help="Expert Parallel mode",
387395
)
388396
self.parser.add_argument(
389397
"--training.mixed_precision_param",

‎torchtitan/optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def build_optimizers(
146146
"betas": (0.9, 0.95),
147147
"weight_decay": 0.1,
148148
"fused": fused,
149-
"foreach": not fused,
149+
"foreach": False,
150150
}
151151

152152
return (

‎torchtitan/parallelisms/expert_parallel.py

+119
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
325325
self._prepare_output_fn, self.output_layouts, self.use_local_output
326326
),
327327
)
328+
329+
330+
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
331+
class ExpertTensorParallel(ParallelStyle):
332+
def __init__(
333+
self,
334+
*,
335+
tp_mesh: DeviceMesh,
336+
ep_mesh: DeviceMesh,
337+
):
338+
super().__init__()
339+
# TODO: has to pass in the meshes in addition to device_mesh,
340+
# as there's an issue from DeviceMesh that
341+
# "Cannot create a submesh from a submesh."
342+
self.tp_mesh = tp_mesh
343+
self.ep_mesh = ep_mesh
344+
345+
@staticmethod
346+
def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh):
347+
input_tensor = inputs[0]
348+
# input_tensor of placements Shard(1) on the tp mesh
349+
assert not isinstance(input_tensor, DTensor)
350+
351+
# a2a(ep)
352+
input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),))
353+
input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local()
354+
# ag(tp)
355+
input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),))
356+
input_tensor = input_tensor.redistribute(placements=(Replicate(),))
357+
358+
return input_tensor
359+
360+
@staticmethod
361+
def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh):
362+
# TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet
363+
# module.register_parameter(
364+
# "gate_proj",
365+
# nn.Parameter(
366+
# distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)])
367+
# ),
368+
# ) # Column-wise sharding
369+
# module.register_parameter(
370+
# "down_proj",
371+
# nn.Parameter(
372+
# distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)])
373+
# ),
374+
# ) # Row-wise sharding
375+
# module.register_parameter(
376+
# "up_proj",
377+
# nn.Parameter(
378+
# distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)])
379+
# ),
380+
# ) # Column-wise sharding
381+
382+
# TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it.
383+
# This would become an issue from DCP resharding perspective.
384+
module.register_parameter(
385+
"gate_proj",
386+
nn.Parameter(
387+
DTensor.from_local(
388+
(
389+
distribute_tensor(
390+
module.gate_proj, device_mesh, [Shard(0), Shard(2)]
391+
).to_local()
392+
),
393+
tp_mesh,
394+
(Shard(2),),
395+
)
396+
),
397+
) # Column-wise sharding
398+
module.register_parameter(
399+
"down_proj",
400+
nn.Parameter(
401+
DTensor.from_local(
402+
(
403+
distribute_tensor(
404+
module.down_proj, device_mesh, [Shard(0), Shard(1)]
405+
).to_local()
406+
),
407+
tp_mesh,
408+
(Shard(1),),
409+
)
410+
),
411+
) # Row-wise sharding
412+
module.register_parameter(
413+
"up_proj",
414+
nn.Parameter(
415+
DTensor.from_local(
416+
(
417+
distribute_tensor(
418+
module.up_proj, device_mesh, [Shard(0), Shard(2)]
419+
).to_local()
420+
),
421+
tp_mesh,
422+
(Shard(2),),
423+
)
424+
),
425+
) # Column-wise sharding
426+
427+
@staticmethod
428+
def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh):
429+
# outputs of placements Partial() on the tp mesh
430+
431+
# rs(tp)
432+
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
433+
# a2a(ep)
434+
outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),))
435+
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
436+
437+
return outputs
438+
439+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
440+
return distribute_module(
441+
module,
442+
device_mesh,
443+
partial(self._partition_fn, self.tp_mesh, self.ep_mesh),
444+
partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh),
445+
partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh),
446+
)

‎torchtitan/parallelisms/parallel_dims.py

+78-2
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,24 @@ class ParallelDims:
1818
cp: int
1919
tp: int
2020
pp: int
21+
ep: int
22+
ep_mode: str
2123
world_size: int
2224
enable_loss_parallel: bool
2325

2426
def __post_init__(self):
2527
self._validate()
2628

2729
def _validate(self):
28-
dp_replicate, dp_shard, cp, tp, pp = (
30+
dp_replicate, dp_shard, cp, tp, pp, ep = (
2931
self.dp_replicate,
3032
self.dp_shard,
3133
self.cp,
3234
self.tp,
3335
self.pp,
36+
self.ep,
3437
)
35-
for d in (dp_replicate, cp, tp, pp):
38+
for d in (dp_replicate, cp, tp, pp, ep):
3639
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
3740

3841
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -45,7 +48,80 @@ def _validate(self):
4548
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
4649
)
4750

51+
if ep > 1:
52+
assert self.ep_mode in ["tp", "tp2ep", "dp2ep"]
53+
if self.ep_mode == "tp" or self.ep_mode == "tp2ep":
54+
assert ep == tp
55+
elif self.ep_mode == "dp2ep":
56+
# EP would borrow all cp and some dp_shard degree
57+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
else:
59+
self.ep_mode = "none"
60+
61+
def _build_mesh_with_dp2ep(self, device_type):
62+
# In dp2ep, dp_shard and ep are derived submeshes:
63+
# dp_shard = dp_shard_1 * dp_shard_2
64+
# ep = dp_shard_2 * cp
65+
dp_shard_1 = self.dp_shard * self.cp // self.ep
66+
dp_shard_2 = self.ep // self.cp
67+
68+
dims = []
69+
names = []
70+
for d, name in zip(
71+
[self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp],
72+
["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"],
73+
):
74+
# dp_shard_1 is needed even if it's 1, whose FSDP wrapping
75+
# helps the MoE layers do mixed precision training
76+
if d > 1 or name == "dp_shard_1":
77+
dims.append(d)
78+
names.append(name)
79+
80+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
81+
names = tuple(names)
82+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
83+
84+
# Create all the submesh here to ensure all required process groups are
85+
# initialized:
86+
# Mesh for data loading (no communication on this mesh)
87+
dp_mesh_dim_names = []
88+
# Mesh for param sharding
89+
dp_shard_cp_mesh_dim_names = []
90+
# Mesh for loss all-reduce
91+
dp_cp_mesh_dim_names = []
92+
# Mesh for ep
93+
ep_mesh_dim_names = []
94+
95+
if self.dp_replicate_enabled:
96+
dp_mesh_dim_names.append("dp_replicate")
97+
dp_cp_mesh_dim_names.append("dp_replicate")
98+
# dp_shard_1 is always needed, even if it's 1
99+
dp_mesh_dim_names.append("dp_shard_1")
100+
dp_shard_cp_mesh_dim_names.append("dp_shard_1")
101+
dp_cp_mesh_dim_names.append("dp_shard_1")
102+
if "dp_shard_2" in names:
103+
dp_mesh_dim_names.append("dp_shard_2")
104+
dp_shard_cp_mesh_dim_names.append("dp_shard_2")
105+
dp_cp_mesh_dim_names.append("dp_shard_2")
106+
ep_mesh_dim_names.append("dp_shard_2")
107+
if self.cp_enabled:
108+
dp_shard_cp_mesh_dim_names.append("cp")
109+
dp_cp_mesh_dim_names.append("cp")
110+
ep_mesh_dim_names.append("cp")
111+
112+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
113+
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
114+
mesh_dim_name="dp_shard_cp"
115+
)
116+
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
117+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
118+
119+
return mesh
120+
48121
def build_mesh(self, device_type):
122+
if self.ep_mode == "dp2ep":
123+
return self._build_mesh_with_dp2ep(device_type)
124+
49125
dims = []
50126
names = []
51127
for d, name in zip(

‎torchtitan/parallelisms/parallelize_llama.py

+88-5
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,17 @@ def parallelize_llama(
6565
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
6666
)
6767

68-
ep_mode = job_config.experimental.expert_parallel_mode
69-
if ep_mode != "none":
68+
if parallel_dims.ep_mode != "none":
7069
apply_ep(
7170
model,
72-
ep_mode=ep_mode,
71+
ep_mode=parallel_dims.ep_mode,
72+
ep_mesh=world_mesh["ep"] if parallel_dims.ep_mode == "dp2ep" else None,
7373
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
74+
ep_tp_mesh=(
75+
world_mesh["ep", "tp"]
76+
if parallel_dims.ep_mode == "dp2ep" and parallel_dims.tp_enabled
77+
else None
78+
),
7479
)
7580

7681
if job_config.activation_checkpoint.mode != "none":
@@ -86,20 +91,31 @@ def parallelize_llama(
8691
apply_compile(model)
8792

8893
if (
89-
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
94+
parallel_dims.dp_shard_enabled
95+
or parallel_dims.cp_enabled
96+
or parallel_dims.ep_mode == "dp2ep"
9097
): # apply FSDP or HSDP, potentially with Context Parallel
9198
if parallel_dims.dp_replicate_enabled:
9299
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
93100
else:
94101
dp_mesh_dim_names = ("dp_shard_cp",)
95102

103+
# the mesh dim names of which the MoE params are sharded on
104+
dp_mod_ep_mesh_dim_names = []
105+
if parallel_dims.ep_mode == "dp2ep":
106+
if parallel_dims.dp_replicate_enabled:
107+
dp_mod_ep_mesh_dim_names.append("dp_replicate")
108+
dp_mod_ep_mesh_dim_names.append("dp_shard_1")
109+
96110
apply_fsdp(
97111
model,
98112
world_mesh[tuple(dp_mesh_dim_names)],
99113
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
100114
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
101115
pp_enabled=parallel_dims.pp_enabled,
102116
cpu_offload=job_config.training.enable_cpu_offload,
117+
ep_enabled=(parallel_dims.ep_mode == "dp2ep"),
118+
dp_mod_ep_mesh=world_mesh[tuple(dp_mod_ep_mesh_dim_names)],
103119
)
104120

105121
if parallel_dims.dp_replicate_enabled:
@@ -222,11 +238,15 @@ def apply_tp(
222238
def apply_ep(
223239
model: nn.Module,
224240
ep_mode: str,
241+
ep_mesh: Optional[DeviceMesh] = None,
225242
tp_mesh: Optional[DeviceMesh] = None,
243+
ep_tp_mesh: Optional[DeviceMesh] = None,
226244
):
227245
from torch.distributed.tensor import Partial
246+
from torch.distributed.tensor.parallel import PrepareModuleOutput
228247
from torchtitan.parallelisms.expert_parallel import (
229248
ExpertParallel,
249+
ExpertTensorParallel,
230250
PrepareModuleInputOutput,
231251
TensorParallel,
232252
)
@@ -282,6 +302,57 @@ def apply_ep(
282302
parallelize_plan=moe_plan,
283303
)
284304

305+
elif ep_mode == "dp2ep":
306+
if not tp_mesh:
307+
assert ep_mesh is not None
308+
309+
for _, transformer_block in model.layers.items():
310+
parallelize_module(
311+
module=transformer_block.moe.experts,
312+
device_mesh=ep_mesh,
313+
# input / output sharding on the tokens dim
314+
parallelize_plan=ExpertParallel(
315+
input_layouts=Shard(1),
316+
output_layouts=Shard(1),
317+
),
318+
)
319+
320+
else: # dp2ep with TP (no Router Parallel)
321+
assert ep_tp_mesh is not None
322+
323+
for _, transformer_block in model.layers.items():
324+
moe_plan = {
325+
# input / output sharding on the seqlen dim
326+
"moe": PrepareModuleInputOutput(
327+
input_layouts=(Shard(1),),
328+
desired_input_layouts=(Replicate(),),
329+
output_layouts=(Partial(),),
330+
desired_output_layouts=(Shard(1),),
331+
),
332+
# no Router Parallel
333+
# NOTE: still need to explicitly or implicitly turn the router into DTensor
334+
# for gradient clippint and optimizer to use DTensor foreach
335+
# top_scores, selected_token_indices shareded on the seqlen dim
336+
"moe.router": PrepareModuleOutput(
337+
output_layouts=(Replicate(), Replicate()),
338+
desired_output_layouts=(Shard(1), Shard(1)),
339+
),
340+
"moe.shared_expert": TensorParallel(),
341+
}
342+
parallelize_module(
343+
module=transformer_block,
344+
device_mesh=tp_mesh,
345+
parallelize_plan=moe_plan,
346+
)
347+
348+
parallelize_module(
349+
module=transformer_block.moe.experts,
350+
device_mesh=ep_tp_mesh,
351+
parallelize_plan=ExpertTensorParallel(
352+
tp_mesh=tp_mesh, ep_mesh=ep_mesh
353+
),
354+
)
355+
285356
logger.info(f"Applied {ep_mode} Expert Parallelism to the model")
286357

287358

@@ -375,7 +446,7 @@ def apply_compile(model: nn.Module):
375446
repeated structure. Alternatively one can compile the whole model (after applying DP).
376447
"""
377448
for layer_id, transformer_block in model.layers.named_children():
378-
transformer_block = torch.compile(transformer_block, fullgraph=True)
449+
transformer_block = torch.compile(transformer_block, fullgraph=False)
379450
model.layers.register_module(layer_id, transformer_block)
380451

381452
logger.info("Compiling each TransformerBlock with torch.compile")
@@ -388,6 +459,8 @@ def apply_fsdp(
388459
reduce_dtype: torch.dtype,
389460
pp_enabled: bool,
390461
cpu_offload: bool = False,
462+
ep_enabled: bool = False,
463+
dp_mod_ep_mesh: Optional[DeviceMesh] = None,
391464
):
392465
"""
393466
Apply data parallelism to the model. FSDP2 is used here.
@@ -406,6 +479,16 @@ def apply_fsdp(
406479
# As an optimization, do not reshard after forward for the last
407480
# transformer block since FSDP would prefetch it immediately
408481
reshard_after_forward = int(layer_id) < len(model.layers) - 1
482+
483+
fsdp_mod_ep_config = fsdp_config.copy()
484+
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh
485+
if ep_enabled:
486+
fully_shard(
487+
transformer_block.moe.experts,
488+
**fsdp_mod_ep_config,
489+
reshard_after_forward=reshard_after_forward,
490+
)
491+
409492
fully_shard(
410493
transformer_block,
411494
**fsdp_config,

‎train.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def main(job_config: JobConfig):
5353
cp=job_config.experimental.context_parallel_degree,
5454
tp=job_config.training.tensor_parallel_degree,
5555
pp=job_config.experimental.pipeline_parallel_degree,
56+
ep=job_config.experimental.expert_parallel_degree,
57+
ep_mode=job_config.experimental.expert_parallel_mode,
5658
world_size=world_size,
5759
enable_loss_parallel=not job_config.training.disable_loss_parallel,
5860
)
@@ -314,12 +316,12 @@ def loss_fn(pred, labels):
314316
loss.backward()
315317

316318
# clip gradients
317-
utils.clip_grad_norm_(
318-
[p for m in model_parts for p in m.parameters()],
319-
job_config.training.max_norm,
320-
foreach=True,
321-
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
322-
)
319+
# utils.clip_grad_norm_(
320+
# [p for m in model_parts for p in m.parameters()],
321+
# job_config.training.max_norm,
322+
# foreach=True,
323+
# pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
324+
# )
323325

324326
# optimizer step
325327
checkpoint.maybe_wait_for_staging()

‎train_configs/debug_model.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4747
context_parallel_degree = 1
4848
pipeline_parallel_degree = 1
4949
enable_async_tensor_parallel = false
50-
expert_parallel_mode = "tp2ep"
50+
expert_parallel_degree = 8
51+
expert_parallel_mode = "dp2ep"
5152

5253
[checkpoint]
5354
enable_checkpoint = false

0 commit comments

Comments
 (0)
Please sign in to comment.