20
20
import queue
21
21
import threading
22
22
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
23
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
24
24
25
25
import torch
26
26
import torch .distributed as dist
@@ -871,6 +871,8 @@ def extend_device_mesh(
871
871
872
872
873
873
class ManagedDeviceMesh (DeviceMesh ):
874
+ replicate_pg_singleton : Optional ["ManagedProcessGroup" ] = None
875
+
874
876
def __init__ (
875
877
self ,
876
878
mesh : Optional [DeviceMesh ],
@@ -899,6 +901,16 @@ def __init__(
899
901
self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
900
902
self ._thread_id : Optional [int ] = None
901
903
904
+ def __getstate__ (self ) -> Dict [str , Any ]:
905
+ state = self .__dict__ .copy ()
906
+ state ["replicate_pg" ] = None
907
+ return state
908
+
909
+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
910
+ self .__dict__ .update (state )
911
+ assert self .replicate_pg_singleton is not None
912
+ self .replicate_pg = self .replicate_pg_singleton
913
+
902
914
def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
903
915
if isinstance (mesh_dim_names , str ):
904
916
if mesh_dim_names == self .replicate_dim_name :
@@ -916,13 +928,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
916
928
return self .mesh [mesh_dim_names ]
917
929
else :
918
930
assert isinstance (mesh_dim_names , tuple )
919
- if self .replicate_dim_name in mesh_dim_names :
931
+ if self .replicate_dim_name not in mesh_dim_names :
920
932
assert self .mesh is not None
921
933
return self .mesh [mesh_dim_names ]
922
934
else :
935
+ mesh_dim_names_wo_replicate = tuple (
936
+ n for n in mesh_dim_names if n != self .replicate_dim_name
937
+ )
923
938
assert self .mesh is not None
924
939
return ManagedDeviceMesh (
925
- self .mesh [mesh_dim_names ],
940
+ self .mesh [mesh_dim_names_wo_replicate ],
926
941
mesh_dim_names ,
927
942
self .replicate_pg ,
928
943
mesh_dim_names .index (self .replicate_dim_name ),
@@ -957,14 +972,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
957
972
return flatten_mesh
958
973
959
974
def size (self , mesh_dim : Optional [int ] = None ) -> int :
975
+ replicate_pg_size = self .replicate_pg .size ()
976
+ # We have to lie to the users if there are zero particpants.
977
+ # This is possible during the initialization stage of training.
978
+ replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
960
979
if mesh_dim is None :
961
980
if self .mesh is None :
962
- return self . replicate_pg . size ()
981
+ return replicate_pg_size
963
982
else :
964
983
assert self .mesh is not None
965
- return self .mesh .size () * self . replicate_pg . size ()
984
+ return self .mesh .size () * replicate_pg_size
966
985
elif mesh_dim == self .replicate_dim :
967
- return self . replicate_pg . size ()
986
+ return replicate_pg_size
968
987
else :
969
988
assert self .mesh is not None
970
989
return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
@@ -1014,7 +1033,16 @@ def get_coordinate(self) -> Optional[List[int]]:
1014
1033
dimensions of the mesh. If this rank is not part of the mesh, return None.
1015
1034
"""
1016
1035
assert self .mesh is not None
1017
- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1036
+ coordinate = (
1037
+ self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1038
+ )
1039
+ if not coordinate :
1040
+ return coordinate
1041
+
1042
+ # We need to copy be cause we are going to modify the coordinate.
1043
+ coordinate = coordinate .copy ()
1044
+ coordinate .insert (get_rank (self .replicate_pg ), self .replicate_dim )
1045
+ return coordinate
1018
1046
1019
1047
def get_all_groups (self ) -> List [BaseProcessGroup ]:
1020
1048
raise NotImplementedError
@@ -1076,19 +1104,11 @@ def ft_init_device_mesh(
1076
1104
mesh_dim_names = tuple (_mesh_dim_names ),
1077
1105
)
1078
1106
1079
- if device_type == "cpu" :
1080
- pg = ProcessGroupGloo ()
1081
- elif device_type == "cuda" :
1082
- pg = ProcessGroupNCCL ()
1083
- else :
1084
- raise ValueError ()
1085
-
1086
- manager ._pg = pg
1087
1107
replicate_pg = ManagedProcessGroup (manager )
1088
- # We have to use MultiProcessTestCase, otherwise c10d will complain
1089
- # the same backend has been registered.
1090
1108
replicate_pg .register (mesh_dim_names [replicate_dim ])
1091
1109
1110
+ ManagedDeviceMesh .replicate_pg_singleton = replicate_pg
1111
+
1092
1112
return ManagedDeviceMesh (
1093
1113
mesh = mesh ,
1094
1114
mesh_dim_names = mesh_dim_names ,
0 commit comments