8
8
9
9
from internlm .model .modules .linear import new_linear
10
10
from internlm .model .modules .utils import Gelu , Silu
11
+ from internlm .core .context import global_context as gpc
11
12
from internlm .utils .logger import get_logger
12
13
from internlm .utils .utils import ActivationType
13
14
@@ -259,6 +260,67 @@ def forward(self, x, batch_sizes=None):
259
260
return out
260
261
261
262
263
+ class FluxFeedForward (nn .Module ):
264
+ """
265
+ Flux FeedForward.
266
+ Args:
267
+ in_features (int): size of each input sample
268
+ hidden_features (int): size of hidden state of FFN
269
+ out_features (int): size of each output sample
270
+ bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
271
+ in the config.
272
+ device (Optional[Union[str, torch.device]]): The device will be used.
273
+ dtype (Optional[torch.dtype]): The type of data.
274
+ multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
275
+ mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP.
276
+ activation_type (str): the activation function used for feed forward, "swiglu" by default.
277
+ """
278
+
279
+ def __init__ (
280
+ self ,
281
+ in_features : int ,
282
+ hidden_features : int ,
283
+ out_features : int = None ,
284
+ bias : bool = True ,
285
+ device : Optional [torch .device ] = None ,
286
+ dtype : Optional [torch .dtype ] = None ,
287
+ activation_type : str = "swiglu" ,
288
+ num_groups : int = 1 ,
289
+ backend : str = "bmm" ,
290
+ is_expert : bool = False ,
291
+ ):
292
+ super ().__init__ ()
293
+
294
+ # TODO: support gelu...
295
+ assert activation_type in ("swiglu" ), f"Unsupported activation type: { activation_type } "
296
+ assert bias is False , "Grouped FeedForward only support bias is False."
297
+
298
+ self .w1 = new_linear (
299
+ "grouped_w1" ,
300
+ in_features ,
301
+ hidden_features ,
302
+ bias ,
303
+ device = device ,
304
+ dtype = dtype ,
305
+ num_groups = num_groups ,
306
+ backend = backend ,
307
+ is_expert = is_expert ,
308
+ )
309
+ self .w2 = new_linear (
310
+ "grouped_w2" ,
311
+ hidden_features ,
312
+ out_features ,
313
+ bias ,
314
+ device = device ,
315
+ dtype = dtype ,
316
+ num_groups = num_groups ,
317
+ backend = backend ,
318
+ is_expert = is_expert ,
319
+ )
320
+ self ._register_load_state_dict_pre_hook (_grouped_mlp_pre_load_convert , with_module = True )
321
+ self ._register_state_dict_hook (_grouped_mlp_save_convert )
322
+
323
+
262
324
def new_feed_forward (
263
325
in_features : int ,
264
326
hidden_features : int ,
@@ -276,6 +338,19 @@ def new_feed_forward(
276
338
if use_grouped_mlp :
277
339
num_groups = kwargs .pop ("num_groups" , 1 )
278
340
backend = kwargs .pop ("backend" , "bmm" )
341
+ if gpc .config .model .moe_type == "Flux" :
342
+ return FluxFeedForward (
343
+ in_features ,
344
+ hidden_features ,
345
+ out_features ,
346
+ bias ,
347
+ device ,
348
+ dtype ,
349
+ activation_type ,
350
+ num_groups = num_groups ,
351
+ backend = backend ,
352
+ is_expert = is_expert ,
353
+ )
279
354
return GroupedFeedForward (
280
355
in_features ,
281
356
hidden_features ,
0 commit comments