Skip to content

Commit 2a67d66

Browse files
authored
Fix ManagedDeviceMesh composability issues (#86)
* Fix ManagedDeviceMesh composability issues There are missing gaps of ManagedDeviceMesh to be actually used in TorchTitan. This PR fixes the gpas: 1. ManagedDeviceMesh is now able to be torch.save()/torch.load(). 2. ManagedDeviceMesh will lie if there are zero replicated group participants. Size 0 DeviceMesh will cause confusion for training loops. 3. Corretly returns coordinates. 4. Remove pg reinitialization issue
1 parent 4bdb8a7 commit 2a67d66

File tree

3 files changed

+122
-64
lines changed

3 files changed

+122
-64
lines changed

torchft/device_mesh_test.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import io
8+
import os
9+
from concurrent.futures import ProcessPoolExecutor
10+
from typing import cast
11+
from unittest import TestCase
12+
from unittest.mock import Mock
13+
14+
import torch
15+
import torch.distributed as dist
16+
17+
from torchft.manager import Manager
18+
from torchft.process_group import (
19+
ManagedProcessGroup,
20+
ProcessGroupGloo,
21+
ft_init_device_mesh,
22+
)
23+
24+
25+
class DeviceMeshTest(TestCase):
26+
@staticmethod
27+
def _test_init_device_mesh(world_size: int, rank: int) -> None:
28+
os.environ["MASTER_ADDR"] = "127.0.0.1"
29+
os.environ["MASTER_PORT"] = str(12346)
30+
os.environ["RANK"] = str(rank)
31+
os.environ["WORLD_SIZE"] = str(4)
32+
33+
testcase = TestCase()
34+
35+
manager = Mock(spec=Manager)
36+
manager._pg = ProcessGroupGloo()
37+
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
38+
# That's because the replicate group is NOT phystically created in the
39+
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
40+
device_mesh = ft_init_device_mesh(
41+
device_type="cpu",
42+
mesh_shape=(2, world_size),
43+
mesh_dim_names=("dp_replicate", "dp_shard"),
44+
replicate_dim=0,
45+
manager=manager,
46+
)
47+
48+
testcase.assertTrue(
49+
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
50+
)
51+
testcase.assertTrue(
52+
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
53+
)
54+
replicate_group = device_mesh.get_group("dp_replicate")
55+
testcase.assertEqual(
56+
cast(ManagedProcessGroup, replicate_group)._manager, manager
57+
)
58+
replicate_mesh = device_mesh["dp_replicate"]
59+
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
60+
61+
flatten_mesh = device_mesh._flatten("dp")
62+
manager.num_participants.return_value = 0
63+
testcase.assertEqual(flatten_mesh.size(), world_size)
64+
manager.num_participants.return_value = 1
65+
testcase.assertEqual(flatten_mesh.size(), world_size)
66+
manager.num_participants.return_value = 2
67+
testcase.assertEqual(flatten_mesh.size(), world_size * 2)
68+
69+
testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())
70+
71+
device_mesh.get_coordinate()
72+
buffer = io.BytesIO()
73+
torch.save(device_mesh, buffer)
74+
buffer.seek(0)
75+
torch.load(buffer, weights_only=False)
76+
77+
def test_init_device_mesh(self) -> None:
78+
with ProcessPoolExecutor(max_workers=4) as executor:
79+
futures = []
80+
for i in range(4):
81+
future = executor.submit(self._test_init_device_mesh, 4, i)
82+
futures.append(future)
83+
for f in futures:
84+
f.result()

torchft/process_group.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import queue
2121
import threading
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
23+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -871,6 +871,8 @@ def extend_device_mesh(
871871

872872

873873
class ManagedDeviceMesh(DeviceMesh):
874+
replicate_pg_singleton: Optional["ManagedProcessGroup"] = None
875+
874876
def __init__(
875877
self,
876878
mesh: Optional[DeviceMesh],
@@ -899,6 +901,16 @@ def __init__(
899901
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
900902
self._thread_id: Optional[int] = None
901903

904+
def __getstate__(self) -> Dict[str, Any]:
905+
state = self.__dict__.copy()
906+
state["replicate_pg"] = None
907+
return state
908+
909+
def __setstate__(self, state: Dict[str, Any]) -> None:
910+
self.__dict__.update(state)
911+
assert self.replicate_pg_singleton is not None
912+
self.replicate_pg = self.replicate_pg_singleton
913+
902914
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
903915
if isinstance(mesh_dim_names, str):
904916
if mesh_dim_names == self.replicate_dim_name:
@@ -916,13 +928,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
916928
return self.mesh[mesh_dim_names]
917929
else:
918930
assert isinstance(mesh_dim_names, tuple)
919-
if self.replicate_dim_name in mesh_dim_names:
931+
if self.replicate_dim_name not in mesh_dim_names:
920932
assert self.mesh is not None
921933
return self.mesh[mesh_dim_names]
922934
else:
935+
mesh_dim_names_wo_replicate = tuple(
936+
n for n in mesh_dim_names if n != self.replicate_dim_name
937+
)
923938
assert self.mesh is not None
924939
return ManagedDeviceMesh(
925-
self.mesh[mesh_dim_names],
940+
self.mesh[mesh_dim_names_wo_replicate],
926941
mesh_dim_names,
927942
self.replicate_pg,
928943
mesh_dim_names.index(self.replicate_dim_name),
@@ -957,14 +972,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
957972
return flatten_mesh
958973

959974
def size(self, mesh_dim: Optional[int] = None) -> int:
975+
replicate_pg_size = self.replicate_pg.size()
976+
# We have to lie to the users if there are zero particpants.
977+
# This is possible during the initialization stage of training.
978+
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
960979
if mesh_dim is None:
961980
if self.mesh is None:
962-
return self.replicate_pg.size()
981+
return replicate_pg_size
963982
else:
964983
assert self.mesh is not None
965-
return self.mesh.size() * self.replicate_pg.size()
984+
return self.mesh.size() * replicate_pg_size
966985
elif mesh_dim == self.replicate_dim:
967-
return self.replicate_pg.size()
986+
return replicate_pg_size
968987
else:
969988
assert self.mesh is not None
970989
return self.mesh.size(self._real_mesh_dim(mesh_dim))
@@ -1014,7 +1033,16 @@ def get_coordinate(self) -> Optional[List[int]]:
10141033
dimensions of the mesh. If this rank is not part of the mesh, return None.
10151034
"""
10161035
assert self.mesh is not None
1017-
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1036+
coordinate = (
1037+
self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1038+
)
1039+
if not coordinate:
1040+
return coordinate
1041+
1042+
# We need to copy be cause we are going to modify the coordinate.
1043+
coordinate = coordinate.copy()
1044+
coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim)
1045+
return coordinate
10181046

10191047
def get_all_groups(self) -> List[BaseProcessGroup]:
10201048
raise NotImplementedError
@@ -1076,19 +1104,11 @@ def ft_init_device_mesh(
10761104
mesh_dim_names=tuple(_mesh_dim_names),
10771105
)
10781106

1079-
if device_type == "cpu":
1080-
pg = ProcessGroupGloo()
1081-
elif device_type == "cuda":
1082-
pg = ProcessGroupNCCL()
1083-
else:
1084-
raise ValueError()
1085-
1086-
manager._pg = pg
10871107
replicate_pg = ManagedProcessGroup(manager)
1088-
# We have to use MultiProcessTestCase, otherwise c10d will complain
1089-
# the same backend has been registered.
10901108
replicate_pg.register(mesh_dim_names[replicate_dim])
10911109

1110+
ManagedDeviceMesh.replicate_pg_singleton = replicate_pg
1111+
10921112
return ManagedDeviceMesh(
10931113
mesh=mesh,
10941114
mesh_dim_names=mesh_dim_names,

torchft/process_group_test.py

+1-47
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
import multiprocessing
89
import os
910
import unittest
@@ -369,50 +370,3 @@ def test_managed_process_group(self) -> None:
369370
self.assertEqual(manager.report_error.call_count, 0)
370371
self.assertEqual(manager.wrap_future.call_count, 1)
371372
self.assertEqual(manager.wait_quorum.call_count, 1)
372-
373-
374-
class DeviceMeshTest(TestCase):
375-
@staticmethod
376-
def _test_init_device_mesh(world_size: int, rank: int) -> None:
377-
os.environ["MASTER_ADDR"] = "127.0.0.1"
378-
os.environ["MASTER_PORT"] = str(12346)
379-
os.environ["RANK"] = str(rank)
380-
os.environ["WORLD_SIZE"] = str(4)
381-
382-
testcase = TestCase()
383-
384-
manager = Mock(spec=Manager)
385-
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
386-
# That's because the replicate group is NOT phystically created in the
387-
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
388-
device_mesh = ft_init_device_mesh(
389-
device_type="cpu",
390-
mesh_shape=(2, world_size),
391-
mesh_dim_names=("dp_replicate", "dp_shard"),
392-
replicate_dim=0,
393-
manager=manager,
394-
)
395-
396-
testcase.assertTrue(
397-
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
398-
)
399-
testcase.assertTrue(
400-
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
401-
)
402-
replicate_group = device_mesh.get_group("dp_replicate")
403-
testcase.assertEqual(
404-
cast(ManagedProcessGroup, replicate_group)._manager, manager
405-
)
406-
replicate_mesh = device_mesh["dp_replicate"]
407-
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
408-
flatten_mesh = device_mesh._flatten("dp")
409-
manager.num_participants.return_value = 1
410-
testcase.assertEqual(flatten_mesh.size(), world_size)
411-
testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())
412-
413-
def test_init_device_mesh(self) -> None:
414-
with ProcessPoolExecutor(max_workers=4) as executor:
415-
futures = []
416-
for i in range(4):
417-
future = executor.submit(self._test_init_device_mesh, 4, i)
418-
futures.append(future)

0 commit comments

Comments
 (0)