Skip to content

Commit ab04f3a

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] set autograd graph task state (pytorch#143108)
GraphTask holds metadata needed for a single execution of backward(), it is 1:1 with backward calls, at least for compiled autograd. It is used for certain torch._C global autograd state APIs. In SAC, we use torch._C._current_graph_task_id() as a dict key to store information during unpack hook execution: https://github.com/pytorch/pytorch/blob/a5fb07af2718285a2d6406535e22fc4035ed7854/torch/utils/checkpoint.py#L1128 If we don't set an active task, it will randomize the key, and will do its logic as if each unpacked tensor was from a different graph task https://github.com/pytorch/pytorch/blob/a5fb07af2718285a2d6406535e22fc4035ed7854/torch/utils/checkpoint.py#L1112-L1115 The sketchy part of this PR is that in eager autograd, GraphTask is mutated during execution. But inspecting the struct, the mutation seems to only be used to communicate between autograd threads (created when multiple devices are involved) or for deprecated uses. We shouldn't run into the mutation case at all in compiled autograd. Also, only the graph task id is accessible from python hooks. FIXES pytorch#142862 Pull Request resolved: pytorch#143108 Approved by: https://github.com/jansel, https://github.com/albanD
1 parent dbe4b69 commit ab04f3a

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,6 +3322,57 @@ def make_post_acc_grad_hook(id):
33223322

33233323
self.check_output_and_recompiles(fn)
33243324

3325+
def test_sac(self):
3326+
# circular import
3327+
from torch.utils.checkpoint import (
3328+
checkpoint,
3329+
CheckpointPolicy,
3330+
create_selective_checkpoint_contexts,
3331+
)
3332+
3333+
def fn():
3334+
class mlp(nn.Module):
3335+
def __init__(self):
3336+
super().__init__()
3337+
self.layer1 = nn.Linear(10, 10)
3338+
self.layer2 = nn.Linear(10, 10)
3339+
self.layer3 = nn.Linear(10, 10)
3340+
self.layer4 = nn.Linear(10, 10)
3341+
3342+
def forward(self, x):
3343+
x = self.layer1(x)
3344+
x = self.layer2(x)
3345+
x = self.layer3(x)
3346+
x = self.layer4(x)
3347+
return x
3348+
3349+
recompute_list = [torch.ops.aten.addmm.default]
3350+
3351+
def recompute_policy(ctx, op, *args, **kwargs):
3352+
if op in recompute_list:
3353+
return CheckpointPolicy.MUST_RECOMPUTE
3354+
else:
3355+
return CheckpointPolicy.PREFER_SAVE
3356+
3357+
def context_fn():
3358+
return create_selective_checkpoint_contexts(recompute_policy)
3359+
3360+
model = mlp()
3361+
input = torch.randn(1, 10)
3362+
3363+
out = checkpoint(model, input, use_reentrant=False, context_fn=context_fn)
3364+
out.sum().backward()
3365+
yield model.layer1.weight.grad
3366+
yield model.layer1.bias.grad
3367+
yield model.layer2.weight.grad
3368+
yield model.layer2.bias.grad
3369+
yield model.layer3.weight.grad
3370+
yield model.layer3.bias.grad
3371+
yield model.layer4.weight.grad
3372+
yield model.layer4.bias.grad
3373+
3374+
self.check_output_and_recompiles(fn)
3375+
33253376

33263377
def load_test_module(name):
33273378
testdir = Path(__file__).absolute().parent.parent

torch/csrc/autograd/engine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ auto Engine::execute(
13211321
TORCH_CHECK(
13221322
!AnomalyMode::is_enabled(),
13231323
"compiled_autograd does not support AnomalyMode")
1324+
GraphTaskGuard guard(graph_task);
13241325
return (*compiled_autograd)(
13251326
graph_root, *graph_task, accumulate_grad, outputs);
13261327
}

0 commit comments

Comments
 (0)