Skip to content

Commit 049debd

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
[Reland][Autograd/Checkpoint] Checkpoint implementation without reentrant autograd (pytorch#69508)
Summary: Pull Request resolved: pytorch#69508 Original Phabricator Diff: D32704467 (pytorch@e032dae) Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented. Original PR body: Resubmission of pytorch#62964 with the suggestions and tests discussed in pytorch#65537. Adds a `use_reentrant=False` flag to `checkpoint` function. When `use_reentrant=True` is specified, a checkpointing implementation that uses SavedVariableHooks instead of re-entrant autograd is used. This makes it more composable with things such as `autograd.grad` as well as DDP (still need to add thorough distributed testing). As discussed in pytorch#65537, the tests that we need to add are: - [x] Gradient hooks are called once - [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn) - [x] works for functions with arbitrary input/output objects - [x] distributed tests (next PR) Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`. ghstack-source-id: 144948501 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D32902634 fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
1 parent 3456c2c commit 049debd

File tree

2 files changed

+313
-18
lines changed

2 files changed

+313
-18
lines changed

test/test_autograd.py

+199-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: autograd"]
22

3+
import contextlib
34
import gc
45
import io
56
import math
@@ -31,7 +32,8 @@
3132
from torch.testing._internal.common_cuda import TEST_CUDA
3233
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
3334
slowTest, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
34-
disable_gc, gradcheck, gradgradcheck)
35+
disable_gc, gradcheck, gradgradcheck,
36+
parametrize, instantiate_parametrized_tests)
3537
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
3638
from torch.autograd.function import InplaceFunction
3739
import torch.autograd.forward_ad as fwAD
@@ -4308,6 +4310,50 @@ def test_checkpointing(self):
43084310
mean_combined = torch.stack(feat_combined).mean()
43094311
mean_combined.backward()
43104312

4313+
@slowTest
4314+
@parametrize("input_requires_grad", [True, False])
4315+
def test_checkpointing_without_reentrant(self, input_requires_grad):
4316+
"""
4317+
Basic test for checkpoint without reentrant autograd.
4318+
"""
4319+
num_inp = 2000
4320+
nz_inp = 10
4321+
nz_out = 10
4322+
nz_bottleneck = 1000
4323+
4324+
# small proxy network for some complex reasoning we want to do per input
4325+
module = nn.Sequential(
4326+
nn.Linear(nz_inp, nz_bottleneck),
4327+
nn.ReLU(),
4328+
nn.Linear(nz_bottleneck, nz_inp)
4329+
)
4330+
4331+
# Run model with and without checkpointing and verify gradients are
4332+
# equivalent, regardless of if inputs require grads or not.
4333+
module_copy = deepcopy(module)
4334+
4335+
feat_combined = []
4336+
feat_combined_no_checkpoint = []
4337+
for r in range(num_inp):
4338+
data_r = torch.empty(1, nz_inp)
4339+
data_r.uniform_()
4340+
data_r.requires_grad = input_requires_grad
4341+
data_r_copy = data_r.clone()
4342+
feat_r = checkpoint(module, data_r, use_reentrant=False)
4343+
feat_combined.append(feat_r)
4344+
feat_r_no_checkpoint = module_copy(data_r)
4345+
feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
4346+
4347+
4348+
# compute mean as a proxy for some joint reasoning
4349+
mean_combined = torch.stack(feat_combined).mean()
4350+
mean_combined.backward()
4351+
mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
4352+
mean_combined_no_checkpoint.backward()
4353+
4354+
for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()):
4355+
self.assertEqual(checkpoint_param.grad, param.grad)
4356+
43114357
def test_checkpoint_valid_reset_on_error(self):
43124358
a = torch.randn(2, 2, requires_grad=True)
43134359

@@ -4318,6 +4364,156 @@ def test_checkpoint_valid_reset_on_error(self):
43184364
c = checkpoint(torch.exp, a).sum()
43194365
c.backward()
43204366

4367+
@parametrize("use_reentrant", [True, False])
4368+
def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
4369+
class NoGradModule(torch.nn.Module):
4370+
def __init__(self):
4371+
super().__init__()
4372+
self.linear = nn.Linear(2, 2, bias=False)
4373+
self.lin2 = nn.Linear(2, 2, bias=False)
4374+
4375+
def forward(self, x):
4376+
with torch.no_grad():
4377+
return self.lin2(self.linear(x))
4378+
4379+
module = NoGradModule()
4380+
4381+
err_ctx = (
4382+
self.assertRaisesRegex(
4383+
RuntimeError,
4384+
"none of output has requires_grad=True"
4385+
)
4386+
if use_reentrant
4387+
else contextlib.suppress()
4388+
)
4389+
4390+
a = torch.randn(2, 2, requires_grad=True)
4391+
for _ in range(3):
4392+
with err_ctx:
4393+
# out does not require grad
4394+
out = checkpoint(module, a, use_reentrant=use_reentrant)
4395+
# Make loss require grad, otherwise we would run into
4396+
# "element 0 of tensors does not require grad and does not have a grad_fn"
4397+
out += a
4398+
out.sum().backward()
4399+
4400+
def test_checkpointing_without_reentrant_correct_grad(self):
4401+
"""
4402+
Verifies that correct gradients are calculated for checkpoint
4403+
without reentrant autograd, for both backward() and autograd.grad().
4404+
"""
4405+
a = torch.randn(2, 2, requires_grad=True)
4406+
4407+
b = torch.exp(a).sum()
4408+
b.backward()
4409+
b_grad = a.grad
4410+
4411+
a.grad = None
4412+
c = checkpoint(torch.exp, a, use_reentrant=False).sum()
4413+
c.backward()
4414+
c_grad = a.grad
4415+
4416+
a.grad = None
4417+
d = checkpoint(torch.exp, a, use_reentrant=False).sum()
4418+
d_grad, = torch.autograd.grad(d, (a,))
4419+
4420+
self.assertEqual(b_grad, c_grad)
4421+
self.assertEqual(b_grad, d_grad)
4422+
4423+
def test_checkpointing_without_reentrant_dataparallel(self):
4424+
"""
4425+
Verifies gradient correctness when checkpoint without reentrant autograd
4426+
is used in conjunction with DataParallel.
4427+
"""
4428+
class LinearModule(torch.nn.Module):
4429+
def __init__(self):
4430+
super().__init__()
4431+
self.linear = nn.Linear(2, 2, bias=False)
4432+
4433+
def forward(self, inp):
4434+
return self.linear(inp)
4435+
4436+
a = torch.randn(2, 2, requires_grad=True)
4437+
if torch.cuda.is_available():
4438+
a = a.cuda()
4439+
4440+
model = LinearModule()
4441+
if torch.cuda.is_available():
4442+
model = model.cuda()
4443+
4444+
b = deepcopy(model)(a).sum()
4445+
b.backward()
4446+
b_grad = a.grad
4447+
4448+
a.grad = None
4449+
4450+
module = torch.nn.DataParallel(deepcopy(model))
4451+
c = checkpoint(module, a, use_reentrant=False).sum()
4452+
c.backward()
4453+
c_grad = a.grad
4454+
4455+
self.assertEqual(b_grad, c_grad)
4456+
4457+
def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
4458+
"""
4459+
Ensures that gradient hooks are only called once per tensor.
4460+
"""
4461+
w = torch.randn(10, 10, requires_grad=True)
4462+
count = 0
4463+
4464+
def hook(grad):
4465+
nonlocal count
4466+
count += 1
4467+
4468+
w.register_hook(hook)
4469+
x = torch.rand(10, 10, requires_grad=True)
4470+
h = w * x # Using w outside the checkpoint
4471+
out = checkpoint(lambda x: w * x, h, use_reentrant=False) # Using w inside the checkpoint
4472+
4473+
out.sum().backward()
4474+
# should only call hook once
4475+
self.assertEqual(count, 1)
4476+
4477+
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
4478+
"""
4479+
Ensures checkpointing without reentrant autograd works with functions
4480+
with arbitrary input/output structures.
4481+
"""
4482+
4483+
class MyModel(torch.nn.Module):
4484+
def __init__(self):
4485+
super().__init__()
4486+
self.layer = torch.nn.Linear(5, 5, bias=False)
4487+
4488+
def forward(self, dict_input):
4489+
tensor = dict_input["tensor"]
4490+
return {
4491+
"result": self.layer(tensor)
4492+
}
4493+
4494+
model_no_checkpoint = MyModel()
4495+
model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
4496+
4497+
inp = {
4498+
"tensor": torch.randn(5, 5)
4499+
}
4500+
4501+
out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
4502+
4503+
out_checkpoint = checkpoint(
4504+
model_checkpoint_without_reentrant,
4505+
inp,
4506+
use_reentrant=False
4507+
)["result"].sum()
4508+
4509+
self.assertEqual(out_checkpoint, out_no_checkpoint)
4510+
4511+
out_no_checkpoint.backward()
4512+
out_checkpoint.backward()
4513+
4514+
for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()):
4515+
self.assertEqual(param.grad, checkpoint_param.grad)
4516+
43214517
def test_callback_adds_callback(self):
43224518
called = [0]
43234519

@@ -9108,5 +9304,7 @@ def fn(x1, x2):
91089304
except_for=None
91099305
)
91109306

9307+
instantiate_parametrized_tests(TestAutograd)
9308+
91119309
if __name__ == '__main__':
91129310
run_tests()

0 commit comments

Comments
 (0)