@@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
325
325
self ._prepare_output_fn , self .output_layouts , self .use_local_output
326
326
),
327
327
)
328
+
329
+
330
+ # This class is for dp2ep with TP (without TP we can just use ExpertParallel)
331
+ class ExpertTensorParallel (ParallelStyle ):
332
+ def __init__ (
333
+ self ,
334
+ * ,
335
+ tp_mesh : DeviceMesh ,
336
+ ep_mesh : DeviceMesh ,
337
+ ):
338
+ super ().__init__ ()
339
+ # TODO: has to pass in the meshes in addition to device_mesh,
340
+ # as there's an issue from DeviceMesh that
341
+ # "Cannot create a submesh from a submesh."
342
+ self .tp_mesh = tp_mesh
343
+ self .ep_mesh = ep_mesh
344
+
345
+ @staticmethod
346
+ def _prepare_input_fn (tp_mesh , ep_mesh , mod , inputs , device_mesh ):
347
+ input_tensor = inputs [0 ]
348
+ # input_tensor of placements Shard(1) on the tp mesh
349
+ assert not isinstance (input_tensor , DTensor )
350
+
351
+ # a2a(ep)
352
+ input_tensor = DTensor .from_local (input_tensor , ep_mesh , (Shard (1 ),))
353
+ input_tensor = input_tensor .redistribute (placements = (Shard (0 ),)).to_local ()
354
+ # ag(tp)
355
+ input_tensor = DTensor .from_local (input_tensor , tp_mesh , (Shard (1 ),))
356
+ input_tensor = input_tensor .redistribute (placements = (Replicate (),))
357
+
358
+ return input_tensor
359
+
360
+ @staticmethod
361
+ def _partition_fn (tp_mesh , ep_mesh , name , module , device_mesh ):
362
+ # TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet
363
+ # module.register_parameter(
364
+ # "gate_proj",
365
+ # nn.Parameter(
366
+ # distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)])
367
+ # ),
368
+ # ) # Column-wise sharding
369
+ # module.register_parameter(
370
+ # "down_proj",
371
+ # nn.Parameter(
372
+ # distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)])
373
+ # ),
374
+ # ) # Row-wise sharding
375
+ # module.register_parameter(
376
+ # "up_proj",
377
+ # nn.Parameter(
378
+ # distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)])
379
+ # ),
380
+ # ) # Column-wise sharding
381
+
382
+ # TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it.
383
+ # This would become an issue from DCP resharding perspective.
384
+ module .register_parameter (
385
+ "gate_proj" ,
386
+ nn .Parameter (
387
+ DTensor .from_local (
388
+ (
389
+ distribute_tensor (
390
+ module .gate_proj , device_mesh , [Shard (0 ), Shard (2 )]
391
+ ).to_local ()
392
+ ),
393
+ tp_mesh ,
394
+ (Shard (2 ),),
395
+ )
396
+ ),
397
+ ) # Column-wise sharding
398
+ module .register_parameter (
399
+ "down_proj" ,
400
+ nn .Parameter (
401
+ DTensor .from_local (
402
+ (
403
+ distribute_tensor (
404
+ module .down_proj , device_mesh , [Shard (0 ), Shard (1 )]
405
+ ).to_local ()
406
+ ),
407
+ tp_mesh ,
408
+ (Shard (1 ),),
409
+ )
410
+ ),
411
+ ) # Row-wise sharding
412
+ module .register_parameter (
413
+ "up_proj" ,
414
+ nn .Parameter (
415
+ DTensor .from_local (
416
+ (
417
+ distribute_tensor (
418
+ module .up_proj , device_mesh , [Shard (0 ), Shard (2 )]
419
+ ).to_local ()
420
+ ),
421
+ tp_mesh ,
422
+ (Shard (2 ),),
423
+ )
424
+ ),
425
+ ) # Column-wise sharding
426
+
427
+ @staticmethod
428
+ def _prepare_output_fn (tp_mesh , ep_mesh , mod , outputs , device_mesh ):
429
+ # outputs of placements Partial() on the tp mesh
430
+
431
+ # rs(tp)
432
+ outputs = outputs .redistribute (placements = (Shard (1 ),)).to_local ()
433
+ # a2a(ep)
434
+ outputs = DTensor .from_local (outputs , ep_mesh , (Shard (0 ),))
435
+ outputs = outputs .redistribute (placements = (Shard (1 ),)).to_local ()
436
+
437
+ return outputs
438
+
439
+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
440
+ return distribute_module (
441
+ module ,
442
+ device_mesh ,
443
+ partial (self ._partition_fn , self .tp_mesh , self .ep_mesh ),
444
+ partial (self ._prepare_input_fn , self .tp_mesh , self .ep_mesh ),
445
+ partial (self ._prepare_output_fn , self .tp_mesh , self .ep_mesh ),
446
+ )
0 commit comments