@@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name):
260
260
@parametrize ("offload_grad,grad_accum" , [(False , 1 ), (False , 2 ), (True , 1 )])
261
261
def test_optim_cpu_offload_correctness (self , offload_grad , grad_accum ):
262
262
device = _DEVICES [- 1 ]
263
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
263
+ # The first two layers are chosen so that they have a terrible arithmetic density.
264
+ # this means long transfers and comparatively quick computation, increasing the chances
265
+ # that missing synchronization will lead to test failures.
266
+ # The third layer is very small, here to validate non-trainable parameters,
267
+ # but shouldn't influence the timings
268
+ model1 = nn .Sequential (
269
+ nn .Linear (32 , 131072 ),
270
+ nn .ReLU (),
271
+ nn .Linear (131072 , 64 ),
272
+ nn .ReLU (),
273
+ nn .Linear (64 , 64 ),
274
+ nn .ReLU (),
275
+ nn .Linear (64 , 128 ),
276
+ )
264
277
model1 .to (device )
265
278
266
279
# make sure it can work in the presence of non-trainable params
267
- model1 [0 ].requires_grad_ (False )
280
+ model1 [2 ].requires_grad_ (False )
268
281
model2 = copy .deepcopy (model1 )
269
282
270
283
optim1 = torch .optim .AdamW (model1 .parameters ())
@@ -274,15 +287,26 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
274
287
offload_gradients = offload_grad ,
275
288
)
276
289
290
+ rng = torch .Generator (device = device )
291
+ rng .manual_seed (42 )
292
+
293
+ # make sure to run both models separately; otherwise, model1 gives additional
294
+ # time for operations in model2 to complete, marking potential race conditions.
277
295
for _ in range (2 ):
278
296
for _ in range (grad_accum ):
279
- x = torch .randn (4 , 32 , device = device )
297
+ x = torch .randn (4 , 32 , device = device , generator = rng )
280
298
model1 (x ).sum ().backward ()
281
- model2 (x ).sum ().backward ()
282
299
283
300
optim1 .step ()
284
301
optim1 .zero_grad ()
285
302
303
+ # reset the rng
304
+ rng .manual_seed (42 )
305
+ for _ in range (2 ):
306
+ for _ in range (grad_accum ):
307
+ x = torch .randn (4 , 32 , device = device , generator = rng )
308
+ model2 (x ).sum ().backward ()
309
+
286
310
optim2 .step ()
287
311
optim2 .zero_grad ()
288
312
0 commit comments