11# Owner(s): ["module: autograd"]
22
3+ import contextlib
34import gc
45import io
56import math
3132from torch .testing ._internal .common_cuda import TEST_CUDA
3233from torch .testing ._internal .common_utils import (TestCase , run_tests , skipIfNoLapack ,
3334 slowTest , IS_WINDOWS , IS_MACOS , CudaMemoryLeakCheck ,
34- disable_gc , gradcheck , gradgradcheck )
35+ disable_gc , gradcheck , gradgradcheck ,
36+ parametrize , instantiate_parametrized_tests )
3537from torch .autograd import Variable , Function , detect_anomaly , kineto_available
3638from torch .autograd .function import InplaceFunction
3739import torch .autograd .forward_ad as fwAD
@@ -4308,6 +4310,50 @@ def test_checkpointing(self):
43084310 mean_combined = torch .stack (feat_combined ).mean ()
43094311 mean_combined .backward ()
43104312
4313+ @slowTest
4314+ @parametrize ("input_requires_grad" , [True , False ])
4315+ def test_checkpointing_without_reentrant (self , input_requires_grad ):
4316+ """
4317+ Basic test for checkpoint without reentrant autograd.
4318+ """
4319+ num_inp = 2000
4320+ nz_inp = 10
4321+ nz_out = 10
4322+ nz_bottleneck = 1000
4323+
4324+ # small proxy network for some complex reasoning we want to do per input
4325+ module = nn .Sequential (
4326+ nn .Linear (nz_inp , nz_bottleneck ),
4327+ nn .ReLU (),
4328+ nn .Linear (nz_bottleneck , nz_inp )
4329+ )
4330+
4331+ # Run model with and without checkpointing and verify gradients are
4332+ # equivalent, regardless of if inputs require grads or not.
4333+ module_copy = deepcopy (module )
4334+
4335+ feat_combined = []
4336+ feat_combined_no_checkpoint = []
4337+ for r in range (num_inp ):
4338+ data_r = torch .empty (1 , nz_inp )
4339+ data_r .uniform_ ()
4340+ data_r .requires_grad = input_requires_grad
4341+ data_r_copy = data_r .clone ()
4342+ feat_r = checkpoint (module , data_r , use_reentrant = False )
4343+ feat_combined .append (feat_r )
4344+ feat_r_no_checkpoint = module_copy (data_r )
4345+ feat_combined_no_checkpoint .append (feat_r_no_checkpoint )
4346+
4347+
4348+ # compute mean as a proxy for some joint reasoning
4349+ mean_combined = torch .stack (feat_combined ).mean ()
4350+ mean_combined .backward ()
4351+ mean_combined_no_checkpoint = torch .stack (feat_combined_no_checkpoint ).mean ()
4352+ mean_combined_no_checkpoint .backward ()
4353+
4354+ for checkpoint_param , param in zip (module .parameters (), module_copy .parameters ()):
4355+ self .assertEqual (checkpoint_param .grad , param .grad )
4356+
43114357 def test_checkpoint_valid_reset_on_error (self ):
43124358 a = torch .randn (2 , 2 , requires_grad = True )
43134359
@@ -4318,6 +4364,156 @@ def test_checkpoint_valid_reset_on_error(self):
43184364 c = checkpoint (torch .exp , a ).sum ()
43194365 c .backward ()
43204366
4367+ @parametrize ("use_reentrant" , [True , False ])
4368+ def test_checkpointing_without_reentrant_detached_tensor (self , use_reentrant ):
4369+ class NoGradModule (torch .nn .Module ):
4370+ def __init__ (self ):
4371+ super ().__init__ ()
4372+ self .linear = nn .Linear (2 , 2 , bias = False )
4373+ self .lin2 = nn .Linear (2 , 2 , bias = False )
4374+
4375+ def forward (self , x ):
4376+ with torch .no_grad ():
4377+ return self .lin2 (self .linear (x ))
4378+
4379+ module = NoGradModule ()
4380+
4381+ err_ctx = (
4382+ self .assertRaisesRegex (
4383+ RuntimeError ,
4384+ "none of output has requires_grad=True"
4385+ )
4386+ if use_reentrant
4387+ else contextlib .suppress ()
4388+ )
4389+
4390+ a = torch .randn (2 , 2 , requires_grad = True )
4391+ for _ in range (3 ):
4392+ with err_ctx :
4393+ # out does not require grad
4394+ out = checkpoint (module , a , use_reentrant = use_reentrant )
4395+ # Make loss require grad, otherwise we would run into
4396+ # "element 0 of tensors does not require grad and does not have a grad_fn"
4397+ out += a
4398+ out .sum ().backward ()
4399+
4400+ def test_checkpointing_without_reentrant_correct_grad (self ):
4401+ """
4402+ Verifies that correct gradients are calculated for checkpoint
4403+ without reentrant autograd, for both backward() and autograd.grad().
4404+ """
4405+ a = torch .randn (2 , 2 , requires_grad = True )
4406+
4407+ b = torch .exp (a ).sum ()
4408+ b .backward ()
4409+ b_grad = a .grad
4410+
4411+ a .grad = None
4412+ c = checkpoint (torch .exp , a , use_reentrant = False ).sum ()
4413+ c .backward ()
4414+ c_grad = a .grad
4415+
4416+ a .grad = None
4417+ d = checkpoint (torch .exp , a , use_reentrant = False ).sum ()
4418+ d_grad , = torch .autograd .grad (d , (a ,))
4419+
4420+ self .assertEqual (b_grad , c_grad )
4421+ self .assertEqual (b_grad , d_grad )
4422+
4423+ def test_checkpointing_without_reentrant_dataparallel (self ):
4424+ """
4425+ Verifies gradient correctness when checkpoint without reentrant autograd
4426+ is used in conjunction with DataParallel.
4427+ """
4428+ class LinearModule (torch .nn .Module ):
4429+ def __init__ (self ):
4430+ super ().__init__ ()
4431+ self .linear = nn .Linear (2 , 2 , bias = False )
4432+
4433+ def forward (self , inp ):
4434+ return self .linear (inp )
4435+
4436+ a = torch .randn (2 , 2 , requires_grad = True )
4437+ if torch .cuda .is_available ():
4438+ a = a .cuda ()
4439+
4440+ model = LinearModule ()
4441+ if torch .cuda .is_available ():
4442+ model = model .cuda ()
4443+
4444+ b = deepcopy (model )(a ).sum ()
4445+ b .backward ()
4446+ b_grad = a .grad
4447+
4448+ a .grad = None
4449+
4450+ module = torch .nn .DataParallel (deepcopy (model ))
4451+ c = checkpoint (module , a , use_reentrant = False ).sum ()
4452+ c .backward ()
4453+ c_grad = a .grad
4454+
4455+ self .assertEqual (b_grad , c_grad )
4456+
4457+ def test_checkpointing_without_reentrant_parameter_used_in_an_out (self ):
4458+ """
4459+ Ensures that gradient hooks are only called once per tensor.
4460+ """
4461+ w = torch .randn (10 , 10 , requires_grad = True )
4462+ count = 0
4463+
4464+ def hook (grad ):
4465+ nonlocal count
4466+ count += 1
4467+
4468+ w .register_hook (hook )
4469+ x = torch .rand (10 , 10 , requires_grad = True )
4470+ h = w * x # Using w outside the checkpoint
4471+ out = checkpoint (lambda x : w * x , h , use_reentrant = False ) # Using w inside the checkpoint
4472+
4473+ out .sum ().backward ()
4474+ # should only call hook once
4475+ self .assertEqual (count , 1 )
4476+
4477+ def test_checkpointing_without_reentrant_arbitrary_input_output (self ):
4478+ """
4479+ Ensures checkpointing without reentrant autograd works with functions
4480+ with arbitrary input/output structures.
4481+ """
4482+
4483+ class MyModel (torch .nn .Module ):
4484+ def __init__ (self ):
4485+ super ().__init__ ()
4486+ self .layer = torch .nn .Linear (5 , 5 , bias = False )
4487+
4488+ def forward (self , dict_input ):
4489+ tensor = dict_input ["tensor" ]
4490+ return {
4491+ "result" : self .layer (tensor )
4492+ }
4493+
4494+ model_no_checkpoint = MyModel ()
4495+ model_checkpoint_without_reentrant = deepcopy (model_no_checkpoint )
4496+
4497+ inp = {
4498+ "tensor" : torch .randn (5 , 5 )
4499+ }
4500+
4501+ out_no_checkpoint = model_no_checkpoint (inp )["result" ].sum ()
4502+
4503+ out_checkpoint = checkpoint (
4504+ model_checkpoint_without_reentrant ,
4505+ inp ,
4506+ use_reentrant = False
4507+ )["result" ].sum ()
4508+
4509+ self .assertEqual (out_checkpoint , out_no_checkpoint )
4510+
4511+ out_no_checkpoint .backward ()
4512+ out_checkpoint .backward ()
4513+
4514+ for param , checkpoint_param in zip (model_no_checkpoint .parameters (), model_checkpoint_without_reentrant .parameters ()):
4515+ self .assertEqual (param .grad , checkpoint_param .grad )
4516+
43214517 def test_callback_adds_callback (self ):
43224518 called = [0 ]
43234519
@@ -9108,5 +9304,7 @@ def fn(x1, x2):
91089304 except_for = None
91099305)
91109306
9307+ instantiate_parametrized_tests (TestAutograd )
9308+
91119309if __name__ == '__main__' :
91129310 run_tests ()
0 commit comments