Skip to content

Commit 624dd8e

Browse files
committed
[MoE][PoC] model code
ghstack-source-id: 105386d Pull Request resolved: #730
1 parent cca0702 commit 624dd8e

File tree

2 files changed

+268
-8
lines changed

2 files changed

+268
-8
lines changed

torchtitan/models/llama/model.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ class ModelArgs:
3434
depth_init: bool = True
3535
norm_type: str = "rmsnorm"
3636

37+
# MoE args
38+
enable_moe: bool = True
39+
num_experts: int = 8
40+
capacity_factor: float = 1.0
41+
use_shared_expert: bool = True
42+
auto_scale_hidden_dim: bool = True
43+
3744

3845
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
3946
"""
@@ -283,12 +290,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
283290
self.n_heads = model_args.n_heads
284291
self.dim = model_args.dim
285292
self.attention = Attention(model_args)
286-
self.feed_forward = FeedForward(
287-
dim=model_args.dim,
288-
hidden_dim=4 * model_args.dim,
289-
multiple_of=model_args.multiple_of,
290-
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
291-
)
293+
self.enable_moe = model_args.enable_moe
294+
295+
if not self.enable_moe:
296+
self.feed_forward = FeedForward(
297+
dim=model_args.dim,
298+
hidden_dim=4 * model_args.dim,
299+
multiple_of=model_args.multiple_of,
300+
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
301+
)
302+
else:
303+
from torchtitan.models.llama.moe_layer import (
304+
ExpertChoiceTopKRouter,
305+
GroupedExperts,
306+
MoE,
307+
)
308+
309+
hidden_dim_denom = 1
310+
if model_args.auto_scale_hidden_dim:
311+
hidden_dim_denom = model_args.capacity_factor + int(
312+
model_args.use_shared_expert
313+
)
314+
315+
dim = model_args.dim
316+
hidden_dim = 4 * model_args.dim
317+
hidden_dim = int(2 * hidden_dim / 3)
318+
if model_args.ffn_dim_multiplier is not None:
319+
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
320+
if model_args.auto_scale_hidden_dim:
321+
hidden_dim = int(hidden_dim / hidden_dim_denom)
322+
hidden_dim += -hidden_dim % model_args.multiple_of
323+
324+
num_experts = model_args.num_experts
325+
self.moe = MoE(
326+
experts=GroupedExperts(
327+
dim_in=dim, dim_out=hidden_dim, num_experts=num_experts
328+
),
329+
router=ExpertChoiceTopKRouter(
330+
gate=nn.Linear(dim, num_experts, bias=False),
331+
dim=dim,
332+
num_experts=num_experts,
333+
capacity_factor=model_args.capacity_factor,
334+
),
335+
shared_expert=(
336+
GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1)
337+
if model_args.use_shared_expert
338+
else None
339+
),
340+
)
341+
292342
self.layer_id = layer_id
293343
self.num_layers = model_args.n_layers
294344

@@ -321,14 +371,20 @@ def forward(
321371
322372
"""
323373
h = x + self.attention(self.attention_norm(x), freqs_cis)
324-
out = h + self.feed_forward(self.ffn_norm(h))
374+
if not self.enable_moe:
375+
out = h + self.feed_forward(self.ffn_norm(h))
376+
else:
377+
out = h + self.moe(self.ffn_norm(h))
325378
return out
326379

327380
def init_weights(self):
328381
for norm in (self.attention_norm, self.ffn_norm):
329382
norm.reset_parameters()
330383
self.attention.init_weights(self.weight_init_std)
331-
self.feed_forward.init_weights(self.weight_init_std)
384+
if not self.enable_moe:
385+
self.feed_forward.init_weights(self.weight_init_std)
386+
else:
387+
self.moe.init_weights(self.weight_init_std)
332388

333389

334390
class Transformer(nn.Module):

torchtitan/models/llama/moe_layer.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Optional
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class GroupedExperts(nn.Module):
15+
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
16+
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
17+
18+
Args:
19+
dim_in (int): Input dimension.
20+
dim_out (int): Output dimension.
21+
num_experts (int): Number of experts in this grouped experts layer. Default is 1.
22+
swiglu (bool): Whether to use gated linear unit. Default is True.
23+
activation (nn.Module): Activation function to use. Default is F.silu.
24+
"""
25+
26+
def __init__(
27+
self,
28+
*,
29+
dim_in: int,
30+
dim_out: int,
31+
num_experts: int = 1,
32+
swiglu: bool = True,
33+
activation: Callable = F.silu,
34+
):
35+
super().__init__()
36+
self.dim_in = dim_in
37+
self.num_experts = num_experts
38+
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
39+
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
40+
if swiglu:
41+
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
42+
self.act_fn = F.silu
43+
else:
44+
self.up_proj = None
45+
self.act_fn = activation
46+
47+
def forward(
48+
self,
49+
x: torch.Tensor,
50+
) -> torch.Tensor:
51+
"""
52+
Args:
53+
x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
54+
55+
Returns:
56+
torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
57+
"""
58+
# Expert Choice(EC) forward
59+
# x shape (num_experts, tokens_per_expert, dim_in)
60+
h = self.act_fn(torch.bmm(x, self.gate_proj))
61+
if self.up_proj is not None:
62+
h = h * torch.bmm(x, self.up_proj)
63+
# out shape (num_experts, tokens_per_expert, dim_out)
64+
out = torch.bmm(h, self.down_proj)
65+
return out
66+
67+
def init_weights(self, init_std: float):
68+
nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02)
69+
if self.up_proj is not None:
70+
nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std)
71+
nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std)
72+
73+
74+
class ExpertChoiceTopKRouter(nn.Module):
75+
"""This class implements experts choice routing. Each experts will select it's top K tokens based on
76+
the router scores. Refer to more details in https://arxiv.org/abs/2202.09368
77+
78+
Args:
79+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
80+
dim (int): Dimension of input tokens.
81+
num_experts (int): Number of experts in each moe layer.
82+
capacity_factor (float): Capacity factor determines how many tokens each expert can choose.
83+
expert capacity = (number of tokens * capacity factor) / number of experts.
84+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
85+
"""
86+
87+
def __init__(
88+
self,
89+
*,
90+
gate: nn.Module,
91+
dim: int,
92+
num_experts: int,
93+
capacity_factor: float,
94+
use_sigmoid: bool = True,
95+
):
96+
super().__init__()
97+
self.gate = gate
98+
self.dim = dim
99+
self.num_experts = num_experts
100+
self.capacity_factor = capacity_factor
101+
self.use_sigmoid = use_sigmoid
102+
103+
def forward(self, x: torch.Tensor):
104+
"""
105+
Args:
106+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
107+
108+
Returns:
109+
routed_input (torch.Tensor): input tokens grouped together by experts indices with shape
110+
``(num_experts*tokens_per_expert, dim)``.
111+
token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``.
112+
"""
113+
# scores shape (num_experts, bs*slen)
114+
scores = self.gate(x).transpose(0, 1)
115+
# By default, we perform sigmoid and softmax in float32 to avoid loss explosion.
116+
if self.use_sigmoid:
117+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
118+
else:
119+
scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype)
120+
tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts)
121+
tokens_per_expert += -tokens_per_expert % 8
122+
# Take the smaller of tokens_per_expert and the number of tokens
123+
tokens_per_expert = min(tokens_per_expert, x.shape[0])
124+
# top_scores shape (num_experts, tokens_per_expert)
125+
top_scores, selected_token_indices = torch.topk(
126+
scores, k=tokens_per_expert, dim=1
127+
)
128+
129+
return top_scores, selected_token_indices
130+
131+
def init_weights(self, init_std: float):
132+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
133+
134+
135+
class MoE(nn.Module):
136+
"""This class implements the moe layer which is Mixture of Experts. Mixture of Experts
137+
typically consists of a set of expert networks, alongside with a router, which directs input tokens
138+
to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204.
139+
140+
Args:
141+
experts (nn.Module): experts module.
142+
router (nn.Module): router module.
143+
shared_expert (Optional[nn.Module]): shared expert module. Default is None.
144+
"""
145+
146+
def __init__(
147+
self,
148+
*,
149+
experts: nn.Module,
150+
router: nn.Module,
151+
shared_expert: Optional[nn.Module] = None,
152+
):
153+
super().__init__()
154+
self.experts = experts
155+
self.router = router
156+
self.shared_expert = shared_expert
157+
158+
def forward(self, x: torch.Tensor) -> torch.Tensor:
159+
"""
160+
Args:
161+
x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``.
162+
163+
Returns:
164+
out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``.
165+
"""
166+
bz, slen, dim = x.shape
167+
168+
# routed_input shape (num_experts*tokens_per_expert, dim) for EC
169+
x = x.reshape(bz * slen, dim)
170+
top_scores, selected_token_indices = self.router(x)
171+
num_experts, _ = top_scores.shape
172+
173+
# token_indices shape (num_experts*tokens_per_expert, dim)
174+
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, dim)
175+
# routed_input shape (num_experts*tokens_per_expert, dim)
176+
routed_input = torch.gather(x, dim=0, index=token_indices)
177+
routed_input = routed_input * top_scores.reshape(-1, 1)
178+
179+
# routed_input shape (num_experts, tokens_per_expert, dim_in)
180+
routed_input = routed_input.reshape(num_experts, -1, dim)
181+
# routed_output shape (num_experts, tokens_per_expert, dim_out)
182+
routed_output = self.experts(routed_input)
183+
# routed_output shape (num_experts*tokens_per_expert, dim_out)
184+
routed_output = routed_output.reshape(-1, dim)
185+
186+
# shared expert
187+
if self.shared_expert is not None:
188+
out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape(
189+
bz * slen, dim
190+
)
191+
else:
192+
out = torch.zeros_like(x.reshape(bz * slen, dim))
193+
194+
# add experts output
195+
# doing in in place might be faster
196+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
197+
out = out.reshape(bz, slen, dim)
198+
return out
199+
200+
def init_weights(self, init_std: float):
201+
self.experts.init_weights(init_std)
202+
self.router.init_weights(init_std)
203+
if self.shared_expert is not None:
204+
self.shared_expert.init_weights(init_std)

0 commit comments

Comments
 (0)