Skip to content

Commit 271b8f2

Browse files
committed
address comments
1 parent 50a6978 commit 271b8f2

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

intermediate_source/compiled_autograd_tutorial.py

+43-37
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,35 @@
1111
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
1212
:class-card: card-prerequisites
1313
14-
* How compiled autograd interacts with torch.compile
14+
* How compiled autograd interacts with ``torch.compile``
1515
* How to use the compiled autograd API
16-
* How to inspect logs using TORCH_LOGS
16+
* How to inspect logs using ``TORCH_LOGS``
1717
1818
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
1919
:class-card: card-prerequisites
2020
2121
* PyTorch 2.4
22-
* `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ familiarity
22+
* Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
2323
2424
"""
2525

2626
######################################################################
2727
# Overview
2828
# ------------
29-
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
29+
# Compiled Autograd is a ``torch.compile`` extension introduced in PyTorch 2.4
3030
# that allows the capture of a larger backward graph.
3131
#
32-
# Doesn't torch.compile already capture the backward graph?
33-
# ------------
34-
# And it does, **partially**. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
35-
# 1. Graph breaks in the forward lead to graph breaks in the backward
36-
# 2. `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
32+
# While ``torch.compile`` does capture the backward graph, it does so **partially**. The AOTAutograd component captures the backward graph ahead-of-time, with certain limitations:
33+
# * Graph breaks in the forward lead to graph breaks in the backward
34+
# * `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
3735
#
3836
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
3937
# it to capture the full backward graph at runtime. Models with these two characteristics should try
4038
# Compiled Autograd, and potentially observe better performance.
4139
#
42-
# However, Compiled Autograd has its own limitations:
43-
# 1. Additional runtime overhead at the start of the backward
44-
# 2. Dynamic autograd structure leads to recompiles
40+
# However, Compiled Autograd introduces its own limitations:
41+
# * Added runtime overhead at the start of the backward for cache lookup
42+
# * More prone to recompiles and graph breaks in dynamo due to the larger capture
4543
#
4644
# .. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_.
4745
#
@@ -50,8 +48,9 @@
5048
######################################################################
5149
# Setup
5250
# ------------
53-
# In this tutorial, we'll base our examples on this toy model.
54-
#
51+
# In this tutorial, we will base our examples on this simple neural network model.
52+
# It takes a a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
53+
#
5554

5655
import torch
5756

@@ -67,7 +66,7 @@ def forward(self, x):
6766
######################################################################
6867
# Basic usage
6968
# ------------
70-
# .. note:: The ``torch._dynamo.config.compiled_autograd = True`` config must be enabled before calling the torch.compile API.
69+
# Before calling the torch.compile API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
7170
#
7271

7372
model = Model()
@@ -82,23 +81,30 @@ def train(model, x):
8281
train(model, x)
8382

8483
######################################################################
85-
# Inspecting the compiled autograd logs
86-
# ------------
87-
# Run the script with the TORCH_LOGS environment variables:
88-
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
89-
# - To print the graph with more tensor medata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
84+
# In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using torch.randn(10).
85+
# We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution.
86+
#
87+
# When ``train(model, x)`` is called:
88+
# * Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``
89+
# * Dynamo intercepts the python bytecode, simulates their execution and records the operations into a graph
90+
# * AOTDispatcher disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
91+
# * Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward
92+
# * Dynamo sets the optimized function to be evaluated next by Python Interpreter
93+
# * Python Interpreter executes the optimized function, which basically executes ``loss = model(x).sum()``
94+
# * Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we enabled the config: ``torch._dynamo.config.compiled_autograd = True``
95+
# * Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode
96+
# * The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher does not need to partition this graph into a forward and backward
9097
#
91-
92-
@torch.compile
93-
def train(model, x):
94-
loss = model(x).sum()
95-
loss.backward()
96-
97-
train(model, x)
9898

9999
######################################################################
100-
# The compiled autograd graph should now be logged to stderr. Certain graph nodes will have names that are prefixed by ``aot0_``,
101-
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0 e.g. ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
100+
# Inspecting the compiled autograd logs
101+
# -------------------------------------
102+
# Run the script with the ``TORCH_LOGS`` environment variables:
103+
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
104+
# - To print the graph with more tensor metadata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
105+
#
106+
# Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``,
107+
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
102108
#
103109

104110
stderr_output = """
@@ -156,17 +162,19 @@ def forward(self, inputs, sizes, scalars, hooks):
156162
"""
157163

158164
######################################################################
159-
# .. note:: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd generates some python code to represent the entire C++ autograd execution.
165+
# .. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.
160166
#
161167

162168
######################################################################
163169
# Compiling the forward and backward pass using different flags
164-
# ------------
165-
#
170+
# -------------------------------------------------------------
171+
# You can use different compiler configs for the two compilations, for example, the backward may be a fullgraph even if there are graph breaks in the forward.
172+
#
166173

167174
def train(model, x):
168175
model = torch.compile(model)
169176
loss = model(x).sum()
177+
torch._dynamo.config.compiled_autograd = True
170178
torch.compile(lambda: loss.backward(), fullgraph=True)()
171179

172180
######################################################################
@@ -182,7 +190,7 @@ def train(model, x):
182190

183191
######################################################################
184192
# Compiled Autograd addresses certain limitations of AOTAutograd
185-
# ------------
193+
# --------------------------------------------------------------
186194
# 1. Graph breaks in the forward lead to graph breaks in the backward
187195
#
188196

@@ -252,7 +260,7 @@ def forward(self, inputs, sizes, scalars, hooks):
252260

253261
######################################################################
254262
# Common recompilation reasons for Compiled Autograd
255-
# ------------
263+
# --------------------------------------------------
256264
# 1. Due to change in autograd structure
257265

258266
torch._dynamo.config.compiled_autograd = True
@@ -302,7 +310,5 @@ def forward(self, inputs, sizes, scalars, hooks):
302310
######################################################################
303311
# Conclusion
304312
# ----------
305-
# In this tutorial, we went over the high-level ecosystem of torch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
306-
#
307-
# For feedback on this tutorial, please file an issue on https://github.com/pytorch/tutorials.
313+
# In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
308314
#

0 commit comments

Comments
 (0)