Skip to content

Commit 88cba6a

Browse files
authored
Add GPU tests for local_sgd (#128)
1 parent 8f021e1 commit 88cba6a

File tree

1 file changed

+43
-5
lines changed

1 file changed

+43
-5
lines changed

torchft/local_sgd_integ_test.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchft.local_sgd import DiLoCo, LocalSGD
1717
from torchft.manager import Manager
1818
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
19-
from torchft.process_group import ProcessGroupGloo, ProcessGroupNCCL
19+
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
2020

2121
logger: logging.Logger = logging.getLogger(__name__)
2222

@@ -41,7 +41,10 @@ def state_dict() -> Dict[str, Dict[str, object]]:
4141

4242
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
4343

44-
pg = ProcessGroupGloo()
44+
if device.type == "cuda":
45+
pg = ProcessGroupBabyNCCL()
46+
else:
47+
pg = ProcessGroupGloo()
4548
manager = Manager(
4649
pg=pg,
4750
min_replica_size=2,
@@ -110,7 +113,12 @@ def diloco_train_loop(
110113
# pyre-ignore[53]
111114
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
112115
m.load_state_dict(state_dict["model"])
116+
m.to(device)
113117
diloco.original_parameters = state_dict["original_params"]
118+
for name in diloco.original_parameters.keys():
119+
diloco.original_parameters[name] = diloco.original_parameters[name].to(
120+
device
121+
)
114122
inner_optimizer.load_state_dict(state_dict["inner_optim"])
115123
outer_optimizer.load_state_dict(state_dict["outer_optim"])
116124

@@ -124,7 +132,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
124132

125133
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
126134

127-
pg = ProcessGroupGloo()
135+
if device.type == "cuda":
136+
pg = ProcessGroupBabyNCCL()
137+
else:
138+
pg = ProcessGroupGloo()
128139
manager = Manager(
129140
pg=pg,
130141
min_replica_size=2,
@@ -138,6 +149,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
138149
world_size=runner.world_size,
139150
lighthouse_addr=runner.lighthouse_address,
140151
port=19530 + runner.replica_id,
152+
connect_timeout=timedelta(seconds=10),
153+
quorum_timeout=timedelta(seconds=10),
141154
timeout=timedelta(seconds=10),
142155
# pyre-fixme[6]: Incompatible parameter type
143156
**runner.manager_args,
@@ -155,6 +168,12 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
155168
sync_every=2,
156169
) as diloco:
157170
while True:
171+
manager_curr_step = manager.current_step()
172+
if manager_curr_step not in all_state_dicts:
173+
print(
174+
f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
175+
)
176+
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
158177
batch_size = 1
159178
inputs = m.get_rand_inputs(batch_size).to(device)
160179
labels = m.get_rand_labels(batch_size).to(device)
@@ -164,7 +183,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
164183

165184
inner_optimizer.zero_grad()
166185
loss.backward()
167-
all_state_dicts[str(manager.current_step())] = state_dict()
168186
inner_optimizer.step()
169187

170188
# after 4 model updates then break
@@ -181,10 +199,15 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
181199
class LocalSGDIntegTest(TestCase):
182200
@parameterized.expand(
183201
[
202+
(True,),
184203
(False,),
185204
]
186205
)
187206
def test_local_sgd_recovery(self, use_cuda: bool) -> None:
207+
# Skip the test if use_cuda is True and there are not enough GPUs
208+
if use_cuda and torch.cuda.device_count() < 2:
209+
self.skipTest("Not enough GPUs for CUDA test")
210+
188211
lighthouse = LighthouseServer(
189212
bind="[::]:0",
190213
min_replicas=2,
@@ -236,10 +259,15 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
236259

237260
@parameterized.expand(
238261
[
262+
(True,),
239263
(False,),
240264
]
241265
)
242266
def test_diloco_healthy(self, use_cuda: bool) -> None:
267+
# Skip the test if use_cuda is True and there are not enough GPUs
268+
if use_cuda and torch.cuda.device_count() < 2:
269+
self.skipTest("Not enough GPUs for CUDA test")
270+
243271
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
244272
num_replicas = 2
245273
futures = []
@@ -289,7 +317,17 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
289317
check_device=False,
290318
)
291319

292-
def test_diloco_recovery(self) -> None:
320+
@parameterized.expand(
321+
[
322+
(True,),
323+
(False,),
324+
]
325+
)
326+
def test_diloco_recovery(self, use_cuda: bool) -> None:
327+
# Skip the test if use_cuda is True and there are not enough GPUs
328+
if use_cuda and torch.cuda.device_count() < 2:
329+
self.skipTest("Not enough GPUs for CUDA test")
330+
293331
lighthouse = LighthouseServer(
294332
bind="[::]:0",
295333
min_replicas=2,

0 commit comments

Comments
 (0)