16
16
from torchft .local_sgd import DiLoCo , LocalSGD
17
17
from torchft .manager import Manager
18
18
from torchft .manager_integ_test import FailureInjector , MyModel , Runner
19
- from torchft .process_group import ProcessGroupGloo , ProcessGroupNCCL
19
+ from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
20
20
21
21
logger : logging .Logger = logging .getLogger (__name__ )
22
22
@@ -41,7 +41,10 @@ def state_dict() -> Dict[str, Dict[str, object]]:
41
41
42
42
print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
43
43
44
- pg = ProcessGroupGloo ()
44
+ if device .type == "cuda" :
45
+ pg = ProcessGroupBabyNCCL ()
46
+ else :
47
+ pg = ProcessGroupGloo ()
45
48
manager = Manager (
46
49
pg = pg ,
47
50
min_replica_size = 2 ,
@@ -110,7 +113,12 @@ def diloco_train_loop(
110
113
# pyre-ignore[53]
111
114
def load_state_dict (state_dict : Dict [str , Dict [str , object ]]) -> None :
112
115
m .load_state_dict (state_dict ["model" ])
116
+ m .to (device )
113
117
diloco .original_parameters = state_dict ["original_params" ]
118
+ for name in diloco .original_parameters .keys ():
119
+ diloco .original_parameters [name ] = diloco .original_parameters [name ].to (
120
+ device
121
+ )
114
122
inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
115
123
outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
116
124
@@ -124,7 +132,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
124
132
125
133
print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
126
134
127
- pg = ProcessGroupGloo ()
135
+ if device .type == "cuda" :
136
+ pg = ProcessGroupBabyNCCL ()
137
+ else :
138
+ pg = ProcessGroupGloo ()
128
139
manager = Manager (
129
140
pg = pg ,
130
141
min_replica_size = 2 ,
@@ -138,6 +149,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
138
149
world_size = runner .world_size ,
139
150
lighthouse_addr = runner .lighthouse_address ,
140
151
port = 19530 + runner .replica_id ,
152
+ connect_timeout = timedelta (seconds = 10 ),
153
+ quorum_timeout = timedelta (seconds = 10 ),
141
154
timeout = timedelta (seconds = 10 ),
142
155
# pyre-fixme[6]: Incompatible parameter type
143
156
** runner .manager_args ,
@@ -155,6 +168,12 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
155
168
sync_every = 2 ,
156
169
) as diloco :
157
170
while True :
171
+ manager_curr_step = manager .current_step ()
172
+ if manager_curr_step not in all_state_dicts :
173
+ print (
174
+ f"{ manager_curr_step = } { diloco ._local_step = } { runner .replica_id = } { state_dict ()= } "
175
+ )
176
+ all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
158
177
batch_size = 1
159
178
inputs = m .get_rand_inputs (batch_size ).to (device )
160
179
labels = m .get_rand_labels (batch_size ).to (device )
@@ -164,7 +183,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
164
183
165
184
inner_optimizer .zero_grad ()
166
185
loss .backward ()
167
- all_state_dicts [str (manager .current_step ())] = state_dict ()
168
186
inner_optimizer .step ()
169
187
170
188
# after 4 model updates then break
@@ -181,10 +199,15 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
181
199
class LocalSGDIntegTest (TestCase ):
182
200
@parameterized .expand (
183
201
[
202
+ (True ,),
184
203
(False ,),
185
204
]
186
205
)
187
206
def test_local_sgd_recovery (self , use_cuda : bool ) -> None :
207
+ # Skip the test if use_cuda is True and there are not enough GPUs
208
+ if use_cuda and torch .cuda .device_count () < 2 :
209
+ self .skipTest ("Not enough GPUs for CUDA test" )
210
+
188
211
lighthouse = LighthouseServer (
189
212
bind = "[::]:0" ,
190
213
min_replicas = 2 ,
@@ -236,10 +259,15 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
236
259
237
260
@parameterized .expand (
238
261
[
262
+ (True ,),
239
263
(False ,),
240
264
]
241
265
)
242
266
def test_diloco_healthy (self , use_cuda : bool ) -> None :
267
+ # Skip the test if use_cuda is True and there are not enough GPUs
268
+ if use_cuda and torch .cuda .device_count () < 2 :
269
+ self .skipTest ("Not enough GPUs for CUDA test" )
270
+
243
271
lighthouse = LighthouseServer (bind = "[::]:0" , min_replicas = 2 )
244
272
num_replicas = 2
245
273
futures = []
@@ -289,7 +317,17 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
289
317
check_device = False ,
290
318
)
291
319
292
- def test_diloco_recovery (self ) -> None :
320
+ @parameterized .expand (
321
+ [
322
+ (True ,),
323
+ (False ,),
324
+ ]
325
+ )
326
+ def test_diloco_recovery (self , use_cuda : bool ) -> None :
327
+ # Skip the test if use_cuda is True and there are not enough GPUs
328
+ if use_cuda and torch .cuda .device_count () < 2 :
329
+ self .skipTest ("Not enough GPUs for CUDA test" )
330
+
293
331
lighthouse = LighthouseServer (
294
332
bind = "[::]:0" ,
295
333
min_replicas = 2 ,
0 commit comments