Skip to content

Commit 5ae50d0

Browse files
authored
[llama4] add auxiliary-loss-free load balancing to MoE token routing (#1114)
There are two issues in this solution: 1. Communication (sync tokens per expert across all DP ranks) happens on the default stream. Maybe need to arrange it on FSDP/DDP comm stream. 2. The communication is blocking experts bias update, thus always exposed. We need to evaluate if 2 is a problem to performance. 1 is OK if 2 is acceptable.
1 parent 20e2f06 commit 5ae50d0

File tree

7 files changed

+106
-28
lines changed

7 files changed

+106
-28
lines changed
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
**The Llama 4 folder is still under development.**
22

3+
#### Issue tracking
4+
https://github.com/pytorch/torchtitan/issues/1118
5+
36
#### Available features
4-
- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing using efficient bfloat16 Grouped MM kernels
7+
- Llama 4 model (text-only), including a token-choice MoE architecture with efficient bfloat16 Grouped MM kernels and auxiliary-loss-free load balancing
58
- FSDP, TP, PP, CP support
69
- DCP checkpoint conversion scripts
710

@@ -13,17 +16,15 @@ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E
1316

1417
#### To be added
1518
- Modeling
16-
- iRoPE implementation
17-
- load balance loss for token-choice MoE
1819
- alternative expert-choice MoE
1920
- multimodal support
2021
- Parallelism
21-
- Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
22+
- Context Parallel support for FlexAttention and multimodal inputs
2223
- Expert Parallel support
2324
- torch.compile
2425
- for MoE layers
2526
- Quantization
26-
- efficient float8 GroupedGEMM kernels (from torchao)
27+
- efficient float8 Grouped MM kernels (from torchao)
2728
- Testing
2829
- perfomance and loss converging tests
2930
- CI integration

torchtitan/experiments/llama4/infra/parallelize_llama.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
)
2222
from torchtitan.tools.logging import logger
2323

24+
from ..model.moe import MoE
25+
2426

2527
def parallelize_llama(
2628
model: nn.Module,
@@ -74,17 +76,19 @@ def parallelize_llama(
7476
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
7577
torch._dynamo.config.capture_scalar_outputs = True
7678

79+
dp_mesh: DeviceMesh | None = None
7780
if (
7881
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
7982
): # apply FSDP or HSDP, potentially with Context Parallel
8083
if parallel_dims.dp_replicate_enabled:
8184
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
8285
else:
8386
dp_mesh_dim_names = ("dp_shard_cp",)
87+
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
8488

8589
apply_fsdp(
8690
model,
87-
world_mesh[tuple(dp_mesh_dim_names)],
91+
dp_mesh,
8892
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
8993
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
9094
pp_enabled=parallel_dims.pp_enabled,
@@ -105,13 +109,36 @@ def parallelize_llama(
105109
elif parallel_dims.dp_replicate_enabled:
106110
if world_mesh.ndim > 1:
107111
raise RuntimeError("DDP has not supported > 1D parallelism")
112+
dp_mesh = world_mesh
108113
apply_ddp(
109114
model,
110-
world_mesh,
115+
dp_mesh,
111116
enable_compile=job_config.training.compile,
112117
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
113118
)
114119

120+
# for MoE auxiliary-loss-free load balancing
121+
if dp_mesh is not None:
122+
# NOTE: Currently this sync is blocking (thus exposed) and happens on the
123+
# default compute stream. Need to assess if this is OK performance-wise.
124+
def _sync_tokens_per_expert(module, *_):
125+
assert isinstance(module, MoE)
126+
torch.distributed.all_reduce(
127+
module.tokens_per_expert, group=dp_mesh.get_group()
128+
)
129+
130+
for transformer_block in model.layers.values():
131+
if transformer_block.moe_enabled:
132+
load_balance_coeff = transformer_block.moe.load_balance_coeff
133+
if load_balance_coeff is not None and load_balance_coeff > 0:
134+
# prepend=True so that the sync runs before
135+
# the _update_expert_bias hook in MoE
136+
transformer_block.moe.register_full_backward_hook(
137+
_sync_tokens_per_expert, prepend=True
138+
)
139+
else:
140+
break
141+
115142
return model
116143

117144

@@ -127,7 +154,7 @@ def apply_moe_tp(
127154

128155
from .expert_parallel import NoParallel, TensorParallel
129156

130-
for _, transformer_block in model.layers.items():
157+
for transformer_block in model.layers.values():
131158
moe_layer_plan = {
132159
# input / output sharding on the seqlen dim
133160
# all-gather for input, reduce-scatter for output

torchtitan/experiments/llama4/model/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TransformerModelArgs(BaseModelArgs):
5656
# token-choice
5757
top_k: int = 1
5858
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
59+
load_balance_coeff: float | None = 1e-3
5960

6061
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
6162
self.vocab_size = tokenizer.n_words

torchtitan/experiments/llama4/model/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,12 @@ def forward(
341341
out = h + self.feed_forward(self.ffn_norm(h))
342342
return out
343343

344-
def init_weights(self):
344+
def init_weights(self, buffer_device: torch.device):
345345
for norm in (self.attention_norm, self.ffn_norm):
346346
norm.reset_parameters()
347347
self.attention.init_weights(self.weight_init_std)
348348
if self.moe_enabled:
349-
self.moe.init_weights(self.weight_init_std)
349+
self.moe.init_weights(self.weight_init_std, buffer_device)
350350
else:
351351
self.feed_forward.init_weights(self.weight_init_std)
352352

@@ -417,7 +417,7 @@ def init_weights(
417417
nn.init.normal_(self.tok_embeddings.weight)
418418
for layer in self.layers.values():
419419
if layer is not None:
420-
layer.init_weights()
420+
layer.init_weights(buffer_device=buffer_device)
421421
if self.norm is not None:
422422
self.norm.reset_parameters()
423423
final_out_std = self.model_args.dim**-0.5

torchtitan/experiments/llama4/model/moe.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
self.use_sigmoid = use_sigmoid
121121

122122
def forward(
123-
self, x: torch.Tensor
123+
self, x: torch.Tensor, expert_bias: torch.Tensor = None
124124
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125125
"""
126126
Args:
@@ -139,13 +139,17 @@ def forward(
139139

140140
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
141141
if self.use_sigmoid:
142-
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
142+
scores = torch.sigmoid(scores.to(torch.float32))
143143
else:
144-
scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
144+
scores = F.softmax(scores.to(torch.float32), dim=1)
145145

146146
# top scores shape (bs*slen, top_k)
147-
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
148-
# top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
147+
# NOTE: The expert_bias is only used for routing. The gating value
148+
# top_scores is still derived from the original scores.
149+
_, selected_experts_indices = torch.topk(
150+
scores + expert_bias, k=self.top_k, dim=1
151+
)
152+
top_scores = scores.gather(dim=1, index=selected_experts_indices)
149153

150154
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
151155
num_local_tokens_per_expert = torch.histc(
@@ -167,7 +171,6 @@ def init_weights(self, init_std: float):
167171
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
168172

169173

170-
# TODO: implement load balancing auxiliary loss for token-choice routing
171174
class MoE(nn.Module):
172175
def __init__(self, model_args: TransformerModelArgs):
173176
super().__init__()
@@ -209,6 +212,35 @@ def __init__(self, model_args: TransformerModelArgs):
209212
else None
210213
)
211214

215+
# auxiliary-loss-free load balancing
216+
self.load_balance_coeff = model_args.load_balance_coeff
217+
# the fields below are defined even when load_balance_coeff is None
218+
# to make initialization and checkpointing code simpler
219+
self.register_buffer(
220+
"expert_bias",
221+
torch.zeros(num_experts, dtype=torch.float32),
222+
persistent=True,
223+
)
224+
self.register_buffer(
225+
"tokens_per_expert",
226+
torch.zeros(num_experts, dtype=torch.float32),
227+
persistent=True,
228+
)
229+
230+
# NOTE: forward hook, forward pre hook, or backward pre hook
231+
# would conflict with activation checkpointing
232+
if self.load_balance_coeff is not None and self.load_balance_coeff > 0:
233+
self.register_full_backward_hook(self._update_expert_bias)
234+
235+
def _update_expert_bias(self, *_):
236+
expert_bias_delta = self.load_balance_coeff * torch.sign(
237+
self.tokens_per_expert.mean() - self.tokens_per_expert
238+
)
239+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
240+
self.expert_bias = self.expert_bias + expert_bias_delta
241+
242+
self.tokens_per_expert.zero_()
243+
212244
def forward(self, x: torch.Tensor) -> torch.Tensor:
213245
"""
214246
Args:
@@ -218,13 +250,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
218250
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
219251
"""
220252
bs, slen, dim = x.shape
253+
221254
# top_scores and selected_indices shape (bs*slen*top_k,)
222255
# num_local_tokens_per_expert shape (num_experts,)
223256
(
224257
top_scores,
225258
token_indices,
226259
num_local_tokens_per_expert,
227-
) = self.router(x.reshape(bs * slen, dim))
260+
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)
261+
262+
# will be used to update the expert bias for load balancing
263+
self.tokens_per_expert += num_local_tokens_per_expert
228264

229265
# shape (bs*slen*top_k, dim)
230266
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
@@ -235,7 +271,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
235271
dim=0,
236272
index=token_indices,
237273
)
238-
routed_input = routed_input * top_scores.reshape(-1, 1)
274+
routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to(
275+
x.dtype
276+
)
239277

240278
if self.use_grouped_mm:
241279
# NOTE: In order to use torch._grouped_mm, we need to make sure
@@ -285,8 +323,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
285323
out = out.reshape(bs, slen, dim)
286324
return out
287325

288-
def init_weights(self, init_std: float):
326+
def init_weights(
327+
self,
328+
init_std: float,
329+
buffer_device: torch.device,
330+
):
289331
self.experts.init_weights(init_std)
290332
self.router.init_weights(init_std)
291333
if self.shared_expert is not None:
292334
self.shared_expert.init_weights(init_std)
335+
336+
with torch.device(buffer_device):
337+
self.expert_bias = torch.zeros(
338+
self.experts.num_experts, dtype=torch.float32
339+
)
340+
self.tokens_per_expert = torch.zeros(
341+
self.experts.num_experts, dtype=torch.float32
342+
)

torchtitan/models/llama3/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
from dataclasses import dataclass
11-
from typing import Optional
1211

1312
import torch
1413
import torch.nn.functional as F
@@ -25,10 +24,10 @@ class TransformerModelArgs(BaseModelArgs):
2524
dim: int = 4096
2625
n_layers: int = 32
2726
n_heads: int = 32
28-
n_kv_heads: Optional[int] = None
27+
n_kv_heads: int | None = None
2928
vocab_size: int = -1 # defined later by tokenizer
3029
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
31-
ffn_dim_multiplier: Optional[float] = None
30+
ffn_dim_multiplier: float | None = None
3231
norm_eps: float = 1e-5
3332
rope_theta: float = 10000
3433

@@ -93,7 +92,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te
9392
Args:
9493
dim (int): Dimension of the frequency tensor.
9594
end (int): End index for precomputing frequencies.
96-
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
95+
theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0.
9796
9897
Returns:
9998
torch.Tensor: Precomputed frequency tensor with complex exponentials.
@@ -271,7 +270,7 @@ class FeedForward(nn.Module):
271270
dim (int): Input dimension.
272271
hidden_dim (int): Hidden dimension of the feedforward layer.
273272
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
274-
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
273+
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
275274
276275
Attributes:
277276
w1 (Linear): Linear transformation for the first layer.
@@ -285,7 +284,7 @@ def __init__(
285284
dim: int,
286285
hidden_dim: int,
287286
multiple_of: int,
288-
ffn_dim_multiplier: Optional[float],
287+
ffn_dim_multiplier: float | None,
289288
):
290289
super().__init__()
291290
hidden_dim = int(2 * hidden_dim / 3)
@@ -419,7 +418,7 @@ def __init__(self, model_args: TransformerModelArgs):
419418

420419
def init_weights(
421420
self,
422-
buffer_device: Optional[torch.device] = None,
421+
buffer_device: torch.device | None = None,
423422
):
424423
"""
425424
[Note: On ``init_weights`` vs. ``reset_parameters``]

torchtitan/models/llama3/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def apply_tp(
175175
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
176176
# by folding (and unfolding) the batch dimension and the sequence dimension.
177177
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
178-
for layer_id, transformer_block in model.layers.items():
178+
for transformer_block in model.layers.values():
179179
layer_plan = {
180180
"attention_norm": SequenceParallel(),
181181
"attention": prepare_module_input(

0 commit comments

Comments
 (0)