2121import threading
2222from abc import ABC
2323from datetime import timedelta
24- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
24+ from typing import Any , TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
2525
2626import torch
2727import torch .distributed as dist
@@ -858,6 +858,8 @@ def extend_device_mesh(
858858
859859
860860class ManagedDeviceMesh (DeviceMesh ):
861+ replicate_pg_singleton : Optional ["ManagedProcessGroup" ]
862+
861863 def __init__ (
862864 self ,
863865 mesh : Optional [DeviceMesh ],
@@ -886,6 +888,15 @@ def __init__(
886888 self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
887889 self ._thread_id : Optional [int ] = None
888890
891+ def __getstate__ (self ) -> Dict [str , Any ]:
892+ state = self .__dict__ .copy ()
893+ state ["replicate_pg" ] = None
894+ return state
895+
896+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
897+ self .__dict__ .update (state )
898+ self .replicate_pg = self .replicate_pg_singleton
899+
889900 def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
890901 if isinstance (mesh_dim_names , str ):
891902 if mesh_dim_names == self .replicate_dim_name :
@@ -903,13 +914,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
903914 return self .mesh [mesh_dim_names ]
904915 else :
905916 assert isinstance (mesh_dim_names , tuple )
906- if self .replicate_dim_name in mesh_dim_names :
917+ if self .replicate_dim_name not in mesh_dim_names :
907918 assert self .mesh is not None
908919 return self .mesh [mesh_dim_names ]
909920 else :
910921 assert self .mesh is not None
922+ mesh_dim_names_wo_replicate = tuple (n for n in mesh_dim_names if n != self .replicate_dim_name )
911923 return ManagedDeviceMesh (
912- self .mesh [mesh_dim_names ],
924+ self .mesh [mesh_dim_names_wo_replicate ],
913925 mesh_dim_names ,
914926 self .replicate_pg ,
915927 mesh_dim_names .index (self .replicate_dim_name ),
@@ -944,14 +956,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
944956 return flatten_mesh
945957
946958 def size (self , mesh_dim : Optional [int ] = None ) -> int :
959+ replicate_pg_size = self .replicate_pg .size ()
960+ replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
947961 if mesh_dim is None :
948962 if self .mesh is None :
949- return self . replicate_pg . size ()
963+ return replicate_pg_size
950964 else :
951965 assert self .mesh is not None
952- return self .mesh .size () * self . replicate_pg . size ()
966+ return self .mesh .size () * replicate_pg_size
953967 elif mesh_dim == self .replicate_dim :
954- return self . replicate_pg . size ()
968+ return replicate_pg_size
955969 else :
956970 assert self .mesh is not None
957971 return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
@@ -1001,7 +1015,11 @@ def get_coordinate(self) -> Optional[List[int]]:
10011015 dimensions of the mesh. If this rank is not part of the mesh, return None.
10021016 """
10031017 assert self .mesh is not None
1004- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1018+ ret = self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1019+ if ret :
1020+ ret = ret .copy ()
1021+ ret .insert (get_rank (self .replicate_pg ), self .replicate_dim )
1022+ return ret
10051023
10061024 def get_all_groups (self ) -> List [BaseProcessGroup ]:
10071025 raise NotImplementedError
@@ -1076,6 +1094,8 @@ def ft_init_device_mesh(
10761094 # the same backend has been registered.
10771095 replicate_pg .register (mesh_dim_names [replicate_dim ])
10781096
1097+ ManagedDeviceMesh .replicate_pg_singleton = replicate_pg
1098+
10791099 return ManagedDeviceMesh (
10801100 mesh = mesh ,
10811101 mesh_dim_names = mesh_dim_names ,
0 commit comments