Skip to content

Commit a19bdfb

Browse files
YangQun1pytorchmergebot
authored andcommitted
[compiled autograd] reorder backward hooks to match eager behavior (pytorch#138553)
Fixes pytorch#138538 Pull Request resolved: pytorch#138553 Approved by: https://github.com/xmfan
1 parent b71ab3f commit a19bdfb

File tree

2 files changed

+564
-3
lines changed

2 files changed

+564
-3
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 342 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,346 @@ def fn():
251251

252252
self.check_output_and_recompiles(fn)
253253

254+
def test_reorder_acc_grad(self):
255+
model = torch.nn.Sequential(
256+
torch.nn.Conv2d(4, 4, 3, bias=True),
257+
torch.nn.Conv2d(4, 4, 3, bias=True),
258+
)
259+
compiled_model = torch.compile(model)
260+
x = torch.randn([1, 4, 32, 32])
261+
262+
model(x).sum().backward()
263+
ref_res = [
264+
model[0].weight.grad,
265+
model[0].bias.grad,
266+
model[1].weight.grad,
267+
model[1].bias.grad,
268+
]
269+
270+
model[0].weight.grad = None
271+
model[0].bias.grad = None
272+
model[1].weight.grad = None
273+
model[1].bias.grad = None
274+
with compiled_autograd.enable(compiler_fn):
275+
compiled_model(x).sum().backward(retain_graph=True)
276+
res = [
277+
model[0].weight.grad,
278+
model[0].bias.grad,
279+
model[1].weight.grad,
280+
model[1].bias.grad,
281+
]
282+
283+
self.assertEqual(res[0], ref_res[0])
284+
self.assertEqual(res[1], ref_res[1])
285+
self.assertEqual(res[2], ref_res[2])
286+
self.assertEqual(res[3], ref_res[3])
287+
288+
def test_reorder_post_hook1(self):
289+
def grad_div(param):
290+
param.grad = param.grad / 4.0
291+
292+
class Module(torch.nn.Module):
293+
def __init__(self, ioc):
294+
super().__init__()
295+
self.fc1 = torch.nn.Linear(ioc, ioc, bias=False)
296+
self.fc2 = torch.nn.Linear(ioc, ioc, bias=False)
297+
298+
self.grad_acc_hooks = []
299+
self.grad_acc = []
300+
self.params = [self.fc1.weight, self.fc2.weight]
301+
for i, param in enumerate(self.params):
302+
303+
def wrapper(param):
304+
param_tmp = param.expand_as(param)
305+
grad_acc = param_tmp.grad_fn.next_functions[0][0]
306+
307+
def grad_acc_hook(*notneeded):
308+
grad_div(param)
309+
310+
self.grad_acc.append(grad_acc)
311+
self.grad_acc_hooks.append(
312+
grad_acc.register_hook(grad_acc_hook)
313+
)
314+
315+
wrapper(param)
316+
317+
def forward(self, x):
318+
x = self.fc1(x)
319+
x = self.fc2(x)
320+
return x.sum()
321+
322+
bs = 8
323+
ioc = 16
324+
model = Module(ioc)
325+
input = torch.randn([bs, ioc])
326+
327+
# eager ref
328+
model(input).backward()
329+
ref_res = [model.fc1.weight.grad, model.fc2.weight.grad]
330+
331+
# cag
332+
model.fc1.weight.grad = None
333+
model.fc2.weight.grad = None
334+
model_to_train = torch.compile(model, backend="inductor")
335+
with compiled_autograd.enable(compiler_fn):
336+
model_to_train(input).backward()
337+
res = [model_to_train.fc1.weight.grad, model_to_train.fc2.weight.grad]
338+
339+
self.assertEqual(res[0], ref_res[0])
340+
self.assertEqual(res[1], ref_res[1])
341+
342+
def test_reorder_post_hook2(self):
343+
x = torch.randn([1, 4, 32, 32], requires_grad=True)
344+
y = torch.sigmoid(x)
345+
z = torch.tanh(y)
346+
347+
assert isinstance(z.grad_fn, torch.autograd.graph.Node)
348+
assert isinstance(y.grad_fn, torch.autograd.graph.Node)
349+
handle_z = z.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
350+
handle_y = y.grad_fn.register_hook(lambda gI, gO: (gI[0] * 2,))
351+
z.sum().backward(retain_graph=True)
352+
ref_res = x.grad
353+
354+
x.grad = None
355+
with compiled_autograd.enable(compiler_fn):
356+
z.sum().backward(retain_graph=True)
357+
res = x.grad
358+
359+
self.assertEqual(res, ref_res)
360+
361+
def test_reorder_post_hook3(self):
362+
conv = torch.nn.Conv2d(4, 4, 3, bias=False)
363+
x = torch.randn([1, 4, 32, 32])
364+
y = conv(x)
365+
366+
assert isinstance(y.grad_fn, torch.autograd.graph.Node)
367+
# this hook will mul 2.0 to the conv weight gradient
368+
handle_y = y.grad_fn.register_hook(lambda gI, gO: (gI[0], gI[1] * 2, gI[2]))
369+
y.sum().backward(retain_graph=True)
370+
ref_res = x.grad
371+
372+
x.grad = None
373+
with compiled_autograd.enable(compiler_fn):
374+
y.sum().backward(retain_graph=True)
375+
res = x.grad
376+
377+
self.assertEqual(res, ref_res)
378+
379+
def test_reorder_all_bwd_hooks(self):
380+
def tensor_hook(grad):
381+
return grad.sub(2.0)
382+
383+
def acc_grad_node_pre_hook(grad_out):
384+
return (grad_out[0].div(5.0),)
385+
386+
def post_acc_grad_hook(tensor):
387+
tensor.grad.add_(3.0)
388+
389+
class TestModel(torch.nn.Module):
390+
def __init__(self):
391+
super().__init__()
392+
self.conv1 = torch.nn.Conv2d(4, 4, 3, bias=False)
393+
self.conv2 = torch.nn.Conv2d(4, 4, 3, bias=False)
394+
395+
self.acc_grad1 = self.conv1.weight.view_as(
396+
self.conv1.weight
397+
).grad_fn.next_functions[0][0]
398+
self.conv1.weight.register_hook(tensor_hook)
399+
self.conv1.weight.register_post_accumulate_grad_hook(post_acc_grad_hook)
400+
self.acc_grad1.register_prehook(acc_grad_node_pre_hook)
401+
402+
def acc_grad_node_post_hook1(grad_in, grad_out):
403+
self.conv1.weight.grad.mul_(0.5)
404+
405+
self.acc_grad1.register_hook(acc_grad_node_post_hook1)
406+
407+
self.acc_grad2 = self.conv2.weight.view_as(
408+
self.conv2.weight
409+
).grad_fn.next_functions[0][0]
410+
self.conv2.weight.register_hook(tensor_hook)
411+
self.conv2.weight.register_post_accumulate_grad_hook(post_acc_grad_hook)
412+
self.acc_grad2.register_prehook(acc_grad_node_pre_hook)
413+
414+
def acc_grad_node_post_hook2(grad_in, grad_out):
415+
self.conv2.weight.grad.mul_(0.5)
416+
417+
self.acc_grad2.register_hook(acc_grad_node_post_hook2)
418+
419+
def forward(self, x):
420+
y = self.conv1(x)
421+
y = self.conv2(y)
422+
return y.sum()
423+
424+
input = torch.randn([1, 4, 32, 32])
425+
426+
# eager ref
427+
model = TestModel()
428+
model(input).backward()
429+
ref_results = [model.conv1.weight.grad, model.conv2.weight.grad]
430+
431+
# cag
432+
model.conv1.weight.grad = None
433+
model.conv2.weight.grad = None
434+
compiled_model = torch.compile(model, backend="inductor")
435+
with compiled_autograd.enable(compiler_fn):
436+
compiled_model(input).backward()
437+
results = [compiled_model.conv1.weight.grad, compiled_model.conv2.weight.grad]
438+
439+
self.assertEqual(results[0], ref_results[0])
440+
self.assertEqual(results[1], ref_results[1])
441+
442+
def test_reorder_multi_post_hooks(self):
443+
class TestModel(torch.nn.Module):
444+
def __init__(self):
445+
super().__init__()
446+
self.conv1 = torch.nn.Conv2d(4, 4, 3, bias=False)
447+
self.conv2 = torch.nn.Conv2d(4, 4, 3, bias=False)
448+
449+
self.acc_grad1 = self.conv1.weight.view_as(
450+
self.conv1.weight
451+
).grad_fn.next_functions[0][0]
452+
453+
def acc_grad_node1_post_hook1(grad_in, grad_out):
454+
self.conv1.weight.grad.mul_(0.5)
455+
456+
def acc_grad_node1_post_hook2(grad_in, grad_out):
457+
self.conv1.weight.grad.sub_(0.3)
458+
459+
self.acc_grad1.register_hook(acc_grad_node1_post_hook1)
460+
self.acc_grad1.register_hook(acc_grad_node1_post_hook2)
461+
462+
self.acc_grad2 = self.conv2.weight.view_as(
463+
self.conv2.weight
464+
).grad_fn.next_functions[0][0]
465+
466+
def acc_grad_node2_post_hook1(grad_in, grad_out):
467+
self.conv2.weight.grad.mul_(0.3)
468+
469+
def acc_grad_node2_post_hook2(grad_in, grad_out):
470+
self.conv2.weight.grad.sub_(0.5)
471+
472+
self.acc_grad2.register_hook(acc_grad_node2_post_hook1)
473+
self.acc_grad2.register_hook(acc_grad_node2_post_hook2)
474+
475+
def forward(self, x):
476+
y = self.conv1(x)
477+
y = self.conv2(y)
478+
return y.sum()
479+
480+
input = torch.randn([1, 4, 32, 32])
481+
482+
# eager ref
483+
model = TestModel()
484+
model(input).backward()
485+
ref_results = [model.conv1.weight.grad, model.conv2.weight.grad]
486+
487+
# cag
488+
model.conv1.weight.grad = None
489+
model.conv2.weight.grad = None
490+
compiled_model = torch.compile(model, backend="inductor")
491+
with compiled_autograd.enable(compiler_fn):
492+
compiled_model(input).backward()
493+
results = [compiled_model.conv1.weight.grad, compiled_model.conv2.weight.grad]
494+
495+
self.assertEqual(results[0], ref_results[0])
496+
self.assertEqual(results[1], ref_results[1])
497+
498+
def test_reorder_multi_pre_hooks(self):
499+
def acc_grad_node_pre_hook1(grad_out):
500+
return (grad_out[0].div(5.0),)
501+
502+
def acc_grad_node_pre_hook2(grad_out):
503+
return (grad_out[0].sub(0.3),)
504+
505+
class TestModel(torch.nn.Module):
506+
def __init__(self):
507+
super().__init__()
508+
self.conv1 = torch.nn.Conv2d(4, 4, 3, bias=False)
509+
self.conv2 = torch.nn.Conv2d(4, 4, 3, bias=False)
510+
511+
self.acc_grad1 = self.conv1.weight.view_as(
512+
self.conv1.weight
513+
).grad_fn.next_functions[0][0]
514+
self.acc_grad1.register_prehook(acc_grad_node_pre_hook1)
515+
self.acc_grad1.register_prehook(acc_grad_node_pre_hook2)
516+
517+
self.acc_grad2 = self.conv2.weight.view_as(
518+
self.conv2.weight
519+
).grad_fn.next_functions[0][0]
520+
self.acc_grad2.register_prehook(acc_grad_node_pre_hook1)
521+
self.acc_grad2.register_prehook(acc_grad_node_pre_hook2)
522+
523+
def forward(self, x):
524+
y = self.conv1(x)
525+
y = self.conv2(y)
526+
return y.sum()
527+
528+
input = torch.randn([1, 4, 32, 32])
529+
530+
# eager ref
531+
model = TestModel()
532+
model(input).backward()
533+
ref_results = [model.conv1.weight.grad, model.conv2.weight.grad]
534+
535+
# cag
536+
model.conv1.weight.grad = None
537+
model.conv2.weight.grad = None
538+
compiled_model = torch.compile(model, backend="inductor")
539+
with compiled_autograd.enable(compiler_fn):
540+
compiled_model(input).backward()
541+
results = [compiled_model.conv1.weight.grad, compiled_model.conv2.weight.grad]
542+
543+
self.assertEqual(results[0], ref_results[0])
544+
self.assertEqual(results[1], ref_results[1])
545+
546+
def test_reorder_multi_tensor_pre_hooks(self):
547+
def tensor_hook1(grad):
548+
return grad.sub(2.0)
549+
550+
def tensor_hook2(grad):
551+
return grad.mul(0.5)
552+
553+
class TestModel(torch.nn.Module):
554+
def __init__(self):
555+
super().__init__()
556+
self.conv1 = torch.nn.Conv2d(4, 4, 3, bias=False)
557+
self.conv2 = torch.nn.Conv2d(4, 4, 3, bias=False)
558+
559+
self.acc_grad1 = self.conv1.weight.view_as(
560+
self.conv1.weight
561+
).grad_fn.next_functions[0][0]
562+
self.conv1.weight.register_hook(tensor_hook1)
563+
self.conv1.weight.register_hook(tensor_hook2)
564+
565+
self.acc_grad2 = self.conv2.weight.view_as(
566+
self.conv2.weight
567+
).grad_fn.next_functions[0][0]
568+
self.conv2.weight.register_hook(tensor_hook1)
569+
self.conv2.weight.register_hook(tensor_hook2)
570+
571+
def forward(self, x):
572+
y = self.conv1(x)
573+
y = self.conv2(y)
574+
return y.sum()
575+
576+
input = torch.randn([1, 4, 32, 32])
577+
578+
# eager ref
579+
model = TestModel()
580+
model(input).backward()
581+
ref_results = [model.conv1.weight.grad, model.conv2.weight.grad]
582+
583+
# cag
584+
model.conv1.weight.grad = None
585+
model.conv2.weight.grad = None
586+
compiled_model = torch.compile(model, backend="inductor")
587+
with compiled_autograd.enable(compiler_fn):
588+
compiled_model(input).backward()
589+
results = [compiled_model.conv1.weight.grad, compiled_model.conv2.weight.grad]
590+
591+
self.assertEqual(results[0], ref_results[0])
592+
self.assertEqual(results[1], ref_results[1])
593+
254594
def test_torch_compile(self):
255595
def fn():
256596
model = torch.nn.Sequential(
@@ -2990,7 +3330,8 @@ def wrap_test_class(orig_cls):
29903330
"test_save_output_nr", # output_nr grad passed as None
29913331
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
29923332
"test_lobpcg", # create_graph
2993-
"test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
3333+
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
3334+
"test_grad_nonleaf_register_hook",
29943335
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
29953336
# Category: Dynamo
29963337
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None

0 commit comments

Comments
 (0)