@@ -864,7 +864,7 @@ def __init__(
864
864
raise ValueError (
865
865
"ManagedDeviceMesh doesn't support both mesh and parent are None."
866
866
)
867
- self .mesh = mesh
867
+ self ._mesh = mesh
868
868
self .mesh_dim_names = mesh_dim_names
869
869
self .replicate_pg = replicate_pg
870
870
self .replicate_dim = replicate_dim
@@ -893,17 +893,17 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
893
893
elif mesh_dim_names in self .flatten_meshes :
894
894
return self .flatten_meshes [mesh_dim_names ]
895
895
else :
896
- assert self .mesh is not None
897
- return self .mesh [mesh_dim_names ]
896
+ assert self ._mesh is not None
897
+ return self ._mesh [mesh_dim_names ]
898
898
else :
899
899
assert isinstance (mesh_dim_names , tuple )
900
900
if self .replicate_dim_name in mesh_dim_names :
901
- assert self .mesh is not None
902
- return self .mesh [mesh_dim_names ]
901
+ assert self ._mesh is not None
902
+ return self ._mesh [mesh_dim_names ]
903
903
else :
904
- assert self .mesh is not None
904
+ assert self ._mesh is not None
905
905
return ManagedDeviceMesh (
906
- self .mesh [mesh_dim_names ],
906
+ self ._mesh [mesh_dim_names ],
907
907
mesh_dim_names ,
908
908
self .replicate_pg ,
909
909
mesh_dim_names .index (self .replicate_dim_name ),
@@ -924,8 +924,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
924
924
elif dim == self .replicate_dim :
925
925
return self .replicate_pg
926
926
else :
927
- assert self .mesh is not None
928
- return self .mesh .get_group (self ._real_mesh_dim (dim ))
927
+ assert self ._mesh is not None
928
+ return self ._mesh .get_group (self ._real_mesh_dim (dim ))
929
929
930
930
def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
931
931
flatten_mesh = _FlattenDeviceMesh (self )
@@ -939,32 +939,32 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
939
939
940
940
def size (self , mesh_dim : Optional [int ] = None ) -> int :
941
941
if mesh_dim is None :
942
- if self .mesh is None :
942
+ if self ._mesh is None :
943
943
return self .replicate_pg .size ()
944
944
else :
945
- assert self .mesh is not None
946
- return self .mesh .size () * self .replicate_pg .size ()
945
+ assert self ._mesh is not None
946
+ return self ._mesh .size () * self .replicate_pg .size ()
947
947
elif mesh_dim == self .replicate_dim :
948
948
return self .replicate_pg .size ()
949
949
else :
950
- assert self .mesh is not None
951
- return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
950
+ assert self ._mesh is not None
951
+ return self ._mesh .size (self ._real_mesh_dim (mesh_dim ))
952
952
953
953
@property
954
954
def ndim (self ) -> int :
955
- assert self .mesh is not None
956
- return self .mesh .ndim + 1
955
+ assert self ._mesh is not None
956
+ return self ._mesh .ndim + 1
957
957
958
958
@property
959
959
def shape (self ) -> Tuple [int , ...]:
960
- assert self .mesh is not None
961
- ret : List [int ] = list (self .mesh .shape )
960
+ assert self ._mesh is not None
961
+ ret : List [int ] = list (self ._mesh .shape )
962
962
ret .insert (self .replicate_dim , self .replicate_pg .size ())
963
963
return tuple (ret )
964
964
965
965
def get_rank (self ) -> int :
966
- assert self .mesh is not None
967
- return self .mesh .get_rank ()
966
+ assert self ._mesh is not None
967
+ return self ._mesh .get_rank ()
968
968
969
969
def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
970
970
if isinstance (mesh_dim , str ):
@@ -973,33 +973,37 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
973
973
dim = 0 if mesh_dim is None else int (mesh_dim )
974
974
975
975
if mesh_dim is None :
976
- if self .mesh is None :
976
+ if self ._mesh is None :
977
977
return get_rank (self .replicate_pg )
978
978
979
979
assert self .replicate_dim == 0 , "replicate_dim must be the first one"
980
- assert self .mesh is not None
981
- other_dim_size = self .mesh .size ()
982
- assert self .mesh is not None
983
- other_dim_rank = self .mesh .get_local_rank ()
980
+ assert self ._mesh is not None
981
+ other_dim_size = self ._mesh .size ()
982
+ assert self ._mesh is not None
983
+ other_dim_rank = self ._mesh .get_local_rank ()
984
984
replicate_pg_rank = get_rank (self .replicate_pg )
985
985
return other_dim_size * replicate_pg_rank + other_dim_rank
986
986
elif dim == self .replicate_dim :
987
987
return get_rank (self .replicate_pg )
988
988
else :
989
- assert self .mesh is not None
990
- return self .mesh .get_local_rank (self ._real_mesh_dim (dim ))
989
+ assert self ._mesh is not None
990
+ return self ._mesh .get_local_rank (self ._real_mesh_dim (dim ))
991
991
992
992
def get_coordinate (self ) -> Optional [List [int ]]:
993
993
"""
994
994
Return the relative indices of this rank relative to all
995
995
dimensions of the mesh. If this rank is not part of the mesh, return None.
996
996
"""
997
- assert self .mesh is not None
998
- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
997
+ assert self ._mesh is not None
998
+ return self ._mesh ._coordinate_on_dim if self ._mesh ._coordinate_on_dim else None
999
999
1000
1000
def get_all_groups (self ) -> List [BaseProcessGroup ]:
1001
1001
raise NotImplementedError
1002
1002
1003
+ @property
1004
+ def mesh (self ):
1005
+ return self ._mesh .mesh
1006
+
1003
1007
1004
1008
class _FlattenDeviceMesh (DeviceMesh ):
1005
1009
def __init__ (self , managed_mesh : ManagedDeviceMesh ) -> None :
0 commit comments