-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathlocal_sgd_integ_test.py
390 lines (334 loc) · 13.1 KB
/
local_sgd_integ_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import copy
import logging
import re
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from datetime import timedelta
from typing import Any, Dict
from unittest import TestCase
import torch
from parameterized import parameterized
from torch import nn, optim
from torchft._torchft import LighthouseServer
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
logger: logging.Logger = logging.getLogger(__name__)
def local_sgd_train_loop(
rank: int,
store_port: int,
device: torch.device,
runner: Runner,
) -> Dict[str, Dict[str, object]]:
with ExitStack() as stack:
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])
def state_dict() -> Dict[str, Dict[str, object]]:
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(runner.replica_id),
store_addr="localhost",
store_port=store_port,
rank=rank,
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
timeout=timedelta(seconds=10),
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
stack.callback(lambda: manager.shutdown(wait=False))
m: nn.Module = MyModel().to(device)
optimizer: optim.Optimizer = optim.Adam(m.parameters())
criterion = nn.CrossEntropyLoss()
with LocalSGD(manager, m, optimizer, sync_every=2) as local_sgd:
while True:
inputs = torch.rand(2, 3).to(device)
labels = torch.randint(4, (2,)).to(device)
optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
if manager.current_step() >= 4:
break
runner.failure_injector.check(rank, manager.current_step())
# return state_dict so we can check consistency
return state_dict()
return {}
def diloco_train_loop(
rank: int,
store_port: int,
device: torch.device,
runner: Runner,
) -> Dict[str, Dict[str, object]]:
with ExitStack() as stack:
# Declare the model and optimizers
m: nn.Module = MyModel(2, 3)
model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
m.load_state_dict(model_state_dict)
m.to(device)
# Setup optimizers
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer: optim.Optimizer = torch.optim.SGD(
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
# pyre-ignore[53]
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
m.to(device)
diloco.original_parameters = state_dict["original_params"]
for name in diloco.original_parameters.keys():
diloco.original_parameters[name] = diloco.original_parameters[name].to(
device
)
inner_optimizer.load_state_dict(state_dict["inner_optim"])
outer_optimizer.load_state_dict(state_dict["outer_optim"])
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
return {
"model": m.state_dict(),
"original_params": diloco.original_parameters,
"inner_optim": inner_optimizer.state_dict(),
"outer_optim": outer_optimizer.state_dict(),
}
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
use_async_quorum=False,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(runner.replica_id),
store_addr="localhost",
store_port=store_port,
rank=rank,
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
connect_timeout=timedelta(seconds=10),
quorum_timeout=timedelta(seconds=10),
timeout=timedelta(seconds=10),
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
stack.callback(manager.shutdown)
criterion = nn.CrossEntropyLoss()
all_state_dicts = {}
with DiLoCo(
manager,
m,
inner_optimizer,
outer_optimizer,
backup_device=device,
sync_every=2,
) as diloco:
while True:
manager_curr_step = manager.current_step()
if manager_curr_step not in all_state_dicts:
print(
f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
)
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
batch_size = 1
inputs = m.get_rand_inputs(batch_size).to(device)
labels = m.get_rand_labels(batch_size).to(device)
out = m(inputs)
loss = criterion(out, labels)
inner_optimizer.zero_grad()
loss.backward()
inner_optimizer.step()
# after 4 model updates then break
if manager.current_step() >= 4:
break
runner.failure_injector.check(rank, manager.current_step())
# return state_dict so we can check consistency
return all_state_dicts
return {}
class LocalSGDIntegTest(TestCase):
# TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
# Because of that the test is disabled for cuda
@parameterized.expand(
[
# (True,),
(False,),
]
)
def test_local_sgd_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []
failure_injectors = [
FailureInjector(),
FailureInjector().fail_at(0, 2),
]
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
):
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=local_sgd_train_loop,
use_cuda=use_cuda,
manager_args={
"use_async_quorum": False,
},
)
futures.append(executor.submit(runner.run_replica))
state_dicts = []
for fut in as_completed(futures):
try:
state_dicts.append(fut.result())
except Exception as e:
print(e)
raise
lighthouse.shutdown()
for state_dict in state_dicts:
# LocalSGD only guarantees that the model is consistent across
# replicas but uses separate optimizer states.
torch.testing.assert_close(
state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False
)
self.assertEqual(failure_injectors[1].count, 1)
@parameterized.expand(
[
# (True,),
(False,),
]
)
def test_diloco_healthy(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
num_replicas = 2
futures = []
torch.manual_seed(42)
# Initialize the model so we can pass in the state_dict
m: nn.Module = MyModel(2, 3)
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
failure_injector = FailureInjector()
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=diloco_train_loop,
use_cuda=use_cuda,
train_loop_args={
"model_state_dict": m.state_dict(),
},
)
futures.append(executor.submit(runner.run_replica))
state_dicts = []
for fut in as_completed(futures):
try:
state_dicts.append(fut.result()[0])
except Exception as e:
print(e, flush=True)
traceback.print_exc()
raise
lighthouse.shutdown()
rep0, rep1 = state_dicts
for step, state_dict in rep1.items():
# inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(
state_dict["model"],
rep0[step]["model"],
check_device=False,
)
torch.testing.assert_close(
state_dict["outer_optim"],
rep0[step]["outer_optim"],
check_device=False,
)
@parameterized.expand(
[
# (True,),
(False,),
]
)
def test_diloco_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []
failure_injectors = [
FailureInjector(),
FailureInjector().fail_at(0, 2),
]
torch.manual_seed(42)
# Initialize the model so we can pass in the state_dict
m: nn.Module = MyModel(2, 3)
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
):
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=diloco_train_loop,
train_loop_args={
"model_state_dict": m.state_dict(),
},
)
futures.append(executor.submit(runner.run_replica))
state_dicts = []
for fut in as_completed(futures):
try:
state_dicts.append(fut.result()[0])
except Exception as e:
print(e)
raise
lighthouse.shutdown()
rep0, rep1 = state_dicts
for step in rep0.keys():
# Inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(
rep1[step]["model"],
rep0[step]["model"],
check_device=False,
)
torch.testing.assert_close(
rep1[step]["outer_optim"],
rep0[step]["outer_optim"],
check_device=False,
)
self.assertEqual(failure_injectors[1].count, 1)