Skip to content

Commit f8d74c8

Browse files
committed
[MoE][PoC] Expert Parallel: tp and tp2ep
ghstack-source-id: 351648e881125aee147c14f6af4aee4219632dd4 Pull Request resolved: #731
1 parent 624dd8e commit f8d74c8

File tree

4 files changed

+413
-0
lines changed

4 files changed

+413
-0
lines changed

torchtitan/config_manager.py

+10
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,16 @@ def __init__(self):
375375
The default value is 'allgather'.
376376
""",
377377
)
378+
self.parser.add_argument(
379+
"--experimental.expert_parallel_mode",
380+
type=str,
381+
default="none",
382+
choices=["none", "tp", "tp2ep"],
383+
help="""
384+
Expert Parallel mode.
385+
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
386+
""",
387+
)
378388
self.parser.add_argument(
379389
"--training.mixed_precision_param",
380390
type=str,
+327
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: allow-untyped-defs
8+
# Copyright (c) Meta Platforms, Inc. and affiliates
9+
from functools import partial
10+
from typing import Any, Dict, Optional, Tuple, Union
11+
12+
import torch
13+
import torch.nn as nn
14+
from torch.distributed.tensor import (
15+
DeviceMesh,
16+
distribute_module,
17+
distribute_tensor,
18+
DTensor,
19+
Partial,
20+
Replicate,
21+
Shard,
22+
)
23+
from torch.distributed.tensor.parallel import ParallelStyle
24+
from torch.distributed.tensor.placement_types import Placement
25+
26+
27+
# This is similar to PrepareModuleInput and PrepareModuleOutput,
28+
# but applies them simultaneously.
29+
class PrepareModuleInputOutput(ParallelStyle):
30+
def __init__(
31+
self,
32+
*,
33+
input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
34+
desired_input_layouts: Optional[
35+
Union[Placement, Tuple[Optional[Placement]]]
36+
] = None,
37+
input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
38+
desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
39+
use_local_input: bool = True,
40+
output_layouts: Union[Placement, Tuple[Placement]],
41+
desired_output_layouts: Union[Placement, Tuple[Placement]],
42+
use_local_output: bool = True,
43+
):
44+
# for input
45+
self.input_layouts = (
46+
(input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
47+
)
48+
self.desired_input_layouts = (
49+
(desired_input_layouts,)
50+
if isinstance(desired_input_layouts, Placement)
51+
else desired_input_layouts
52+
)
53+
self.use_local_input = use_local_input
54+
if self.input_layouts is not None:
55+
assert (
56+
self.desired_input_layouts is not None
57+
), "desired module inputs should not be None!"
58+
assert len(self.input_layouts) == len(
59+
self.desired_input_layouts
60+
), "input_layouts and desired_input_layouts should have same length!"
61+
self.with_kwargs = input_kwarg_layouts is not None
62+
self.input_kwarg_layouts = input_kwarg_layouts or {}
63+
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
64+
if self.with_kwargs:
65+
assert len(self.input_kwarg_layouts) == len(
66+
self.desired_input_kwarg_layouts
67+
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
68+
69+
# for output
70+
self.output_layouts = (
71+
(output_layouts,)
72+
if isinstance(output_layouts, Placement)
73+
else output_layouts
74+
)
75+
self.desired_output_layouts = (
76+
(desired_output_layouts,)
77+
if isinstance(desired_output_layouts, Placement)
78+
else desired_output_layouts
79+
)
80+
self.use_local_output = use_local_output
81+
assert len(self.output_layouts) == len(
82+
self.desired_output_layouts
83+
), "output_layouts and desired_output_layouts should have same length!"
84+
85+
def _prepare_input_arg(
86+
self,
87+
input: Any,
88+
mesh: DeviceMesh,
89+
input_layout: Optional[Placement],
90+
desired_layout: Optional[Placement],
91+
):
92+
if input_layout is not None:
93+
if isinstance(input, DTensor):
94+
# TODO: re-enable the check once we fix the compile path
95+
# assert inp.placements[0] == input_layout
96+
dt_inp = input
97+
else:
98+
assert isinstance(
99+
input, torch.Tensor
100+
), "expecting input to be a torch.Tensor!"
101+
dt_inp = DTensor.from_local(
102+
input, mesh, (input_layout,), run_check=False
103+
)
104+
105+
if desired_layout is not None and input_layout != desired_layout:
106+
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
107+
108+
return dt_inp.to_local() if self.use_local_input else dt_inp
109+
else:
110+
return input
111+
112+
def _prepare_input_fn(self, inputs, device_mesh):
113+
if self.input_layouts is None:
114+
return inputs
115+
prepared_inputs = []
116+
if not isinstance(inputs, tuple):
117+
inputs = (inputs,)
118+
if len(inputs) != len(self.input_layouts):
119+
raise ValueError("module inputs and input_layouts should have same length!")
120+
121+
assert (
122+
self.desired_input_layouts is not None
123+
), "desired module inputs should not be None!"
124+
for inp, input_layout, desired_layout in zip(
125+
inputs, self.input_layouts, self.desired_input_layouts
126+
):
127+
prepared_inputs.append(
128+
self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
129+
)
130+
return tuple(prepared_inputs)
131+
132+
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
133+
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
134+
prepared_kwarg_inputs = {}
135+
for kwarg_key in kwarg_inputs.keys():
136+
kwarg_val = kwarg_inputs[kwarg_key]
137+
input_layout = self.input_kwarg_layouts.get(kwarg_key)
138+
desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
139+
140+
prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
141+
kwarg_val, device_mesh, input_layout, desired_input_layout
142+
)
143+
144+
return (prepared_arg_inputs, prepared_kwarg_inputs)
145+
146+
def _prepare_out_fn(self, outputs, device_mesh):
147+
prepared_outputs = []
148+
if not isinstance(outputs, tuple):
149+
outputs = (outputs,)
150+
if len(outputs) != len(self.output_layouts):
151+
raise ValueError(
152+
"module outputs and output_layouts should have same length!"
153+
)
154+
for out, out_layout, desired_out_layout in zip(
155+
outputs, self.output_layouts, self.desired_output_layouts
156+
):
157+
if out_layout is not None:
158+
if isinstance(out, DTensor):
159+
# TODO: re-enable the check once we fix the compile path
160+
# assert out.placements[0] == out_layout
161+
dt_out = out
162+
else:
163+
dt_out = DTensor.from_local(
164+
out, device_mesh, (out_layout,), run_check=False
165+
)
166+
167+
if out_layout != desired_out_layout:
168+
dt_out = dt_out.redistribute(placements=(desired_out_layout,))
169+
prepared_outputs.append(
170+
dt_out.to_local() if self.use_local_output else dt_out
171+
)
172+
else:
173+
prepared_outputs.append(out)
174+
if len(prepared_outputs) == 1:
175+
return prepared_outputs[0]
176+
else:
177+
return tuple(prepared_outputs)
178+
179+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
180+
# for input
181+
if self.with_kwargs:
182+
module.register_forward_pre_hook(
183+
lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(
184+
inputs, kwargs, device_mesh
185+
),
186+
with_kwargs=True,
187+
) # type: ignore[misc]
188+
else:
189+
module.register_forward_pre_hook(
190+
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
191+
) # type: ignore[misc, call-arg]
192+
193+
# for output
194+
module.register_forward_hook(
195+
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
196+
) # type: ignore[misc, call-arg]
197+
198+
return module
199+
200+
201+
class TensorParallel(ParallelStyle):
202+
def __init__(
203+
self,
204+
*,
205+
input_layouts: Optional[Placement] = None,
206+
output_layouts: Optional[Placement] = None,
207+
use_local_output: bool = True,
208+
):
209+
super().__init__()
210+
self.input_layouts = (input_layouts or Replicate(),)
211+
self.output_layouts = (output_layouts or Partial(),)
212+
self.desired_input_layouts = (Replicate(),)
213+
self.use_local_output = use_local_output
214+
215+
@staticmethod
216+
def _prepare_input_fn(
217+
input_layouts, desired_input_layouts, mod, inputs, device_mesh
218+
):
219+
# TODO: figure out dynamo support for instance method and switch this to instance method
220+
221+
# annotate module input placements/sharding with input_layouts
222+
input_tensor = inputs[0]
223+
if not isinstance(input_tensor, DTensor):
224+
input_tensor = DTensor.from_local(
225+
input_tensor, device_mesh, input_layouts, run_check=False
226+
)
227+
228+
if input_layouts != desired_input_layouts:
229+
input_tensor = input_tensor.redistribute(
230+
placements=desired_input_layouts, async_op=True
231+
)
232+
return input_tensor
233+
234+
def _partition_fn(self, name, module, device_mesh):
235+
module.register_parameter(
236+
"gate_proj",
237+
nn.Parameter(distribute_tensor(module.gate_proj, device_mesh, [Shard(2)])),
238+
) # Column-wise sharding
239+
module.register_parameter(
240+
"down_proj",
241+
nn.Parameter(distribute_tensor(module.down_proj, device_mesh, [Shard(1)])),
242+
) # Row-wise sharding
243+
module.register_parameter(
244+
"up_proj",
245+
nn.Parameter(distribute_tensor(module.up_proj, device_mesh, [Shard(2)])),
246+
) # Column-wise sharding
247+
248+
@staticmethod
249+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
250+
if outputs.placements != output_layouts:
251+
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
252+
# back to local tensor
253+
return outputs.to_local() if use_local_output else outputs
254+
255+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
256+
return distribute_module(
257+
module,
258+
device_mesh,
259+
self._partition_fn,
260+
partial(
261+
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
262+
),
263+
partial(
264+
self._prepare_output_fn, self.output_layouts, self.use_local_output
265+
),
266+
)
267+
268+
269+
class ExpertParallel(ParallelStyle):
270+
def __init__(
271+
self,
272+
*,
273+
input_layouts: Optional[Placement] = None,
274+
output_layouts: Optional[Placement] = None,
275+
use_local_output: bool = True,
276+
):
277+
super().__init__()
278+
self.input_layouts = (input_layouts or Shard(0),)
279+
self.output_layouts = (output_layouts or Shard(0),)
280+
self.desired_input_layouts = (Shard(0),)
281+
self.use_local_output = use_local_output
282+
283+
@staticmethod
284+
def _prepare_input_fn(
285+
input_layouts, desired_input_layouts, mod, inputs, device_mesh
286+
):
287+
# TODO: figure out dynamo support for instance method and switch this to instance method
288+
289+
# annotate module input placements/sharding with input_layouts
290+
input_tensor = inputs[0]
291+
if not isinstance(input_tensor, DTensor):
292+
input_tensor = DTensor.from_local(
293+
input_tensor, device_mesh, input_layouts, run_check=False
294+
)
295+
296+
if input_layouts != desired_input_layouts:
297+
input_tensor = input_tensor.redistribute(
298+
placements=desired_input_layouts, async_op=True
299+
)
300+
return input_tensor
301+
302+
def _partition_fn(self, name, module, device_mesh):
303+
# shard on the expert dimension
304+
for name, param in module.named_parameters(recurse=False):
305+
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
306+
module.register_parameter(name, dist_param)
307+
308+
@staticmethod
309+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
310+
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
311+
if outputs.placements != output_layouts:
312+
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
313+
# back to local tensor
314+
return outputs.to_local() if use_local_output else outputs
315+
316+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
317+
return distribute_module(
318+
module,
319+
device_mesh,
320+
self._partition_fn,
321+
partial(
322+
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
323+
),
324+
partial(
325+
self._prepare_output_fn, self.output_layouts, self.use_local_output
326+
),
327+
)

0 commit comments

Comments
 (0)