@@ -251,6 +251,346 @@ def fn():
251
251
252
252
self .check_output_and_recompiles (fn )
253
253
254
+ def test_reorder_acc_grad (self ):
255
+ model = torch .nn .Sequential (
256
+ torch .nn .Conv2d (4 , 4 , 3 , bias = True ),
257
+ torch .nn .Conv2d (4 , 4 , 3 , bias = True ),
258
+ )
259
+ compiled_model = torch .compile (model )
260
+ x = torch .randn ([1 , 4 , 32 , 32 ])
261
+
262
+ model (x ).sum ().backward ()
263
+ ref_res = [
264
+ model [0 ].weight .grad ,
265
+ model [0 ].bias .grad ,
266
+ model [1 ].weight .grad ,
267
+ model [1 ].bias .grad ,
268
+ ]
269
+
270
+ model [0 ].weight .grad = None
271
+ model [0 ].bias .grad = None
272
+ model [1 ].weight .grad = None
273
+ model [1 ].bias .grad = None
274
+ with compiled_autograd .enable (compiler_fn ):
275
+ compiled_model (x ).sum ().backward (retain_graph = True )
276
+ res = [
277
+ model [0 ].weight .grad ,
278
+ model [0 ].bias .grad ,
279
+ model [1 ].weight .grad ,
280
+ model [1 ].bias .grad ,
281
+ ]
282
+
283
+ self .assertEqual (res [0 ], ref_res [0 ])
284
+ self .assertEqual (res [1 ], ref_res [1 ])
285
+ self .assertEqual (res [2 ], ref_res [2 ])
286
+ self .assertEqual (res [3 ], ref_res [3 ])
287
+
288
+ def test_reorder_post_hook1 (self ):
289
+ def grad_div (param ):
290
+ param .grad = param .grad / 4.0
291
+
292
+ class Module (torch .nn .Module ):
293
+ def __init__ (self , ioc ):
294
+ super ().__init__ ()
295
+ self .fc1 = torch .nn .Linear (ioc , ioc , bias = False )
296
+ self .fc2 = torch .nn .Linear (ioc , ioc , bias = False )
297
+
298
+ self .grad_acc_hooks = []
299
+ self .grad_acc = []
300
+ self .params = [self .fc1 .weight , self .fc2 .weight ]
301
+ for i , param in enumerate (self .params ):
302
+
303
+ def wrapper (param ):
304
+ param_tmp = param .expand_as (param )
305
+ grad_acc = param_tmp .grad_fn .next_functions [0 ][0 ]
306
+
307
+ def grad_acc_hook (* notneeded ):
308
+ grad_div (param )
309
+
310
+ self .grad_acc .append (grad_acc )
311
+ self .grad_acc_hooks .append (
312
+ grad_acc .register_hook (grad_acc_hook )
313
+ )
314
+
315
+ wrapper (param )
316
+
317
+ def forward (self , x ):
318
+ x = self .fc1 (x )
319
+ x = self .fc2 (x )
320
+ return x .sum ()
321
+
322
+ bs = 8
323
+ ioc = 16
324
+ model = Module (ioc )
325
+ input = torch .randn ([bs , ioc ])
326
+
327
+ # eager ref
328
+ model (input ).backward ()
329
+ ref_res = [model .fc1 .weight .grad , model .fc2 .weight .grad ]
330
+
331
+ # cag
332
+ model .fc1 .weight .grad = None
333
+ model .fc2 .weight .grad = None
334
+ model_to_train = torch .compile (model , backend = "inductor" )
335
+ with compiled_autograd .enable (compiler_fn ):
336
+ model_to_train (input ).backward ()
337
+ res = [model_to_train .fc1 .weight .grad , model_to_train .fc2 .weight .grad ]
338
+
339
+ self .assertEqual (res [0 ], ref_res [0 ])
340
+ self .assertEqual (res [1 ], ref_res [1 ])
341
+
342
+ def test_reorder_post_hook2 (self ):
343
+ x = torch .randn ([1 , 4 , 32 , 32 ], requires_grad = True )
344
+ y = torch .sigmoid (x )
345
+ z = torch .tanh (y )
346
+
347
+ assert isinstance (z .grad_fn , torch .autograd .graph .Node )
348
+ assert isinstance (y .grad_fn , torch .autograd .graph .Node )
349
+ handle_z = z .grad_fn .register_hook (lambda gI , gO : (gO [0 ] * 2 ,))
350
+ handle_y = y .grad_fn .register_hook (lambda gI , gO : (gI [0 ] * 2 ,))
351
+ z .sum ().backward (retain_graph = True )
352
+ ref_res = x .grad
353
+
354
+ x .grad = None
355
+ with compiled_autograd .enable (compiler_fn ):
356
+ z .sum ().backward (retain_graph = True )
357
+ res = x .grad
358
+
359
+ self .assertEqual (res , ref_res )
360
+
361
+ def test_reorder_post_hook3 (self ):
362
+ conv = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
363
+ x = torch .randn ([1 , 4 , 32 , 32 ])
364
+ y = conv (x )
365
+
366
+ assert isinstance (y .grad_fn , torch .autograd .graph .Node )
367
+ # this hook will mul 2.0 to the conv weight gradient
368
+ handle_y = y .grad_fn .register_hook (lambda gI , gO : (gI [0 ], gI [1 ] * 2 , gI [2 ]))
369
+ y .sum ().backward (retain_graph = True )
370
+ ref_res = x .grad
371
+
372
+ x .grad = None
373
+ with compiled_autograd .enable (compiler_fn ):
374
+ y .sum ().backward (retain_graph = True )
375
+ res = x .grad
376
+
377
+ self .assertEqual (res , ref_res )
378
+
379
+ def test_reorder_all_bwd_hooks (self ):
380
+ def tensor_hook (grad ):
381
+ return grad .sub (2.0 )
382
+
383
+ def acc_grad_node_pre_hook (grad_out ):
384
+ return (grad_out [0 ].div (5.0 ),)
385
+
386
+ def post_acc_grad_hook (tensor ):
387
+ tensor .grad .add_ (3.0 )
388
+
389
+ class TestModel (torch .nn .Module ):
390
+ def __init__ (self ):
391
+ super ().__init__ ()
392
+ self .conv1 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
393
+ self .conv2 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
394
+
395
+ self .acc_grad1 = self .conv1 .weight .view_as (
396
+ self .conv1 .weight
397
+ ).grad_fn .next_functions [0 ][0 ]
398
+ self .conv1 .weight .register_hook (tensor_hook )
399
+ self .conv1 .weight .register_post_accumulate_grad_hook (post_acc_grad_hook )
400
+ self .acc_grad1 .register_prehook (acc_grad_node_pre_hook )
401
+
402
+ def acc_grad_node_post_hook1 (grad_in , grad_out ):
403
+ self .conv1 .weight .grad .mul_ (0.5 )
404
+
405
+ self .acc_grad1 .register_hook (acc_grad_node_post_hook1 )
406
+
407
+ self .acc_grad2 = self .conv2 .weight .view_as (
408
+ self .conv2 .weight
409
+ ).grad_fn .next_functions [0 ][0 ]
410
+ self .conv2 .weight .register_hook (tensor_hook )
411
+ self .conv2 .weight .register_post_accumulate_grad_hook (post_acc_grad_hook )
412
+ self .acc_grad2 .register_prehook (acc_grad_node_pre_hook )
413
+
414
+ def acc_grad_node_post_hook2 (grad_in , grad_out ):
415
+ self .conv2 .weight .grad .mul_ (0.5 )
416
+
417
+ self .acc_grad2 .register_hook (acc_grad_node_post_hook2 )
418
+
419
+ def forward (self , x ):
420
+ y = self .conv1 (x )
421
+ y = self .conv2 (y )
422
+ return y .sum ()
423
+
424
+ input = torch .randn ([1 , 4 , 32 , 32 ])
425
+
426
+ # eager ref
427
+ model = TestModel ()
428
+ model (input ).backward ()
429
+ ref_results = [model .conv1 .weight .grad , model .conv2 .weight .grad ]
430
+
431
+ # cag
432
+ model .conv1 .weight .grad = None
433
+ model .conv2 .weight .grad = None
434
+ compiled_model = torch .compile (model , backend = "inductor" )
435
+ with compiled_autograd .enable (compiler_fn ):
436
+ compiled_model (input ).backward ()
437
+ results = [compiled_model .conv1 .weight .grad , compiled_model .conv2 .weight .grad ]
438
+
439
+ self .assertEqual (results [0 ], ref_results [0 ])
440
+ self .assertEqual (results [1 ], ref_results [1 ])
441
+
442
+ def test_reorder_multi_post_hooks (self ):
443
+ class TestModel (torch .nn .Module ):
444
+ def __init__ (self ):
445
+ super ().__init__ ()
446
+ self .conv1 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
447
+ self .conv2 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
448
+
449
+ self .acc_grad1 = self .conv1 .weight .view_as (
450
+ self .conv1 .weight
451
+ ).grad_fn .next_functions [0 ][0 ]
452
+
453
+ def acc_grad_node1_post_hook1 (grad_in , grad_out ):
454
+ self .conv1 .weight .grad .mul_ (0.5 )
455
+
456
+ def acc_grad_node1_post_hook2 (grad_in , grad_out ):
457
+ self .conv1 .weight .grad .sub_ (0.3 )
458
+
459
+ self .acc_grad1 .register_hook (acc_grad_node1_post_hook1 )
460
+ self .acc_grad1 .register_hook (acc_grad_node1_post_hook2 )
461
+
462
+ self .acc_grad2 = self .conv2 .weight .view_as (
463
+ self .conv2 .weight
464
+ ).grad_fn .next_functions [0 ][0 ]
465
+
466
+ def acc_grad_node2_post_hook1 (grad_in , grad_out ):
467
+ self .conv2 .weight .grad .mul_ (0.3 )
468
+
469
+ def acc_grad_node2_post_hook2 (grad_in , grad_out ):
470
+ self .conv2 .weight .grad .sub_ (0.5 )
471
+
472
+ self .acc_grad2 .register_hook (acc_grad_node2_post_hook1 )
473
+ self .acc_grad2 .register_hook (acc_grad_node2_post_hook2 )
474
+
475
+ def forward (self , x ):
476
+ y = self .conv1 (x )
477
+ y = self .conv2 (y )
478
+ return y .sum ()
479
+
480
+ input = torch .randn ([1 , 4 , 32 , 32 ])
481
+
482
+ # eager ref
483
+ model = TestModel ()
484
+ model (input ).backward ()
485
+ ref_results = [model .conv1 .weight .grad , model .conv2 .weight .grad ]
486
+
487
+ # cag
488
+ model .conv1 .weight .grad = None
489
+ model .conv2 .weight .grad = None
490
+ compiled_model = torch .compile (model , backend = "inductor" )
491
+ with compiled_autograd .enable (compiler_fn ):
492
+ compiled_model (input ).backward ()
493
+ results = [compiled_model .conv1 .weight .grad , compiled_model .conv2 .weight .grad ]
494
+
495
+ self .assertEqual (results [0 ], ref_results [0 ])
496
+ self .assertEqual (results [1 ], ref_results [1 ])
497
+
498
+ def test_reorder_multi_pre_hooks (self ):
499
+ def acc_grad_node_pre_hook1 (grad_out ):
500
+ return (grad_out [0 ].div (5.0 ),)
501
+
502
+ def acc_grad_node_pre_hook2 (grad_out ):
503
+ return (grad_out [0 ].sub (0.3 ),)
504
+
505
+ class TestModel (torch .nn .Module ):
506
+ def __init__ (self ):
507
+ super ().__init__ ()
508
+ self .conv1 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
509
+ self .conv2 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
510
+
511
+ self .acc_grad1 = self .conv1 .weight .view_as (
512
+ self .conv1 .weight
513
+ ).grad_fn .next_functions [0 ][0 ]
514
+ self .acc_grad1 .register_prehook (acc_grad_node_pre_hook1 )
515
+ self .acc_grad1 .register_prehook (acc_grad_node_pre_hook2 )
516
+
517
+ self .acc_grad2 = self .conv2 .weight .view_as (
518
+ self .conv2 .weight
519
+ ).grad_fn .next_functions [0 ][0 ]
520
+ self .acc_grad2 .register_prehook (acc_grad_node_pre_hook1 )
521
+ self .acc_grad2 .register_prehook (acc_grad_node_pre_hook2 )
522
+
523
+ def forward (self , x ):
524
+ y = self .conv1 (x )
525
+ y = self .conv2 (y )
526
+ return y .sum ()
527
+
528
+ input = torch .randn ([1 , 4 , 32 , 32 ])
529
+
530
+ # eager ref
531
+ model = TestModel ()
532
+ model (input ).backward ()
533
+ ref_results = [model .conv1 .weight .grad , model .conv2 .weight .grad ]
534
+
535
+ # cag
536
+ model .conv1 .weight .grad = None
537
+ model .conv2 .weight .grad = None
538
+ compiled_model = torch .compile (model , backend = "inductor" )
539
+ with compiled_autograd .enable (compiler_fn ):
540
+ compiled_model (input ).backward ()
541
+ results = [compiled_model .conv1 .weight .grad , compiled_model .conv2 .weight .grad ]
542
+
543
+ self .assertEqual (results [0 ], ref_results [0 ])
544
+ self .assertEqual (results [1 ], ref_results [1 ])
545
+
546
+ def test_reorder_multi_tensor_pre_hooks (self ):
547
+ def tensor_hook1 (grad ):
548
+ return grad .sub (2.0 )
549
+
550
+ def tensor_hook2 (grad ):
551
+ return grad .mul (0.5 )
552
+
553
+ class TestModel (torch .nn .Module ):
554
+ def __init__ (self ):
555
+ super ().__init__ ()
556
+ self .conv1 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
557
+ self .conv2 = torch .nn .Conv2d (4 , 4 , 3 , bias = False )
558
+
559
+ self .acc_grad1 = self .conv1 .weight .view_as (
560
+ self .conv1 .weight
561
+ ).grad_fn .next_functions [0 ][0 ]
562
+ self .conv1 .weight .register_hook (tensor_hook1 )
563
+ self .conv1 .weight .register_hook (tensor_hook2 )
564
+
565
+ self .acc_grad2 = self .conv2 .weight .view_as (
566
+ self .conv2 .weight
567
+ ).grad_fn .next_functions [0 ][0 ]
568
+ self .conv2 .weight .register_hook (tensor_hook1 )
569
+ self .conv2 .weight .register_hook (tensor_hook2 )
570
+
571
+ def forward (self , x ):
572
+ y = self .conv1 (x )
573
+ y = self .conv2 (y )
574
+ return y .sum ()
575
+
576
+ input = torch .randn ([1 , 4 , 32 , 32 ])
577
+
578
+ # eager ref
579
+ model = TestModel ()
580
+ model (input ).backward ()
581
+ ref_results = [model .conv1 .weight .grad , model .conv2 .weight .grad ]
582
+
583
+ # cag
584
+ model .conv1 .weight .grad = None
585
+ model .conv2 .weight .grad = None
586
+ compiled_model = torch .compile (model , backend = "inductor" )
587
+ with compiled_autograd .enable (compiler_fn ):
588
+ compiled_model (input ).backward ()
589
+ results = [compiled_model .conv1 .weight .grad , compiled_model .conv2 .weight .grad ]
590
+
591
+ self .assertEqual (results [0 ], ref_results [0 ])
592
+ self .assertEqual (results [1 ], ref_results [1 ])
593
+
254
594
def test_torch_compile (self ):
255
595
def fn ():
256
596
model = torch .nn .Sequential (
@@ -2990,7 +3330,8 @@ def wrap_test_class(orig_cls):
2990
3330
"test_save_output_nr" , # output_nr grad passed as None
2991
3331
"test_setup_context_when_forward_has_default_args" , # autograd.Function with class methods
2992
3332
"test_lobpcg" , # create_graph
2993
- "test_grad_nonleaf_register_hook" , # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
3333
+ # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
3334
+ "test_grad_nonleaf_register_hook" ,
2994
3335
"test_backward_twice_without_saved_values" , # https://github.com/pytorch/pytorch/issues/129938
2995
3336
# Category: Dynamo
2996
3337
"test_accumulate_grad_tensor_reference" , # Out of bounds: frame_state_entry.stride[i] is None
0 commit comments