Skip to content

Commit 519f193

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
simplify 2D parallel process group init (#2694)
Summary: Pull Request resolved: #2694 DeviceMesh and manual PG initialization was redundant code leading to more process groups created then needed. (2x as much) In this diff we update the init to use the process groups created by the DeviceMesh init instead. Reviewed By: carlbunny, TroyGarden Differential Revision: D68495749 fbshipit-source-id: 85123a5f43f0e1c55e50e5fd52b6dbc0d2c62107
1 parent dd5457c commit 519f193

File tree

1 file changed

+9
-26
lines changed

1 file changed

+9
-26
lines changed

torchrec/distributed/model_parallel.py

+9-26
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def _create_process_groups(
770770
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
771771
"""
772772
Creates process groups for sharding and replication, the process groups
773-
are created in the same exact order on all ranks as per `dist.new_group` API.
773+
are created using the DeviceMesh API.
774774
775775
Args:
776776
global_rank (int): The global rank of the current process.
@@ -781,44 +781,27 @@ def _create_process_groups(
781781
Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh,
782782
replication process group, and allreduce process group.
783783
"""
784-
# TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
785784
peer_matrix = []
786-
sharding_pg, replica_pg = None, None
787785
step = world_size // local_size
788786

789-
my_group_rank = global_rank % step
790787
for group_rank in range(world_size // local_size):
791788
peers = [step * r + group_rank for r in range(local_size)]
792-
backend = dist.get_backend(self._pg)
793-
curr_pg = dist.new_group(backend=backend, ranks=peers)
794789
peer_matrix.append(peers)
795-
if my_group_rank == group_rank:
796-
logger.warning(
797-
f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]"
798-
)
799-
sharding_pg = curr_pg
800-
assert sharding_pg is not None, "sharding_pg is not initialized!"
801-
dist.barrier()
802-
803-
my_inter_rank = global_rank // step
804-
for inter_rank in range(local_size):
805-
peers = [inter_rank * step + r for r in range(step)]
806-
backend = dist.get_backend(self._pg)
807-
curr_pg = dist.new_group(backend=backend, ranks=peers)
808-
if my_inter_rank == inter_rank:
809-
logger.warning(
810-
f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]"
811-
)
812-
replica_pg = curr_pg
813-
assert replica_pg is not None, "replica_pg is not initialized!"
814-
dist.barrier()
815790

816791
mesh = DeviceMesh(
817792
device_type=self._device.type,
818793
mesh=peer_matrix,
819794
mesh_dim_names=("replicate", "shard"),
820795
)
821796
logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
797+
sharding_pg = mesh.get_group(mesh_dim="shard")
798+
logger.warning(
799+
f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]"
800+
)
801+
replica_pg = mesh.get_group(mesh_dim="replicate")
802+
logger.warning(
803+
f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]"
804+
)
822805

823806
return mesh, sharding_pg, replica_pg
824807

0 commit comments

Comments
 (0)