14
14
logger = get_logger (__file__ )
15
15
16
16
17
- def split_fused_mlp_weight (w1_w3 ):
18
- w1 , w3 = torch .split (w1_w3 , w1_w3 .shape [0 ] // 2 , dim = 0 )
17
+ def split_fused_mlp_weight (w1_w3 , split_dim = 0 ):
18
+ w1 , w3 = torch .split (w1_w3 , w1_w3 .shape [split_dim ] // 2 , dim = split_dim )
19
19
return w1 , w3
20
20
21
21
@@ -41,6 +41,31 @@ def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **k
41
41
return state_dict
42
42
43
43
44
+ def _grouped_mlp_pre_load_convert (
45
+ module : "FeedForward" , state_dict , prefix : str , * args , ** kwargs # pylint: disable=W0613
46
+ ) -> None :
47
+ w1_name , w3_name , fused_name = f"{ prefix } w1.weight" , f"{ prefix } w3.weight" , f"{ prefix } fused_w1_w3.weight"
48
+
49
+ if module .mlp_layer_fusion and fused_name not in state_dict :
50
+ w1 , w3 = state_dict .pop (w1_name ), state_dict .pop (w3_name )
51
+ # loaded w1,w3: [in, out]; need: [in, out*2]
52
+ state_dict [fused_name ] = torch .cat ([w1 , w3 ], dim = 1 )
53
+
54
+ if not module .mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict ):
55
+ state_dict [w1_name ], state_dict [w3_name ] = split_fused_mlp_weight (state_dict .pop (fused_name ), split_dim = 1 )
56
+
57
+
58
+ def _grouped_mlp_save_convert (
59
+ module : "FeedForward" , state_dict , prefix : str , * args , ** kwargs # pylint: disable=W0613
60
+ ) -> Dict : # pylint: disable=W0613
61
+ w1_name , w3_name , fused_name = f"{ prefix } w1.weight" , f"{ prefix } w3.weight" , f"{ prefix } fused_w1_w3.weight"
62
+
63
+ if module .mlp_layer_fusion :
64
+ state_dict [w1_name ], state_dict [w3_name ] = split_fused_mlp_weight (state_dict .pop (fused_name ), split_dim = 1 )
65
+
66
+ return state_dict
67
+
68
+
44
69
class FeedForward (nn .Module ):
45
70
"""
46
71
Base FeedForward in flash implementation.
@@ -164,7 +189,30 @@ def __init__(
164
189
hidden_features = multiple_of * ((hidden_features + multiple_of - 1 ) // multiple_of )
165
190
166
191
if self .mlp_layer_fusion :
167
- assert False , "do not support for grouped mlp."
192
+ self .fused_w1_w3 = new_linear (
193
+ "grouped_w13" ,
194
+ in_features ,
195
+ hidden_features * 2 ,
196
+ bias ,
197
+ device = device ,
198
+ dtype = dtype ,
199
+ num_groups = num_groups ,
200
+ backend = backend ,
201
+ is_expert = is_expert ,
202
+ )
203
+ self .w2 = new_linear (
204
+ "grouped_w2" ,
205
+ hidden_features ,
206
+ out_features ,
207
+ bias ,
208
+ device = device ,
209
+ dtype = dtype ,
210
+ num_groups = num_groups ,
211
+ backend = backend ,
212
+ is_expert = is_expert ,
213
+ )
214
+ self ._register_load_state_dict_pre_hook (_grouped_mlp_pre_load_convert , with_module = True )
215
+ self ._register_state_dict_hook (_grouped_mlp_save_convert )
168
216
else :
169
217
self .w1 = new_linear (
170
218
"grouped_w1" ,
@@ -205,7 +253,8 @@ def forward(self, x, batch_sizes=None):
205
253
w1_o = self .w1 (x , batch_sizes )
206
254
w3_o = self .w3 (x , batch_sizes )
207
255
else :
208
- assert False
256
+ w13_o = self .fused_w1_w3 (x , batch_sizes )
257
+ w1_o , w3_o = torch .split (w13_o , w13_o .shape [- 1 ] // 2 , dim = - 1 )
209
258
out = self .w2 (Silu (w1_o , w3_o ), batch_sizes )
210
259
return out
211
260
@@ -241,15 +290,16 @@ def new_feed_forward(
241
290
backend = backend ,
242
291
is_expert = is_expert ,
243
292
)
244
- return FeedForward (
245
- in_features ,
246
- hidden_features ,
247
- out_features ,
248
- bias ,
249
- device ,
250
- dtype ,
251
- multiple_of ,
252
- mlp_layer_fusion ,
253
- activation_type ,
254
- is_expert ,
255
- )
293
+ else :
294
+ return FeedForward (
295
+ in_features ,
296
+ hidden_features ,
297
+ out_features ,
298
+ bias ,
299
+ device ,
300
+ dtype ,
301
+ multiple_of ,
302
+ mlp_layer_fusion ,
303
+ activation_type ,
304
+ is_expert = is_expert ,
305
+ )
0 commit comments