Skip to content

Commit b86d2c5

Browse files
committed
Add compiled autograd tutorial
1 parent 748e52b commit b86d2c5

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
5+
==========================================================================
6+
7+
"""
8+
9+
######################################################################
10+
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
11+
# that allows the capture of a larger backward graph. It is highly recommended
12+
# to familiarize yourself with `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.
13+
#
14+
15+
######################################################################
16+
# Doesn't torch.compile already capture the backward graph?
17+
# ------------
18+
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19+
# - Graph breaks in the forward lead to graph breaks in the backward
20+
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
21+
#
22+
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
23+
# it to capture the full backward graph at runtime. Models with these two characteristics should try
24+
# Compiled Autograd, and potentially observe better performance.
25+
#
26+
27+
######################################################################
28+
# Basic Usage
29+
# ------------
30+
#
31+
32+
import torch
33+
34+
class Model(torch.nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.linear = torch.nn.Linear(10, 10)
38+
39+
def forward(self, x):
40+
return self.linear(x)
41+
42+
torch._dynamo.config.compiled_autograd = True
43+
44+
@torch.compile
45+
def train(model, x):
46+
loss = model(x).sum()
47+
loss.backward()
48+
49+
######################################################################
50+
# TODO: add an image of the graph
51+
# Note: In a future release, we will prevent the graph break caused by loss.backward()
52+
#
53+
54+
######################################################################
55+
# Compiling the forward and backward pass using different flags
56+
# ------------
57+
#
58+
59+
def train(model, x):
60+
model = torch.compile(model)
61+
loss = model(x).sum()
62+
torch.compile(lambda: loss.backward(), fullgraph=True)(loss)
63+
64+
######################################################################
65+
# Appendix: Compatibility
66+
# ------------
67+
#
68+
# Compiled Autograd is not yet compatible with all existing PyTorch features.
69+
# Below is a list of known incompatibilities.

0 commit comments

Comments
 (0)