1414logger = get_logger (__file__ )
1515
1616
17- def split_fused_mlp_weight (w1_w3 ):
18- w1 , w3 = torch .split (w1_w3 , w1_w3 .shape [0 ] // 2 , dim = 0 )
17+ def split_fused_mlp_weight (w1_w3 , split_dim = 0 ):
18+ w1 , w3 = torch .split (w1_w3 , w1_w3 .shape [split_dim ] // 2 , dim = split_dim )
1919 return w1 , w3
2020
2121
@@ -41,6 +41,31 @@ def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **k
4141 return state_dict
4242
4343
44+ def _grouped_mlp_pre_load_convert (
45+ module : "FeedForward" , state_dict , prefix : str , * args , ** kwargs # pylint: disable=W0613
46+ ) -> None :
47+ w1_name , w3_name , fused_name = f"{ prefix } w1.weight" , f"{ prefix } w3.weight" , f"{ prefix } fused_w1_w3.weight"
48+
49+ if module .mlp_layer_fusion and fused_name not in state_dict :
50+ w1 , w3 = state_dict .pop (w1_name ), state_dict .pop (w3_name )
51+ # loaded w1,w3: [in, out]; need: [in, out*2]
52+ state_dict [fused_name ] = torch .cat ([w1 , w3 ], dim = 1 )
53+
54+ if not module .mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict ):
55+ state_dict [w1_name ], state_dict [w3_name ] = split_fused_mlp_weight (state_dict .pop (fused_name ), split_dim = 1 )
56+
57+
58+ def _grouped_mlp_save_convert (
59+ module : "FeedForward" , state_dict , prefix : str , * args , ** kwargs # pylint: disable=W0613
60+ ) -> Dict : # pylint: disable=W0613
61+ w1_name , w3_name , fused_name = f"{ prefix } w1.weight" , f"{ prefix } w3.weight" , f"{ prefix } fused_w1_w3.weight"
62+
63+ if module .mlp_layer_fusion :
64+ state_dict [w1_name ], state_dict [w3_name ] = split_fused_mlp_weight (state_dict .pop (fused_name ), split_dim = 1 )
65+
66+ return state_dict
67+
68+
4469class FeedForward (nn .Module ):
4570 """
4671 Base FeedForward in flash implementation.
@@ -164,7 +189,30 @@ def __init__(
164189 hidden_features = multiple_of * ((hidden_features + multiple_of - 1 ) // multiple_of )
165190
166191 if self .mlp_layer_fusion :
167- assert False , "do not support for grouped mlp."
192+ self .fused_w1_w3 = new_linear (
193+ "grouped_w13" ,
194+ in_features ,
195+ hidden_features * 2 ,
196+ bias ,
197+ device = device ,
198+ dtype = dtype ,
199+ num_groups = num_groups ,
200+ backend = backend ,
201+ is_expert = is_expert ,
202+ )
203+ self .w2 = new_linear (
204+ "grouped_w2" ,
205+ hidden_features ,
206+ out_features ,
207+ bias ,
208+ device = device ,
209+ dtype = dtype ,
210+ num_groups = num_groups ,
211+ backend = backend ,
212+ is_expert = is_expert ,
213+ )
214+ self ._register_load_state_dict_pre_hook (_grouped_mlp_pre_load_convert , with_module = True )
215+ self ._register_state_dict_hook (_grouped_mlp_save_convert )
168216 else :
169217 self .w1 = new_linear (
170218 "grouped_w1" ,
@@ -205,7 +253,8 @@ def forward(self, x, batch_sizes=None):
205253 w1_o = self .w1 (x , batch_sizes )
206254 w3_o = self .w3 (x , batch_sizes )
207255 else :
208- assert False
256+ w13_o = self .fused_w1_w3 (x , batch_sizes )
257+ w1_o , w3_o = torch .split (w13_o , w13_o .shape [- 1 ] // 2 , dim = - 1 )
209258 out = self .w2 (Silu (w1_o , w3_o ), batch_sizes )
210259 return out
211260
@@ -241,15 +290,16 @@ def new_feed_forward(
241290 backend = backend ,
242291 is_expert = is_expert ,
243292 )
244- return FeedForward (
245- in_features ,
246- hidden_features ,
247- out_features ,
248- bias ,
249- device ,
250- dtype ,
251- multiple_of ,
252- mlp_layer_fusion ,
253- activation_type ,
254- is_expert ,
255- )
293+ else :
294+ return FeedForward (
295+ in_features ,
296+ hidden_features ,
297+ out_features ,
298+ bias ,
299+ device ,
300+ dtype ,
301+ multiple_of ,
302+ mlp_layer_fusion ,
303+ activation_type ,
304+ is_expert = is_expert ,
305+ )
0 commit comments