1
1
# Owner(s): ["module: autograd"]
2
2
3
+ import contextlib
3
4
import gc
4
5
import io
5
6
import math
31
32
from torch .testing ._internal .common_cuda import TEST_CUDA
32
33
from torch .testing ._internal .common_utils import (TestCase , run_tests , skipIfNoLapack ,
33
34
slowTest , IS_WINDOWS , IS_MACOS , CudaMemoryLeakCheck ,
34
- disable_gc , gradcheck , gradgradcheck )
35
+ disable_gc , gradcheck , gradgradcheck ,
36
+ parametrize , instantiate_parametrized_tests )
35
37
from torch .autograd import Variable , Function , detect_anomaly , kineto_available
36
38
from torch .autograd .function import InplaceFunction
37
39
import torch .autograd .forward_ad as fwAD
@@ -4308,6 +4310,50 @@ def test_checkpointing(self):
4308
4310
mean_combined = torch .stack (feat_combined ).mean ()
4309
4311
mean_combined .backward ()
4310
4312
4313
+ @slowTest
4314
+ @parametrize ("input_requires_grad" , [True , False ])
4315
+ def test_checkpointing_without_reentrant (self , input_requires_grad ):
4316
+ """
4317
+ Basic test for checkpoint without reentrant autograd.
4318
+ """
4319
+ num_inp = 2000
4320
+ nz_inp = 10
4321
+ nz_out = 10
4322
+ nz_bottleneck = 1000
4323
+
4324
+ # small proxy network for some complex reasoning we want to do per input
4325
+ module = nn .Sequential (
4326
+ nn .Linear (nz_inp , nz_bottleneck ),
4327
+ nn .ReLU (),
4328
+ nn .Linear (nz_bottleneck , nz_inp )
4329
+ )
4330
+
4331
+ # Run model with and without checkpointing and verify gradients are
4332
+ # equivalent, regardless of if inputs require grads or not.
4333
+ module_copy = deepcopy (module )
4334
+
4335
+ feat_combined = []
4336
+ feat_combined_no_checkpoint = []
4337
+ for r in range (num_inp ):
4338
+ data_r = torch .empty (1 , nz_inp )
4339
+ data_r .uniform_ ()
4340
+ data_r .requires_grad = input_requires_grad
4341
+ data_r_copy = data_r .clone ()
4342
+ feat_r = checkpoint (module , data_r , use_reentrant = False )
4343
+ feat_combined .append (feat_r )
4344
+ feat_r_no_checkpoint = module_copy (data_r )
4345
+ feat_combined_no_checkpoint .append (feat_r_no_checkpoint )
4346
+
4347
+
4348
+ # compute mean as a proxy for some joint reasoning
4349
+ mean_combined = torch .stack (feat_combined ).mean ()
4350
+ mean_combined .backward ()
4351
+ mean_combined_no_checkpoint = torch .stack (feat_combined_no_checkpoint ).mean ()
4352
+ mean_combined_no_checkpoint .backward ()
4353
+
4354
+ for checkpoint_param , param in zip (module .parameters (), module_copy .parameters ()):
4355
+ self .assertEqual (checkpoint_param .grad , param .grad )
4356
+
4311
4357
def test_checkpoint_valid_reset_on_error (self ):
4312
4358
a = torch .randn (2 , 2 , requires_grad = True )
4313
4359
@@ -4318,6 +4364,156 @@ def test_checkpoint_valid_reset_on_error(self):
4318
4364
c = checkpoint (torch .exp , a ).sum ()
4319
4365
c .backward ()
4320
4366
4367
+ @parametrize ("use_reentrant" , [True , False ])
4368
+ def test_checkpointing_without_reentrant_detached_tensor (self , use_reentrant ):
4369
+ class NoGradModule (torch .nn .Module ):
4370
+ def __init__ (self ):
4371
+ super ().__init__ ()
4372
+ self .linear = nn .Linear (2 , 2 , bias = False )
4373
+ self .lin2 = nn .Linear (2 , 2 , bias = False )
4374
+
4375
+ def forward (self , x ):
4376
+ with torch .no_grad ():
4377
+ return self .lin2 (self .linear (x ))
4378
+
4379
+ module = NoGradModule ()
4380
+
4381
+ err_ctx = (
4382
+ self .assertRaisesRegex (
4383
+ RuntimeError ,
4384
+ "none of output has requires_grad=True"
4385
+ )
4386
+ if use_reentrant
4387
+ else contextlib .suppress ()
4388
+ )
4389
+
4390
+ a = torch .randn (2 , 2 , requires_grad = True )
4391
+ for _ in range (3 ):
4392
+ with err_ctx :
4393
+ # out does not require grad
4394
+ out = checkpoint (module , a , use_reentrant = use_reentrant )
4395
+ # Make loss require grad, otherwise we would run into
4396
+ # "element 0 of tensors does not require grad and does not have a grad_fn"
4397
+ out += a
4398
+ out .sum ().backward ()
4399
+
4400
+ def test_checkpointing_without_reentrant_correct_grad (self ):
4401
+ """
4402
+ Verifies that correct gradients are calculated for checkpoint
4403
+ without reentrant autograd, for both backward() and autograd.grad().
4404
+ """
4405
+ a = torch .randn (2 , 2 , requires_grad = True )
4406
+
4407
+ b = torch .exp (a ).sum ()
4408
+ b .backward ()
4409
+ b_grad = a .grad
4410
+
4411
+ a .grad = None
4412
+ c = checkpoint (torch .exp , a , use_reentrant = False ).sum ()
4413
+ c .backward ()
4414
+ c_grad = a .grad
4415
+
4416
+ a .grad = None
4417
+ d = checkpoint (torch .exp , a , use_reentrant = False ).sum ()
4418
+ d_grad , = torch .autograd .grad (d , (a ,))
4419
+
4420
+ self .assertEqual (b_grad , c_grad )
4421
+ self .assertEqual (b_grad , d_grad )
4422
+
4423
+ def test_checkpointing_without_reentrant_dataparallel (self ):
4424
+ """
4425
+ Verifies gradient correctness when checkpoint without reentrant autograd
4426
+ is used in conjunction with DataParallel.
4427
+ """
4428
+ class LinearModule (torch .nn .Module ):
4429
+ def __init__ (self ):
4430
+ super ().__init__ ()
4431
+ self .linear = nn .Linear (2 , 2 , bias = False )
4432
+
4433
+ def forward (self , inp ):
4434
+ return self .linear (inp )
4435
+
4436
+ a = torch .randn (2 , 2 , requires_grad = True )
4437
+ if torch .cuda .is_available ():
4438
+ a = a .cuda ()
4439
+
4440
+ model = LinearModule ()
4441
+ if torch .cuda .is_available ():
4442
+ model = model .cuda ()
4443
+
4444
+ b = deepcopy (model )(a ).sum ()
4445
+ b .backward ()
4446
+ b_grad = a .grad
4447
+
4448
+ a .grad = None
4449
+
4450
+ module = torch .nn .DataParallel (deepcopy (model ))
4451
+ c = checkpoint (module , a , use_reentrant = False ).sum ()
4452
+ c .backward ()
4453
+ c_grad = a .grad
4454
+
4455
+ self .assertEqual (b_grad , c_grad )
4456
+
4457
+ def test_checkpointing_without_reentrant_parameter_used_in_an_out (self ):
4458
+ """
4459
+ Ensures that gradient hooks are only called once per tensor.
4460
+ """
4461
+ w = torch .randn (10 , 10 , requires_grad = True )
4462
+ count = 0
4463
+
4464
+ def hook (grad ):
4465
+ nonlocal count
4466
+ count += 1
4467
+
4468
+ w .register_hook (hook )
4469
+ x = torch .rand (10 , 10 , requires_grad = True )
4470
+ h = w * x # Using w outside the checkpoint
4471
+ out = checkpoint (lambda x : w * x , h , use_reentrant = False ) # Using w inside the checkpoint
4472
+
4473
+ out .sum ().backward ()
4474
+ # should only call hook once
4475
+ self .assertEqual (count , 1 )
4476
+
4477
+ def test_checkpointing_without_reentrant_arbitrary_input_output (self ):
4478
+ """
4479
+ Ensures checkpointing without reentrant autograd works with functions
4480
+ with arbitrary input/output structures.
4481
+ """
4482
+
4483
+ class MyModel (torch .nn .Module ):
4484
+ def __init__ (self ):
4485
+ super ().__init__ ()
4486
+ self .layer = torch .nn .Linear (5 , 5 , bias = False )
4487
+
4488
+ def forward (self , dict_input ):
4489
+ tensor = dict_input ["tensor" ]
4490
+ return {
4491
+ "result" : self .layer (tensor )
4492
+ }
4493
+
4494
+ model_no_checkpoint = MyModel ()
4495
+ model_checkpoint_without_reentrant = deepcopy (model_no_checkpoint )
4496
+
4497
+ inp = {
4498
+ "tensor" : torch .randn (5 , 5 )
4499
+ }
4500
+
4501
+ out_no_checkpoint = model_no_checkpoint (inp )["result" ].sum ()
4502
+
4503
+ out_checkpoint = checkpoint (
4504
+ model_checkpoint_without_reentrant ,
4505
+ inp ,
4506
+ use_reentrant = False
4507
+ )["result" ].sum ()
4508
+
4509
+ self .assertEqual (out_checkpoint , out_no_checkpoint )
4510
+
4511
+ out_no_checkpoint .backward ()
4512
+ out_checkpoint .backward ()
4513
+
4514
+ for param , checkpoint_param in zip (model_no_checkpoint .parameters (), model_checkpoint_without_reentrant .parameters ()):
4515
+ self .assertEqual (param .grad , checkpoint_param .grad )
4516
+
4321
4517
def test_callback_adds_callback (self ):
4322
4518
called = [0 ]
4323
4519
@@ -9108,5 +9304,7 @@ def fn(x1, x2):
9108
9304
except_for = None
9109
9305
)
9110
9306
9307
+ instantiate_parametrized_tests (TestAutograd )
9308
+
9111
9309
if __name__ == '__main__' :
9112
9310
run_tests ()
0 commit comments