Skip to content

Commit 122eb73

Browse files
authored
more stringent test for CPUOffloadOptimizer (#1650)
* more stringent test for CPUOffloadOptimizer * fix missing sync
1 parent 3eb18e7 commit 122eb73

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

test/prototype/test_low_bit_optim.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name):
260260
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
261261
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
262262
device = _DEVICES[-1]
263-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
263+
# The first two layers are chosen so that they have a terrible arithmetic density.
264+
# this means long transfers and comparatively quick computation, increasing the chances
265+
# that missing synchronization will lead to test failures.
266+
# The third layer is very small, here to validate non-trainable parameters,
267+
# but shouldn't influence the timings
268+
model1 = nn.Sequential(
269+
nn.Linear(32, 131072),
270+
nn.ReLU(),
271+
nn.Linear(131072, 64),
272+
nn.ReLU(),
273+
nn.Linear(64, 64),
274+
nn.ReLU(),
275+
nn.Linear(64, 128),
276+
)
264277
model1.to(device)
265278

266279
# make sure it can work in the presence of non-trainable params
267-
model1[0].requires_grad_(False)
280+
model1[2].requires_grad_(False)
268281
model2 = copy.deepcopy(model1)
269282

270283
optim1 = torch.optim.AdamW(model1.parameters())
@@ -274,15 +287,26 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
274287
offload_gradients=offload_grad,
275288
)
276289

290+
rng = torch.Generator(device=device)
291+
rng.manual_seed(42)
292+
293+
# make sure to run both models separately; otherwise, model1 gives additional
294+
# time for operations in model2 to complete, marking potential race conditions.
277295
for _ in range(2):
278296
for _ in range(grad_accum):
279-
x = torch.randn(4, 32, device=device)
297+
x = torch.randn(4, 32, device=device, generator=rng)
280298
model1(x).sum().backward()
281-
model2(x).sum().backward()
282299

283300
optim1.step()
284301
optim1.zero_grad()
285302

303+
# reset the rng
304+
rng.manual_seed(42)
305+
for _ in range(2):
306+
for _ in range(grad_accum):
307+
x = torch.randn(4, 32, device=device, generator=rng)
308+
model2(x).sum().backward()
309+
286310
optim2.step()
287311
optim2.zero_grad()
288312

torchao/prototype/low_bit_optim/cpu_offload.py

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def step(self, closure=None):
107107
with getattr(torch, self.device).stream(self.stream):
108108
p_device.copy_(p_host, non_blocking=True)
109109

110+
# make sure param H2D finishes before the next forward pass
111+
self.stream.synchronize()
110112
self.queue.clear()
111113
return loss
112114

0 commit comments

Comments
 (0)