Skip to content

Commit 45a4087

Browse files
committed
update
1 parent ffa0a81 commit 45a4087

File tree

1 file changed

+61
-40
lines changed

1 file changed

+61
-40
lines changed

intermediate_source/compiled_autograd_tutorial.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,43 @@
1616
# Doesn't torch.compile already capture the backward graph?
1717
# ------------
1818
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19-
# - Graph breaks in the forward lead to graph breaks in the backward
20-
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
19+
# - Graph breaks in the forward lead to graph breaks in the backward
20+
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
2121
#
2222
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
2323
# it to capture the full backward graph at runtime. Models with these two characteristics should try
2424
# Compiled Autograd, and potentially observe better performance.
2525
#
2626
# However, Compiled Autograd has its own limitations:
27-
# - Dynamic autograd structure leads to recompiles
27+
# - Dynamic autograd structure leads to recompiles
2828
#
2929

30+
######################################################################
31+
# Tutorial output cells setup
32+
# ------------
33+
#
34+
35+
import os
36+
37+
class ScopedLogging:
38+
def __init__(self):
39+
assert "TORCH_LOGS" not in os.environ
40+
assert "TORCH_LOGS_FORMAT" not in os.environ
41+
os.environ["TORCH_LOGS"] = "compiled_autograd_verbose"
42+
os.environ["TORCH_LOGS_FORMAT"] = "short"
43+
44+
def __del__(self):
45+
del os.environ["TORCH_LOGS"]
46+
del os.environ["TORCH_LOGS_FORMAT"]
47+
48+
3049
######################################################################
3150
# Basic Usage
3251
# ------------
3352
#
3453

54+
import torch
55+
3556
# NOTE: Must be enabled before using the decorator
3657
torch._dynamo.config.compiled_autograd = True
3758

@@ -57,21 +78,12 @@ def train(model, x):
5778
# ------------
5879
# Run the script with either TORCH_LOGS environment variables
5980
#
60-
"""
61-
Prints graph:
62-
TORCH_LOGS="compiled_autograd" python example.py
63-
64-
Performance degrading, prints verbose graph and recompile reasons:
65-
TORCH_LOGS="compiled_autograd_verbose" python example.py
66-
"""
67-
68-
######################################################################
69-
# Or with the set_logs private API:
81+
# - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py`
82+
# - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py`
83+
#
84+
# Logs can also be enabled through the private API torch._logging._internal.set_logs.
7085
#
7186

72-
# flag must be enabled before wrapping using torch.compile
73-
torch._logging._internal.set_logs(compiled_autograd=True)
74-
7587
@torch.compile
7688
def train(model, x):
7789
loss = model(x).sum()
@@ -80,14 +92,15 @@ def train(model, x):
8092
train(model, x)
8193

8294
######################################################################
83-
# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by "aot0_",
95+
# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_,
8496
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
8597
#
8698
# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
8799
# generated some python code to represent the entire C++ autograd execution.
88100
#
89101
"""
90-
INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
102+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
103+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
91104
===== Compiled autograd graph =====
92105
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
93106
def forward(self, inputs, sizes, scalars, hooks):
@@ -178,6 +191,7 @@ def fn(x):
178191
return temp.sum()
179192

180193
x = torch.randn(10, 10, requires_grad=True)
194+
torch._dynamo.utils.counters.clear()
181195
loss = fn(x)
182196

183197
# 1. base torch.compile
@@ -205,7 +219,6 @@ def fn(x):
205219
x.register_hook(lambda grad: grad+10)
206220
loss = fn(x)
207221

208-
torch._logging._internal.set_logs(compiled_autograd=True)
209222
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
210223
loss.backward()
211224

@@ -214,22 +227,22 @@ def fn(x):
214227
#
215228

216229
"""
217-
INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
230+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
231+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
218232
===== Compiled autograd graph =====
219233
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
220234
def forward(self, inputs, sizes, scalars, hooks):
221-
...
222-
getitem_2 = hooks[0]; hooks = None
223-
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
224-
...
235+
...
236+
getitem_2 = hooks[0]; hooks = None
237+
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
238+
...
225239
"""
226240

227241
######################################################################
228-
# Understanding recompilation reasons for Compiled Autograd
242+
# Common recompilation reasons for Compiled Autograd
229243
# ------------
230244
# 1. Due to change in autograd structure
231245

232-
torch._logging._internal.set_logs(compiled_autograd_verbose=True)
233246
torch._dynamo.config.compiled_autograd = True
234247
x = torch.randn(10, requires_grad=True)
235248
for op in [torch.add, torch.sub, torch.mul, torch.div]:
@@ -238,14 +251,18 @@ def forward(self, inputs, sizes, scalars, hooks):
238251

239252
######################################################################
240253
# You should see some cache miss logs (recompiles):
241-
# Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
242-
# ...
243-
# Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
244-
# ...
245-
# Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
246-
# ...
247-
# Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
248-
# ...
254+
#
255+
256+
"""
257+
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
258+
...
259+
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
260+
...
261+
Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
262+
...
263+
Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
264+
...
265+
"""
249266

250267
######################################################################
251268
# 2. Due to dynamic shapes
@@ -260,12 +277,16 @@ def forward(self, inputs, sizes, scalars, hooks):
260277

261278
######################################################################
262279
# You should see some cache miss logs (recompiles):
263-
# ...
264-
# Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
265-
# Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
266-
# Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
267-
# Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
268-
# ...
280+
#
281+
282+
"""
283+
...
284+
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
285+
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
286+
Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
287+
Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
288+
...
289+
"""
269290

270291
######################################################################
271292
# Compatibility and rough edges

0 commit comments

Comments
 (0)