Skip to content

Commit aa623bf

Browse files
committed
add flux expert layer to support ep overlap
1 parent 58f723a commit aa623bf

File tree

6 files changed

+402
-6
lines changed

6 files changed

+402
-6
lines changed

configs/7B_MoE4_sft.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
JOB_NAME = "7b_moe_train"
22
DO_ALERT = False
33

4-
SEQ_LEN = 2048
4+
SEQ_LEN = 1024
55
HIDDEN_SIZE = 4096
66
NUM_ATTENTION_HEAD = 32
77
MLP_RATIO = 4 / 3
@@ -170,8 +170,9 @@
170170
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
171171
qk_interleaved=False,
172172
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
173-
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
174-
num_experts=4,
173+
moe_type="Flux", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless", "Flux"
174+
mlp_layer_fusion=True,
175+
num_experts=8,
175176
top_k=2,
176177
)
177178
"""
@@ -217,10 +218,10 @@
217218
"""
218219
parallel = dict(
219220
zero1=dict(size=-1),
220-
tensor=dict(size=1, mode="mtp"),
221+
tensor=dict(size=8, mode="msp"),
221222
pipeline=dict(size=1, interleaved_overlap=True),
222223
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
223-
expert=dict(size=-1, no_tp=False),
224+
expert=dict(size=8, no_tp=True),
224225
expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
225226
)
226227

internlm/model/modules/mlp.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from internlm.model.modules.linear import new_linear
1010
from internlm.model.modules.utils import Gelu, Silu
11+
from internlm.core.context import global_context as gpc
1112
from internlm.utils.logger import get_logger
1213
from internlm.utils.utils import ActivationType
1314

@@ -259,6 +260,67 @@ def forward(self, x, batch_sizes=None):
259260
return out
260261

261262

263+
class FluxFeedForward(nn.Module):
264+
"""
265+
Flux FeedForward.
266+
Args:
267+
in_features (int): size of each input sample
268+
hidden_features (int): size of hidden state of FFN
269+
out_features (int): size of each output sample
270+
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
271+
in the config.
272+
device (Optional[Union[str, torch.device]]): The device will be used.
273+
dtype (Optional[torch.dtype]): The type of data.
274+
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
275+
mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP.
276+
activation_type (str): the activation function used for feed forward, "swiglu" by default.
277+
"""
278+
279+
def __init__(
280+
self,
281+
in_features: int,
282+
hidden_features: int,
283+
out_features: int = None,
284+
bias: bool = True,
285+
device: Optional[torch.device] = None,
286+
dtype: Optional[torch.dtype] = None,
287+
activation_type: str = "swiglu",
288+
num_groups: int = 1,
289+
backend: str = "bmm",
290+
is_expert: bool = False,
291+
):
292+
super().__init__()
293+
294+
# TODO: support gelu...
295+
assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}"
296+
assert bias is False, "Grouped FeedForward only support bias is False."
297+
298+
self.w1 = new_linear(
299+
"grouped_w1",
300+
in_features,
301+
hidden_features,
302+
bias,
303+
device=device,
304+
dtype=dtype,
305+
num_groups=num_groups,
306+
backend=backend,
307+
is_expert=is_expert,
308+
)
309+
self.w2 = new_linear(
310+
"grouped_w2",
311+
hidden_features,
312+
out_features,
313+
bias,
314+
device=device,
315+
dtype=dtype,
316+
num_groups=num_groups,
317+
backend=backend,
318+
is_expert=is_expert,
319+
)
320+
self._register_load_state_dict_pre_hook(_grouped_mlp_pre_load_convert, with_module=True)
321+
self._register_state_dict_hook(_grouped_mlp_save_convert)
322+
323+
262324
def new_feed_forward(
263325
in_features: int,
264326
hidden_features: int,
@@ -276,6 +338,19 @@ def new_feed_forward(
276338
if use_grouped_mlp:
277339
num_groups = kwargs.pop("num_groups", 1)
278340
backend = kwargs.pop("backend", "bmm")
341+
if gpc.config.model.moe_type == "Flux":
342+
return FluxFeedForward(
343+
in_features,
344+
hidden_features,
345+
out_features,
346+
bias,
347+
device,
348+
dtype,
349+
activation_type,
350+
num_groups=num_groups,
351+
backend=backend,
352+
is_expert=is_expert,
353+
)
279354
return GroupedFeedForward(
280355
in_features,
281356
hidden_features,

internlm/model/moe/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .dropless_layer import DroplessMoELayer
22
from .experts import Experts
33
from .gshard_layer import GShardMoELayer
4+
from .flux_layer import FluxMoELayer
45
from .megablocks import (
56
MegaBlockdMoE,
67
MegaBlockFeedForward,
@@ -18,4 +19,5 @@
1819
"MegaBlockFeedForward",
1920
"MegaBlockGroupedFeedForward",
2021
"DroplessMoELayer",
22+
"FluxMoELayer",
2123
]

0 commit comments

Comments
 (0)