21
21
import threading
22
22
from abc import ABC
23
23
from datetime import timedelta
24
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
24
+ from typing import Any , TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
25
25
26
26
import torch
27
27
import torch .distributed as dist
@@ -858,6 +858,8 @@ def extend_device_mesh(
858
858
859
859
860
860
class ManagedDeviceMesh (DeviceMesh ):
861
+ replicate_pg_singleton : Optional ["ManagedProcessGroup" ]
862
+
861
863
def __init__ (
862
864
self ,
863
865
mesh : Optional [DeviceMesh ],
@@ -886,6 +888,15 @@ def __init__(
886
888
self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
887
889
self ._thread_id : Optional [int ] = None
888
890
891
+ def __getstate__ (self ) -> Dict [str , Any ]:
892
+ state = self .__dict__ .copy ()
893
+ state ["replicate_pg" ] = None
894
+ return state
895
+
896
+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
897
+ self .__dict__ .update (state )
898
+ self .replicate_pg = self .replicate_pg_singleton
899
+
889
900
def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
890
901
if isinstance (mesh_dim_names , str ):
891
902
if mesh_dim_names == self .replicate_dim_name :
@@ -903,13 +914,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
903
914
return self .mesh [mesh_dim_names ]
904
915
else :
905
916
assert isinstance (mesh_dim_names , tuple )
906
- if self .replicate_dim_name in mesh_dim_names :
917
+ if self .replicate_dim_name not in mesh_dim_names :
907
918
assert self .mesh is not None
908
919
return self .mesh [mesh_dim_names ]
909
920
else :
910
921
assert self .mesh is not None
922
+ mesh_dim_names_wo_replicate = tuple (n for n in mesh_dim_names if n != self .replicate_dim_name )
911
923
return ManagedDeviceMesh (
912
- self .mesh [mesh_dim_names ],
924
+ self .mesh [mesh_dim_names_wo_replicate ],
913
925
mesh_dim_names ,
914
926
self .replicate_pg ,
915
927
mesh_dim_names .index (self .replicate_dim_name ),
@@ -944,14 +956,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
944
956
return flatten_mesh
945
957
946
958
def size (self , mesh_dim : Optional [int ] = None ) -> int :
959
+ replicate_pg_size = self .replicate_pg .size ()
960
+ replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
947
961
if mesh_dim is None :
948
962
if self .mesh is None :
949
- return self . replicate_pg . size ()
963
+ return replicate_pg_size
950
964
else :
951
965
assert self .mesh is not None
952
- return self .mesh .size () * self . replicate_pg . size ()
966
+ return self .mesh .size () * replicate_pg_size
953
967
elif mesh_dim == self .replicate_dim :
954
- return self . replicate_pg . size ()
968
+ return replicate_pg_size
955
969
else :
956
970
assert self .mesh is not None
957
971
return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
@@ -1001,7 +1015,11 @@ def get_coordinate(self) -> Optional[List[int]]:
1001
1015
dimensions of the mesh. If this rank is not part of the mesh, return None.
1002
1016
"""
1003
1017
assert self .mesh is not None
1004
- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1018
+ ret = self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1019
+ if ret :
1020
+ ret = ret .copy ()
1021
+ ret .insert (get_rank (self .replicate_pg ), self .replicate_dim )
1022
+ return ret
1005
1023
1006
1024
def get_all_groups (self ) -> List [BaseProcessGroup ]:
1007
1025
raise NotImplementedError
@@ -1076,6 +1094,8 @@ def ft_init_device_mesh(
1076
1094
# the same backend has been registered.
1077
1095
replicate_pg .register (mesh_dim_names [replicate_dim ])
1078
1096
1097
+ ManagedDeviceMesh .replicate_pg_singleton = replicate_pg
1098
+
1079
1099
return ManagedDeviceMesh (
1080
1100
mesh = mesh ,
1081
1101
mesh_dim_names = mesh_dim_names ,
0 commit comments