@@ -709,7 +709,9 @@ def __init__(
709
709
)
710
710
711
711
self ._remap_sharding_plan (
712
- plan , self ._global_rank , world_size // sharding_group_size
712
+ plan = plan ,
713
+ rank = self ._global_rank ,
714
+ num_nodes = world_size // sharding_group_size ,
713
715
)
714
716
super ().__init__ (
715
717
module ,
@@ -733,7 +735,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
733
735
"""
734
736
Syncs the DMP weights across the allreduce (inter) process group
735
737
736
- This method is called after each forward pass to synchronize the weights of the sharded modules.
738
+ This method is called after each train step to synchronize the weights of the sharded modules.
737
739
It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights,
738
740
which averages the weights across all processes in the inter-process group.
739
741
@@ -782,10 +784,10 @@ def _create_process_groups(
782
784
replication process group, and allreduce process group.
783
785
"""
784
786
peer_matrix = []
785
- step = world_size // local_size
787
+ num_nodes = world_size // local_size
786
788
787
789
for group_rank in range (world_size // local_size ):
788
- peers = [step * r + group_rank for r in range (local_size )]
790
+ peers = [num_nodes * r + group_rank for r in range (local_size )]
789
791
peer_matrix .append (peers )
790
792
791
793
mesh = DeviceMesh (
@@ -805,7 +807,9 @@ def _create_process_groups(
805
807
806
808
return mesh , sharding_pg , replica_pg
807
809
808
- def _remap_sharding_plan (self , plan : ShardingPlan , rank : int , step : int ) -> None :
810
+ def _remap_sharding_plan (
811
+ self , plan : ShardingPlan , rank : int , num_nodes : int
812
+ ) -> None :
809
813
"""
810
814
Remaps the sharding plan to the local replica process group ranks
811
815
ShardingPlan is remapped inplace.
@@ -816,22 +820,22 @@ def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None
816
820
Args:
817
821
plan (ShardingPlan): The original sharding plan.
818
822
global_rank (int): The global rank of the current process.
819
- step (int): The number of nodes.
823
+ num_nodes (int): The number of nodes.
820
824
"""
821
825
822
- group_start = rank % step
826
+ group_start = rank % num_nodes
823
827
for key in plan .plan :
824
828
# pyre-ignore[16]
825
829
for _ , param_sharding in plan .plan [key ].items ():
826
830
new_ranks = []
827
831
for shard_rank in param_sharding .ranks :
828
- new_ranks .append (shard_rank * step + group_start )
832
+ new_ranks .append (shard_rank * num_nodes + group_start )
829
833
param_sharding .ranks = new_ranks
830
834
if isinstance (param_sharding .sharding_spec , EnumerableShardingSpec ):
831
835
shards = param_sharding .sharding_spec .shards
832
836
if shards is not None :
833
837
for shard in shards :
834
- shard_rank = shard .placement ._rank * step + group_start
838
+ shard_rank = shard .placement ._rank * num_nodes + group_start
835
839
shard .placement = _remote_device (
836
840
f"rank:{ shard_rank } /cuda:{ shard_rank % get_local_size ()} "
837
841
)
0 commit comments