Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit aa623bf

Browse files
committedMar 20, 2025·
add flux expert layer to support ep overlap
1 parent 58f723a commit aa623bf

File tree

6 files changed

+402
-6
lines changed

6 files changed

+402
-6
lines changed
 

‎configs/7B_MoE4_sft.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
JOB_NAME = "7b_moe_train"
22
DO_ALERT = False
33

4-
SEQ_LEN = 2048
4+
SEQ_LEN = 1024
55
HIDDEN_SIZE = 4096
66
NUM_ATTENTION_HEAD = 32
77
MLP_RATIO = 4 / 3
@@ -170,8 +170,9 @@
170170
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
171171
qk_interleaved=False,
172172
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
173-
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
174-
num_experts=4,
173+
moe_type="Flux", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless", "Flux"
174+
mlp_layer_fusion=True,
175+
num_experts=8,
175176
top_k=2,
176177
)
177178
"""
@@ -217,10 +218,10 @@
217218
"""
218219
parallel = dict(
219220
zero1=dict(size=-1),
220-
tensor=dict(size=1, mode="mtp"),
221+
tensor=dict(size=8, mode="msp"),
221222
pipeline=dict(size=1, interleaved_overlap=True),
222223
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
223-
expert=dict(size=-1, no_tp=False),
224+
expert=dict(size=8, no_tp=True),
224225
expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
225226
)
226227

‎internlm/model/modules/mlp.py

+75
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from internlm.model.modules.linear import new_linear
1010
from internlm.model.modules.utils import Gelu, Silu
11+
from internlm.core.context import global_context as gpc
1112
from internlm.utils.logger import get_logger
1213
from internlm.utils.utils import ActivationType
1314

@@ -259,6 +260,67 @@ def forward(self, x, batch_sizes=None):
259260
return out
260261

261262

263+
class FluxFeedForward(nn.Module):
264+
"""
265+
Flux FeedForward.
266+
Args:
267+
in_features (int): size of each input sample
268+
hidden_features (int): size of hidden state of FFN
269+
out_features (int): size of each output sample
270+
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
271+
in the config.
272+
device (Optional[Union[str, torch.device]]): The device will be used.
273+
dtype (Optional[torch.dtype]): The type of data.
274+
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
275+
mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP.
276+
activation_type (str): the activation function used for feed forward, "swiglu" by default.
277+
"""
278+
279+
def __init__(
280+
self,
281+
in_features: int,
282+
hidden_features: int,
283+
out_features: int = None,
284+
bias: bool = True,
285+
device: Optional[torch.device] = None,
286+
dtype: Optional[torch.dtype] = None,
287+
activation_type: str = "swiglu",
288+
num_groups: int = 1,
289+
backend: str = "bmm",
290+
is_expert: bool = False,
291+
):
292+
super().__init__()
293+
294+
# TODO: support gelu...
295+
assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}"
296+
assert bias is False, "Grouped FeedForward only support bias is False."
297+
298+
self.w1 = new_linear(
299+
"grouped_w1",
300+
in_features,
301+
hidden_features,
302+
bias,
303+
device=device,
304+
dtype=dtype,
305+
num_groups=num_groups,
306+
backend=backend,
307+
is_expert=is_expert,
308+
)
309+
self.w2 = new_linear(
310+
"grouped_w2",
311+
hidden_features,
312+
out_features,
313+
bias,
314+
device=device,
315+
dtype=dtype,
316+
num_groups=num_groups,
317+
backend=backend,
318+
is_expert=is_expert,
319+
)
320+
self._register_load_state_dict_pre_hook(_grouped_mlp_pre_load_convert, with_module=True)
321+
self._register_state_dict_hook(_grouped_mlp_save_convert)
322+
323+
262324
def new_feed_forward(
263325
in_features: int,
264326
hidden_features: int,
@@ -276,6 +338,19 @@ def new_feed_forward(
276338
if use_grouped_mlp:
277339
num_groups = kwargs.pop("num_groups", 1)
278340
backend = kwargs.pop("backend", "bmm")
341+
if gpc.config.model.moe_type == "Flux":
342+
return FluxFeedForward(
343+
in_features,
344+
hidden_features,
345+
out_features,
346+
bias,
347+
device,
348+
dtype,
349+
activation_type,
350+
num_groups=num_groups,
351+
backend=backend,
352+
is_expert=is_expert,
353+
)
279354
return GroupedFeedForward(
280355
in_features,
281356
hidden_features,

‎internlm/model/moe/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .dropless_layer import DroplessMoELayer
22
from .experts import Experts
33
from .gshard_layer import GShardMoELayer
4+
from .flux_layer import FluxMoELayer
45
from .megablocks import (
56
MegaBlockdMoE,
67
MegaBlockFeedForward,
@@ -18,4 +19,5 @@
1819
"MegaBlockFeedForward",
1920
"MegaBlockGroupedFeedForward",
2021
"DroplessMoELayer",
22+
"FluxMoELayer",
2123
]

‎internlm/model/moe/flux_layer.py

+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import torch
2+
from torch import Tensor
3+
import flux
4+
import math
5+
from typing import Optional
6+
7+
from internlm.model.modules.mlp import new_feed_forward
8+
from internlm.model.moe.dropless_layer import TopKGate
9+
from internlm.utils.utils import TensorParallelMode
10+
11+
12+
from internlm.core.context import global_context as gpc
13+
from internlm.core.context.process_group_initializer import ParallelMode
14+
from internlm.model.moe.base_layer import BaseMoELayer
15+
16+
17+
class FluxMoELayer(torch.nn.Module):
18+
def __init__(
19+
self,
20+
in_features: int,
21+
hidden_features: int,
22+
out_features: int,
23+
num_experts: int,
24+
top_k: int,
25+
ep_group: Optional[torch.distributed.ProcessGroup],
26+
ep_size: int,
27+
device: Optional[torch.device] = None,
28+
dtype: Optional[torch.device] = None,
29+
mlp_layer_fusion: bool = False,
30+
multiple_of: int = 256,
31+
activation_type: str = "swiglu",
32+
drop_and_pad: bool = False,
33+
drop_policy="probs",
34+
capacity_factor: float = None,
35+
noisy_gate_policy: str = None,
36+
enable_fused_permute: bool = True,
37+
token_dispatch_policy: str = "alltoall",
38+
use_grouped_mlp: bool = True,
39+
deterministic_mode: bool = False,
40+
):
41+
super().__init__()
42+
43+
seq_len = gpc.config.data["seq_len"]
44+
micro_bsz = gpc.config.data["micro_bsz"]
45+
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
46+
tp_group = gpc.get_group(ParallelMode.TENSOR)
47+
ep_group = gpc.get_group(ParallelMode.EXPERT)
48+
world_size = torch.distributed.get_world_size()
49+
global_rank = gpc.get_global_rank()
50+
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
51+
52+
torch.cuda.set_device(local_rank)
53+
54+
if gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name:
55+
max_ntokens = seq_len * micro_bsz * top_k
56+
else:
57+
max_ntokens = seq_len * micro_bsz * top_k // tp_world_size
58+
59+
with torch.cuda.device(local_rank):
60+
initialized = False
61+
flux.init_flux_shm(tp_group)
62+
63+
tp_env = flux.DistEnvTPWithEP(
64+
tp_group=tp_group,
65+
nnodes=1,
66+
ep_group=ep_group,
67+
)
68+
# import pdb;pdb.set_trace()
69+
print("\n--- Device Context Check ---", flush=True)
70+
print(f"Torch当前默认设备: {torch.cuda.current_device()} (cuda:{torch.cuda.current_device()})", flush=True)
71+
print(f"Torch默认张量类型: {torch.get_default_dtype()}", flush=True)
72+
print(f"可见的CUDA设备: {torch.cuda.device_count()}", flush=True)
73+
print(f"local_rank值: {local_rank}", flush=True)
74+
print("----------------------------\n")
75+
76+
moe_args = flux.MoeArguments(
77+
max_ntokens=max_ntokens // top_k,
78+
hidden=in_features,
79+
ffn_hidden=hidden_features,
80+
nexperts=num_experts,
81+
topk=top_k,
82+
input_dtype=dtype,
83+
output_dtype=dtype,
84+
)
85+
86+
print(f"Rank {torch.distributed.get_rank()} 绑定到 cuda:{local_rank}", flush=True)
87+
88+
if not initialized:
89+
if flux.util.get_arch() >= 90:
90+
self.flux_ag_op = flux.GemmGroupedV3AGScatter(tp_env=tp_env, moe_args=moe_args)
91+
self.flux_rs_op = flux.GemmGroupedV3GatherRS(
92+
num_experts,
93+
max_ntokens,
94+
in_features,
95+
top_k,
96+
global_rank,
97+
world_size,
98+
tp_world_size,
99+
ep_size,
100+
1,
101+
)
102+
else:
103+
self.flux_ag_op = flux.GemmGroupedV2AGScatterOp(tp_env=tp_env, moe_args=moe_args)
104+
self.flux_rs_op = flux.GemmGroupedV2GatherRSOp(
105+
tp_group,
106+
num_experts,
107+
max_ntokens,
108+
in_features,
109+
top_k,
110+
dtype,
111+
tp_world_size,
112+
ep_size,
113+
1,
114+
)
115+
initialized = True
116+
# initialized = False
117+
# flux.init_flux_shm(tp_group)
118+
119+
# tp_env = flux.DistEnvTPWithEP(
120+
# tp_group=tp_group,
121+
# nnodes=1,
122+
# ep_group=ep_group,
123+
# )
124+
125+
# moe_args = flux.MoeArguments(
126+
# max_ntokens=max_ntokens // top_k,
127+
# hidden=in_features,
128+
# ffn_hidden=hidden_features,
129+
# nexperts=num_experts,
130+
# topk=top_k,
131+
# input_dtype=dtype,
132+
# output_dtype=dtype,
133+
# )
134+
135+
# if not initialized:
136+
# if flux.util.get_arch() >= 90:
137+
# self.flux_ag_op = flux.GemmGroupedV3AGScatter(tp_env=tp_env, moe_args=moe_args)
138+
# self.flux_rs_op = flux.GemmGroupedV3GatherRS(
139+
# num_experts,
140+
# max_ntokens,
141+
# in_features,
142+
# top_k,
143+
# global_rank,
144+
# world_size,
145+
# tp_world_size,
146+
# ep_size,
147+
# 1,
148+
# )
149+
# else:
150+
# self.flux_ag_op = flux.GemmGroupedV2AGScatterOp(tp_env=tp_env, moe_args=moe_args)
151+
# self.flux_rs_op = flux.GemmGroupedV2GatherRSOp(
152+
# tp_group,
153+
# num_experts,
154+
# max_ntokens,
155+
# in_features,
156+
# top_k,
157+
# dtype,
158+
# tp_world_size,
159+
# ep_size,
160+
# 1,
161+
# )
162+
# initialized = True
163+
164+
torch.cuda.synchronize(device=local_rank)
165+
166+
assert mlp_layer_fusion is True, f"mlp_layer_fusion should be set to True in flux moe layer"
167+
self.experts = new_feed_forward(
168+
in_features,
169+
hidden_features,
170+
out_features,
171+
bias=False,
172+
device=device,
173+
dtype=dtype,
174+
is_expert=True,
175+
use_grouped_mlp=True,
176+
num_groups=num_experts // ep_size,
177+
backend="gmm",
178+
)
179+
180+
self.gate = TopKGate(
181+
in_features,
182+
num_experts,
183+
top_k,
184+
noisy_gate_policy,
185+
)
186+
187+
self.hidden_features = hidden_features
188+
self.num_experts = num_experts
189+
self.topk = top_k
190+
self.deterministic_mode = deterministic_mode
191+
self.device = device
192+
self.dtype = dtype
193+
self.drop_and_pad = drop_and_pad
194+
self.capacity_factor = capacity_factor
195+
self.drop_policy = drop_policy
196+
self.activation_type = activation_type
197+
if self.drop_and_pad:
198+
assert self.capacity_factor is not None
199+
200+
def forward(self, *inputs: Tensor) -> Tensor:
201+
d_model = inputs[0].shape[-1]
202+
203+
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
204+
# Reshape into G groups so that each group can distribute tokens equally
205+
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
206+
reshaped_inputs = inputs[0].reshape(-1, d_model)
207+
208+
gates = self.gate(reshaped_inputs)
209+
# TODO: how to use expert_weights
210+
expert_weights, indices, tokens_per_expert_before_capacity = self.topk_softmax_with_capacity(gates)
211+
l_aux = self.load_balancing_loss(tokens_per_expert_before_capacity, gates)
212+
213+
splits_gpu = tokens_per_expert_before_capacity.to(self.device)
214+
splits_cpu = splits_gpu.to("cpu")
215+
216+
intermediate_output = torch.empty(
217+
(reshaped_inputs.size(0) * self.topk, self.hidden_features),
218+
dtype=self.dtype, device=self.device
219+
)
220+
221+
# MLP layer 0 (dispatch and GEMM0)
222+
self.flux_ag_op.forward(
223+
inputs_shard=reshaped_inputs,
224+
weights=self.experts.w1.weight.transpose(-1, -2).contiguous(),
225+
splits_gpu=splits_gpu.to(dtype=torch.int32),
226+
scatter_index=indices.to(dtype=torch.int32),
227+
outputs_buf=intermediate_output,
228+
)
229+
# Activation
230+
if self.activation_type == "swiglu":
231+
ac_func = torch.nn.functional.silu
232+
else:
233+
ac_func = torch.nn.functional.gelu
234+
235+
intermediate_output = ac_func(intermediate_output)
236+
# MLP layer 1 (GEMM1 and combine)
237+
mlp_output = self.flux_rs_op.forward_gather_rs(
238+
input=intermediate_output,
239+
weight=self.experts.w2.weight.transpose(-1, -2).contiguous(),
240+
splits_cpu=splits_cpu,
241+
routing_idx=indices,
242+
)
243+
244+
return mlp_output, l_aux
245+
246+
def topk_softmax_with_capacity(self, gates):
247+
expert_weights, indices = torch.topk(gates, self.topk, dim=1)
248+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
249+
# we compute num_local_tokens_per_expert here. If no drop and padding, num_local_tokens_per_expert should be
250+
# the final value, otherwise we recompute it in self.process(.)
251+
# histc(.) can be faster the bincount(.), but will cause non-deterministic behavior
252+
if self.deterministic_mode:
253+
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
254+
else:
255+
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)
256+
257+
# without capacity
258+
if self.capacity_factor is None:
259+
# shape: [num_token, topk]
260+
return expert_weights, indices, num_local_tokens_per_expert
261+
262+
# with capacity
263+
expert_capacity = self.get_capacity(
264+
num_tokens=gates.shape[0] * self.topk,
265+
num_experts=gates.shape[1],
266+
capacity_factor=self.capacity_factor,
267+
)
268+
# TopK selection, Maskout unused experts
269+
topk_masked_gates = torch.zeros_like(gates).scatter(1, indices, expert_weights)
270+
topk_mask = torch.zeros_like(gates).scatter(1, indices, 1)
271+
if self.drop_policy == "probs":
272+
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=expert_capacity, dim=0, sorted=False)
273+
capacity_mask = torch.zeros_like(gates).scatter(0, capacity_indices, 1)
274+
elif self.drop_policy == "position":
275+
_, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False)
276+
capacity_mask = torch.zeros_like(gates).scatter(0, capacity_indices, 1)
277+
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
278+
else:
279+
raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
280+
if self.drop_and_pad:
281+
# shape: [num_expert, capacity]
282+
final_expert_weights, final_indices = (
283+
capacity_probs.T.contiguous(),
284+
capacity_indices.T.contiguous(),
285+
)
286+
else:
287+
# Get exceed mask and maskout exceeded probs and indices
288+
final_mask = torch.logical_and(topk_mask, capacity_mask)
289+
drop_mask = torch.logical_not(final_mask)
290+
exceed_mask = torch.gather(drop_mask, 1, indices)
291+
# shape: [num_token, topk]
292+
final_expert_weights = expert_weights * torch.logical_not(exceed_mask)
293+
final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.long).max)
294+
295+
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
296+
297+
return final_expert_weights, final_indices, tokens_per_expert_before_capacity
298+
299+
def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None):
300+
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
301+
if min_capacity is not None and capacity < min_capacity:
302+
capacity = min_capacity
303+
return capacity
304+
305+
def load_balancing_loss(self, num_local_tokens_per_expert, gates):
306+
"""Calculate the load balancing loss contribution."""
307+
assert len(gates.size()) == 2
308+
tokens, num_experts = gates.size()
309+
assert num_experts == self.num_experts
310+
assert len(num_local_tokens_per_expert.size()) == 1
311+
(num_experts,) = num_local_tokens_per_expert.size()
312+
assert num_experts == self.num_experts
313+
scale = self.num_experts / (tokens * self.topk)
314+
return scale * torch.dot(num_local_tokens_per_expert.to(gates.dtype), gates.mean(dim=0))
315+

‎internlm/model/moe/moe.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from internlm.model.moe.gshard_layer import GShardMoELayer
1010
from internlm.model.moe.megablocks.megablock_dmoe import MegaBlockdMoE
1111
from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE
12+
from internlm.model.moe.flux_layer import FluxMoELayer
1213
from internlm.utils.logger import get_logger
1314

1415
# global llm logger
@@ -24,6 +25,8 @@ def new_moe_layer(moe_type: str, **kwargs):
2425
return MegaBlockMoE(**kwargs)
2526
elif moe_type == "MegaBlock-Dropless":
2627
return MegaBlockdMoE(**kwargs)
28+
elif moe_type == "Flux":
29+
return FluxMoELayer(**kwargs)
2730
else:
2831
raise ValueError(f"Unsupported model type: {moe_type}")
2932

‎internlm/train/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def set_param_unique_tracking_name(model):
200200
f"{local_fqn}",
201201
)
202202

203-
assert hasattr(child, "offset"), f"{child}"
203+
# assert hasattr(child, "offset"), f"{child}"
204204
map_fqn_local_to_global[local_fqn] = global_fqn
205205
map_fqn_global_to_local[global_fqn] = local_fqn
206206

0 commit comments

Comments
 (0)
Please sign in to comment.