16
16
# Doesn't torch.compile already capture the backward graph?
17
17
# ------------
18
18
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19
- # - Graph breaks in the forward lead to graph breaks in the backward
20
- # - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
19
+ # - Graph breaks in the forward lead to graph breaks in the backward
20
+ # - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
21
21
#
22
22
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
23
23
# it to capture the full backward graph at runtime. Models with these two characteristics should try
24
24
# Compiled Autograd, and potentially observe better performance.
25
25
#
26
26
# However, Compiled Autograd has its own limitations:
27
- # - Dynamic autograd structure leads to recompiles
27
+ # - Dynamic autograd structure leads to recompiles
28
28
#
29
29
30
+ ######################################################################
31
+ # Tutorial output cells setup
32
+ # ------------
33
+ #
34
+
35
+ import os
36
+
37
+ class ScopedLogging :
38
+ def __init__ (self ):
39
+ assert "TORCH_LOGS" not in os .environ
40
+ assert "TORCH_LOGS_FORMAT" not in os .environ
41
+ os .environ ["TORCH_LOGS" ] = "compiled_autograd_verbose"
42
+ os .environ ["TORCH_LOGS_FORMAT" ] = "short"
43
+
44
+ def __del__ (self ):
45
+ del os .environ ["TORCH_LOGS" ]
46
+ del os .environ ["TORCH_LOGS_FORMAT" ]
47
+
48
+
30
49
######################################################################
31
50
# Basic Usage
32
51
# ------------
33
52
#
34
53
54
+ import torch
55
+
35
56
# NOTE: Must be enabled before using the decorator
36
57
torch ._dynamo .config .compiled_autograd = True
37
58
@@ -57,21 +78,12 @@ def train(model, x):
57
78
# ------------
58
79
# Run the script with either TORCH_LOGS environment variables
59
80
#
60
- """
61
- Prints graph:
62
- TORCH_LOGS="compiled_autograd" python example.py
63
-
64
- Performance degrading, prints verbose graph and recompile reasons:
65
- TORCH_LOGS="compiled_autograd_verbose" python example.py
66
- """
67
-
68
- ######################################################################
69
- # Or with the set_logs private API:
81
+ # - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py`
82
+ # - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py`
83
+ #
84
+ # Logs can also be enabled through the private API torch._logging._internal.set_logs.
70
85
#
71
86
72
- # flag must be enabled before wrapping using torch.compile
73
- torch ._logging ._internal .set_logs (compiled_autograd = True )
74
-
75
87
@torch .compile
76
88
def train (model , x ):
77
89
loss = model (x ).sum ()
@@ -80,14 +92,15 @@ def train(model, x):
80
92
train (model , x )
81
93
82
94
######################################################################
83
- # The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by " aot0_" ,
95
+ # The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_,
84
96
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
85
97
#
86
98
# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
87
99
# generated some python code to represent the entire C++ autograd execution.
88
100
#
89
101
"""
90
- INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
102
+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
103
+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
91
104
===== Compiled autograd graph =====
92
105
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
93
106
def forward(self, inputs, sizes, scalars, hooks):
@@ -178,6 +191,7 @@ def fn(x):
178
191
return temp .sum ()
179
192
180
193
x = torch .randn (10 , 10 , requires_grad = True )
194
+ torch ._dynamo .utils .counters .clear ()
181
195
loss = fn (x )
182
196
183
197
# 1. base torch.compile
@@ -205,7 +219,6 @@ def fn(x):
205
219
x .register_hook (lambda grad : grad + 10 )
206
220
loss = fn (x )
207
221
208
- torch ._logging ._internal .set_logs (compiled_autograd = True )
209
222
with torch ._dynamo .compiled_autograd .enable (torch .compile (backend = "aot_eager" )):
210
223
loss .backward ()
211
224
@@ -214,22 +227,22 @@ def fn(x):
214
227
#
215
228
216
229
"""
217
- INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
230
+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
231
+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
218
232
===== Compiled autograd graph =====
219
233
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
220
234
def forward(self, inputs, sizes, scalars, hooks):
221
- ...
222
- getitem_2 = hooks[0]; hooks = None
223
- call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
224
- ...
235
+ ...
236
+ getitem_2 = hooks[0]; hooks = None
237
+ call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
238
+ ...
225
239
"""
226
240
227
241
######################################################################
228
- # Understanding recompilation reasons for Compiled Autograd
242
+ # Common recompilation reasons for Compiled Autograd
229
243
# ------------
230
244
# 1. Due to change in autograd structure
231
245
232
- torch ._logging ._internal .set_logs (compiled_autograd_verbose = True )
233
246
torch ._dynamo .config .compiled_autograd = True
234
247
x = torch .randn (10 , requires_grad = True )
235
248
for op in [torch .add , torch .sub , torch .mul , torch .div ]:
@@ -238,14 +251,18 @@ def forward(self, inputs, sizes, scalars, hooks):
238
251
239
252
######################################################################
240
253
# You should see some cache miss logs (recompiles):
241
- # Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
242
- # ...
243
- # Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
244
- # ...
245
- # Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
246
- # ...
247
- # Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
248
- # ...
254
+ #
255
+
256
+ """
257
+ Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
258
+ ...
259
+ Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
260
+ ...
261
+ Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
262
+ ...
263
+ Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
264
+ ...
265
+ """
249
266
250
267
######################################################################
251
268
# 2. Due to dynamic shapes
@@ -260,12 +277,16 @@ def forward(self, inputs, sizes, scalars, hooks):
260
277
261
278
######################################################################
262
279
# You should see some cache miss logs (recompiles):
263
- # ...
264
- # Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
265
- # Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
266
- # Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
267
- # Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
268
- # ...
280
+ #
281
+
282
+ """
283
+ ...
284
+ Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
285
+ Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
286
+ Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
287
+ Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
288
+ ...
289
+ """
269
290
270
291
######################################################################
271
292
# Compatibility and rough edges
0 commit comments