Skip to content

Commit c6f41aa

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
update doc string and clean up variable naming (#2709)
Summary: Pull Request resolved: #2709 tsia Reviewed By: kausv Differential Revision: D68774673 fbshipit-source-id: be7b7c359df62877009ec01994879bd7942b4f4a
1 parent 52b0749 commit c6f41aa

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

torchrec/distributed/model_parallel.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,9 @@ def __init__(
709709
)
710710

711711
self._remap_sharding_plan(
712-
plan, self._global_rank, world_size // sharding_group_size
712+
plan=plan,
713+
rank=self._global_rank,
714+
num_nodes=world_size // sharding_group_size,
713715
)
714716
super().__init__(
715717
module,
@@ -733,7 +735,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
733735
"""
734736
Syncs the DMP weights across the allreduce (inter) process group
735737
736-
This method is called after each forward pass to synchronize the weights of the sharded modules.
738+
This method is called after each train step to synchronize the weights of the sharded modules.
737739
It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights,
738740
which averages the weights across all processes in the inter-process group.
739741
@@ -782,10 +784,10 @@ def _create_process_groups(
782784
replication process group, and allreduce process group.
783785
"""
784786
peer_matrix = []
785-
step = world_size // local_size
787+
num_nodes = world_size // local_size
786788

787789
for group_rank in range(world_size // local_size):
788-
peers = [step * r + group_rank for r in range(local_size)]
790+
peers = [num_nodes * r + group_rank for r in range(local_size)]
789791
peer_matrix.append(peers)
790792

791793
mesh = DeviceMesh(
@@ -805,7 +807,9 @@ def _create_process_groups(
805807

806808
return mesh, sharding_pg, replica_pg
807809

808-
def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None:
810+
def _remap_sharding_plan(
811+
self, plan: ShardingPlan, rank: int, num_nodes: int
812+
) -> None:
809813
"""
810814
Remaps the sharding plan to the local replica process group ranks
811815
ShardingPlan is remapped inplace.
@@ -816,22 +820,22 @@ def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None
816820
Args:
817821
plan (ShardingPlan): The original sharding plan.
818822
global_rank (int): The global rank of the current process.
819-
step (int): The number of nodes.
823+
num_nodes (int): The number of nodes.
820824
"""
821825

822-
group_start = rank % step
826+
group_start = rank % num_nodes
823827
for key in plan.plan:
824828
# pyre-ignore[16]
825829
for _, param_sharding in plan.plan[key].items():
826830
new_ranks = []
827831
for shard_rank in param_sharding.ranks:
828-
new_ranks.append(shard_rank * step + group_start)
832+
new_ranks.append(shard_rank * num_nodes + group_start)
829833
param_sharding.ranks = new_ranks
830834
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
831835
shards = param_sharding.sharding_spec.shards
832836
if shards is not None:
833837
for shard in shards:
834-
shard_rank = shard.placement._rank * step + group_start
838+
shard_rank = shard.placement._rank * num_nodes + group_start
835839
shard.placement = _remote_device(
836840
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
837841
)

0 commit comments

Comments
 (0)