Skip to content

Commit 25cfe6d

Browse files
committed
[MoE][PoC] Expert Parallel: dp2ep
ghstack-source-id: 17160930f23950b91faca7b822cd3e7f9d075f7d Pull Request resolved: #732
1 parent 83d1714 commit 25cfe6d

File tree

7 files changed

+306
-22
lines changed

7 files changed

+306
-22
lines changed

torchtitan/config_manager.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,23 @@ def __init__(self):
364364
default=1,
365365
help="Context parallelism degree. 1 means disabled.",
366366
)
367+
self.parser.add_argument(
368+
"--experimental.expert_parallel_degree",
369+
type=int,
370+
default=1,
371+
help="""
372+
Expert parallelism degree. 1 means disabled.
373+
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree.
374+
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree,
375+
where k >= 1 and k | data_parallel_shard_degree.
376+
""",
377+
)
367378
self.parser.add_argument(
368379
"--experimental.expert_parallel_mode",
369380
type=str,
370381
default="none",
371-
choices=["none", "tp", "tp2ep"],
372-
help="""
373-
Expert Parallel mode.
374-
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
375-
""",
382+
choices=["none", "tp", "tp2ep", "dp2ep"],
383+
help="Expert Parallel mode",
376384
)
377385
self.parser.add_argument(
378386
"--training.mixed_precision_param",

torchtitan/optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def build_optimizers(model_parts, job_config: JobConfig):
8686
"betas": (0.9, 0.95),
8787
"weight_decay": 0.1,
8888
"fused": fused,
89-
"foreach": not fused,
89+
"foreach": False,
9090
}
9191

9292
return (

torchtitan/parallelisms/expert_parallel.py

+119
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
325325
self._prepare_output_fn, self.output_layouts, self.use_local_output
326326
),
327327
)
328+
329+
330+
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
331+
class ExpertTensorParallel(ParallelStyle):
332+
def __init__(
333+
self,
334+
*,
335+
tp_mesh: DeviceMesh,
336+
ep_mesh: DeviceMesh,
337+
):
338+
super().__init__()
339+
# TODO: has to pass in the meshes in addition to device_mesh,
340+
# as there's an issue from DeviceMesh that
341+
# "Cannot create a submesh from a submesh."
342+
self.tp_mesh = tp_mesh
343+
self.ep_mesh = ep_mesh
344+
345+
@staticmethod
346+
def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh):
347+
input_tensor = inputs[0]
348+
# input_tensor of placements Shard(1) on the tp mesh
349+
assert not isinstance(input_tensor, DTensor)
350+
351+
# a2a(ep)
352+
input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),))
353+
input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local()
354+
# ag(tp)
355+
input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),))
356+
input_tensor = input_tensor.redistribute(placements=(Replicate(),))
357+
358+
return input_tensor
359+
360+
@staticmethod
361+
def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh):
362+
# TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet
363+
# module.register_parameter(
364+
# "gate_proj",
365+
# nn.Parameter(
366+
# distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)])
367+
# ),
368+
# ) # Column-wise sharding
369+
# module.register_parameter(
370+
# "down_proj",
371+
# nn.Parameter(
372+
# distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)])
373+
# ),
374+
# ) # Row-wise sharding
375+
# module.register_parameter(
376+
# "up_proj",
377+
# nn.Parameter(
378+
# distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)])
379+
# ),
380+
# ) # Column-wise sharding
381+
382+
# TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it.
383+
# This would become an issue from DCP resharding perspective.
384+
module.register_parameter(
385+
"gate_proj",
386+
nn.Parameter(
387+
DTensor.from_local(
388+
(
389+
distribute_tensor(
390+
module.gate_proj, device_mesh, [Shard(0), Shard(2)]
391+
).to_local()
392+
),
393+
tp_mesh,
394+
(Shard(2),),
395+
)
396+
),
397+
) # Column-wise sharding
398+
module.register_parameter(
399+
"down_proj",
400+
nn.Parameter(
401+
DTensor.from_local(
402+
(
403+
distribute_tensor(
404+
module.down_proj, device_mesh, [Shard(0), Shard(1)]
405+
).to_local()
406+
),
407+
tp_mesh,
408+
(Shard(1),),
409+
)
410+
),
411+
) # Row-wise sharding
412+
module.register_parameter(
413+
"up_proj",
414+
nn.Parameter(
415+
DTensor.from_local(
416+
(
417+
distribute_tensor(
418+
module.up_proj, device_mesh, [Shard(0), Shard(2)]
419+
).to_local()
420+
),
421+
tp_mesh,
422+
(Shard(2),),
423+
)
424+
),
425+
) # Column-wise sharding
426+
427+
@staticmethod
428+
def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh):
429+
# outputs of placements Partial() on the tp mesh
430+
431+
# rs(tp)
432+
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
433+
# a2a(ep)
434+
outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),))
435+
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
436+
437+
return outputs
438+
439+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
440+
return distribute_module(
441+
module,
442+
device_mesh,
443+
partial(self._partition_fn, self.tp_mesh, self.ep_mesh),
444+
partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh),
445+
partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh),
446+
)

torchtitan/parallelisms/parallel_dims.py

+72-2
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,24 @@ class ParallelDims:
1818
cp: int
1919
tp: int
2020
pp: int
21+
ep: int
22+
ep_mode: str
2123
world_size: int
2224
enable_loss_parallel: bool
2325

2426
def __post_init__(self):
2527
self._validate()
2628

2729
def _validate(self):
28-
dp_replicate, dp_shard, cp, tp, pp = (
30+
dp_replicate, dp_shard, cp, tp, pp, ep = (
2931
self.dp_replicate,
3032
self.dp_shard,
3133
self.cp,
3234
self.tp,
3335
self.pp,
36+
self.ep,
3437
)
35-
for d in (dp_replicate, cp, tp, pp):
38+
for d in (dp_replicate, cp, tp, pp, ep):
3639
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
3740

3841
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
@@ -45,7 +48,74 @@ def _validate(self):
4548
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
4649
)
4750

51+
if ep > 1:
52+
assert self.ep_mode in ["tp", "tp2ep", "dp2ep"]
53+
if self.ep_mode == "tp" or self.ep_mode == "tp2ep":
54+
assert ep == tp
55+
elif self.ep_mode == "dp2ep":
56+
# EP would borrow all cp and some dp_shard degree
57+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
else:
59+
self.ep_mode = "none"
60+
61+
def _build_mesh_with_dp2ep(self, device_type):
62+
# In dp2ep, dp_shard and ep are derived submeshes:
63+
# dp_shard = dp_shard_1 * dp_shard_2
64+
# ep = dp_shard_2 * cp
65+
dp_shard_1 = self.dp_shard * self.cp // self.ep
66+
dp_shard_2 = self.ep // self.cp
67+
68+
dims = []
69+
names = []
70+
for d, name in zip(
71+
[self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp],
72+
["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"],
73+
):
74+
# dp_shard_1 is needed even if it's 1, whose FSDP wrapping
75+
# helps the MoE layers do mixed precision training
76+
if d > 1 or name == "dp_shard_1":
77+
dims.append(d)
78+
names.append(name)
79+
80+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
81+
names = tuple(names)
82+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
83+
84+
# Create all the submesh here to ensure all required process groups are
85+
# initialized:
86+
# Mesh for data loading
87+
dp_mesh_dim_names = []
88+
if self.dp_replicate_enabled:
89+
dp_mesh_dim_names.append("dp_replicate")
90+
dp_mesh_dim_names.append("dp_shard_1")
91+
if "dp_shard_2" in names:
92+
dp_mesh_dim_names.append("dp_shard_2")
93+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
94+
95+
# Mesh for param sharding
96+
dp_shard_cp_mesh_dim_name = []
97+
dp_shard_cp_mesh_dim_name.append("dp_shard_1")
98+
if "dp_shard_2" in names:
99+
dp_shard_cp_mesh_dim_name.append("dp_shard_2")
100+
if self.cp_enabled:
101+
dp_shard_cp_mesh_dim_name.append("cp")
102+
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")
103+
104+
# Mesh for ep
105+
ep_mesh_dim_names = []
106+
if "dp_shard_2" in names:
107+
ep_mesh_dim_names.append("dp_shard_2")
108+
if self.cp_enabled:
109+
ep_mesh_dim_names.append("cp")
110+
assert len(ep_mesh_dim_names) > 0
111+
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
112+
113+
return mesh
114+
48115
def build_mesh(self, device_type):
116+
if self.ep_mode == "dp2ep":
117+
return self._build_mesh_with_dp2ep(device_type)
118+
49119
dims = []
50120
names = []
51121
for d, name in zip(

0 commit comments

Comments
 (0)