Skip to content

Commit 10ede60

Browse files
authored
feat(moe) impl w1&3 fused for moe (#422)
1 parent 30bb508 commit 10ede60

File tree

2 files changed

+68
-18
lines changed

2 files changed

+68
-18
lines changed

internlm/core/parallel/shard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
169169
return "column"
170170
elif linear_name in ("wo", "out_proj", "w2"):
171171
return "row"
172-
elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp":
172+
elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3", "grouped_w13") and tp_mode == "isp":
173173
return "grouped_wp"
174-
elif linear_name in ("grouped_w1", "grouped_w3"):
174+
elif linear_name in ("grouped_w1", "grouped_w3", "grouped_w13"):
175175
return "grouped_column"
176176
elif linear_name in ("grouped_w2"):
177177
return "grouped_row"

internlm/model/modules/mlp.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
logger = get_logger(__file__)
1515

1616

17-
def split_fused_mlp_weight(w1_w3):
18-
w1, w3 = torch.split(w1_w3, w1_w3.shape[0] // 2, dim=0)
17+
def split_fused_mlp_weight(w1_w3, split_dim=0):
18+
w1, w3 = torch.split(w1_w3, w1_w3.shape[split_dim] // 2, dim=split_dim)
1919
return w1, w3
2020

2121

@@ -41,6 +41,31 @@ def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **k
4141
return state_dict
4242

4343

44+
def _grouped_mlp_pre_load_convert(
45+
module: "FeedForward", state_dict, prefix: str, *args, **kwargs # pylint: disable=W0613
46+
) -> None:
47+
w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight"
48+
49+
if module.mlp_layer_fusion and fused_name not in state_dict:
50+
w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name)
51+
# loaded w1,w3: [in, out]; need: [in, out*2]
52+
state_dict[fused_name] = torch.cat([w1, w3], dim=1)
53+
54+
if not module.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict):
55+
state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name), split_dim=1)
56+
57+
58+
def _grouped_mlp_save_convert(
59+
module: "FeedForward", state_dict, prefix: str, *args, **kwargs # pylint: disable=W0613
60+
) -> Dict: # pylint: disable=W0613
61+
w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight"
62+
63+
if module.mlp_layer_fusion:
64+
state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name), split_dim=1)
65+
66+
return state_dict
67+
68+
4469
class FeedForward(nn.Module):
4570
"""
4671
Base FeedForward in flash implementation.
@@ -164,7 +189,30 @@ def __init__(
164189
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
165190

166191
if self.mlp_layer_fusion:
167-
assert False, "do not support for grouped mlp."
192+
self.fused_w1_w3 = new_linear(
193+
"grouped_w13",
194+
in_features,
195+
hidden_features * 2,
196+
bias,
197+
device=device,
198+
dtype=dtype,
199+
num_groups=num_groups,
200+
backend=backend,
201+
is_expert=is_expert,
202+
)
203+
self.w2 = new_linear(
204+
"grouped_w2",
205+
hidden_features,
206+
out_features,
207+
bias,
208+
device=device,
209+
dtype=dtype,
210+
num_groups=num_groups,
211+
backend=backend,
212+
is_expert=is_expert,
213+
)
214+
self._register_load_state_dict_pre_hook(_grouped_mlp_pre_load_convert, with_module=True)
215+
self._register_state_dict_hook(_grouped_mlp_save_convert)
168216
else:
169217
self.w1 = new_linear(
170218
"grouped_w1",
@@ -205,7 +253,8 @@ def forward(self, x, batch_sizes=None):
205253
w1_o = self.w1(x, batch_sizes)
206254
w3_o = self.w3(x, batch_sizes)
207255
else:
208-
assert False
256+
w13_o = self.fused_w1_w3(x, batch_sizes)
257+
w1_o, w3_o = torch.split(w13_o, w13_o.shape[-1] // 2, dim=-1)
209258
out = self.w2(Silu(w1_o, w3_o), batch_sizes)
210259
return out
211260

@@ -241,15 +290,16 @@ def new_feed_forward(
241290
backend=backend,
242291
is_expert=is_expert,
243292
)
244-
return FeedForward(
245-
in_features,
246-
hidden_features,
247-
out_features,
248-
bias,
249-
device,
250-
dtype,
251-
multiple_of,
252-
mlp_layer_fusion,
253-
activation_type,
254-
is_expert,
255-
)
293+
else:
294+
return FeedForward(
295+
in_features,
296+
hidden_features,
297+
out_features,
298+
bias,
299+
device,
300+
dtype,
301+
multiple_of,
302+
mlp_layer_fusion,
303+
activation_type,
304+
is_expert=is_expert,
305+
)

0 commit comments

Comments
 (0)