@@ -770,7 +770,7 @@ def _create_process_groups(
770
770
) -> Tuple [DeviceMesh , dist .ProcessGroup , dist .ProcessGroup ]:
771
771
"""
772
772
Creates process groups for sharding and replication, the process groups
773
- are created in the same exact order on all ranks as per `dist.new_group` API.
773
+ are created using the DeviceMesh API.
774
774
775
775
Args:
776
776
global_rank (int): The global rank of the current process.
@@ -781,44 +781,27 @@ def _create_process_groups(
781
781
Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh,
782
782
replication process group, and allreduce process group.
783
783
"""
784
- # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
785
784
peer_matrix = []
786
- sharding_pg , replica_pg = None , None
787
785
step = world_size // local_size
788
786
789
- my_group_rank = global_rank % step
790
787
for group_rank in range (world_size // local_size ):
791
788
peers = [step * r + group_rank for r in range (local_size )]
792
- backend = dist .get_backend (self ._pg )
793
- curr_pg = dist .new_group (backend = backend , ranks = peers )
794
789
peer_matrix .append (peers )
795
- if my_group_rank == group_rank :
796
- logger .warning (
797
- f"[Connection] 2D sharding_group: [{ global_rank } ] -> [{ peers } ]"
798
- )
799
- sharding_pg = curr_pg
800
- assert sharding_pg is not None , "sharding_pg is not initialized!"
801
- dist .barrier ()
802
-
803
- my_inter_rank = global_rank // step
804
- for inter_rank in range (local_size ):
805
- peers = [inter_rank * step + r for r in range (step )]
806
- backend = dist .get_backend (self ._pg )
807
- curr_pg = dist .new_group (backend = backend , ranks = peers )
808
- if my_inter_rank == inter_rank :
809
- logger .warning (
810
- f"[Connection] 2D replica_group: [{ global_rank } ] -> [{ peers } ]"
811
- )
812
- replica_pg = curr_pg
813
- assert replica_pg is not None , "replica_pg is not initialized!"
814
- dist .barrier ()
815
790
816
791
mesh = DeviceMesh (
817
792
device_type = self ._device .type ,
818
793
mesh = peer_matrix ,
819
794
mesh_dim_names = ("replicate" , "shard" ),
820
795
)
821
796
logger .warning (f"[Connection] 2D Device Mesh created: { mesh } " )
797
+ sharding_pg = mesh .get_group (mesh_dim = "shard" )
798
+ logger .warning (
799
+ f"[Connection] 2D sharding_group: [{ global_rank } ] -> [{ mesh ['shard' ]} ]"
800
+ )
801
+ replica_pg = mesh .get_group (mesh_dim = "replicate" )
802
+ logger .warning (
803
+ f"[Connection] 2D replica_group: [{ global_rank } ] -> [{ mesh ['replicate' ]} ]"
804
+ )
822
805
823
806
return mesh , sharding_pg , replica_pg
824
807
0 commit comments