Skip to content

Commit 9dec3c0

Browse files
authored
Add blog post "Current and New Activation Checkpointing Techniques in PyTorch" (#1928)
* Add blog post "Current and New Activation Checkpointing Techniques in PyTorch" Signed-off-by: Chris Abraham <[email protected]> * fix Signed-off-by: Chris Abraham <[email protected]> * Update publish date Signed-off-by: Chris Abraham <[email protected]> --------- Signed-off-by: Chris Abraham <[email protected]>
1 parent b3d0b73 commit 9dec3c0

File tree

15 files changed

+233
-0
lines changed

15 files changed

+233
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
---
2+
layout: blog_detail
3+
title: "Current and New Activation Checkpointing Techniques in PyTorch"
4+
---
5+
6+
As models scale in depth, batch size, and sequence length, etc, activation memory becomes an increasingly significant contributor to the overall memory usage. To help address this, PyTorch provides utilities for [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html), which reduce the number of saved tensors by recomputing them when needed, trading off memory usage for additional compute.
7+
8+
In this post, we’ll walk through the basics of what activation memory is, the high-level ideas behind existing activation checkpointing techniques, and also introduce some newer techniques that aim to improve flexibility and provide more optimization/automation out of the box.
9+
10+
As we look at these techniques, we'll compare how these methods fit into a speed vs. memory trade-off diagram and hopefully provide some insight on how to choose the right strategy for your use case.
11+
12+
*(If you prefer to jump straight to the new APIs, please skip ahead to the “Selective Activation Checkpoint” and “Memory Budget API” sections below.)*
13+
14+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg1.png){:style="width:100%"}
15+
16+
17+
---
18+
19+
20+
## Activation Memory Basics
21+
22+
By default, in eager mode (rather than using `torch.compile`), PyTorch’s autograd preserves intermediate activations for backward computation. For example, if you call `sin` on a tensor `x` during the forward pass, autograd must remember `x` to compute `cos(x)` during backward.
23+
24+
25+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg2.png){:style="max-width:400px; display: block; margin-left: auto; margin-right: auto"}
26+
27+
28+
If this tensor `x` is saved at the beginning of the forward pass, it remains in memory throughout both the forward and backward phases. It can only be cleared after it is used to compute the gradient, which happens at the end of the backward pass (due to the reverse order of execution).
29+
30+
Thus, as you proceed through the forward pass and perform more and more operations, you accumulate more and more activations, resulting in more and more activation memory until it (typically) reaches its peak at the start of backward (at which point activations can start to get cleared).
31+
32+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg3.png){:style="width:100%"}
33+
34+
35+
*In the diagram above, the orange boxes represent operations, black arrows represent their tensor inputs and outputs. The black arrows that cross over the right represent tensors that autograd saves for backward.*
36+
37+
A useful way to visually organize this default saving behavior in eager as well as the techniques we're about to introduce is based on how they trade off speed versus memory.
38+
39+
40+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg4.png){:style="width:100%"}
41+
42+
43+
The ideal place to be on this diagram is the top-left, where you have "high" speed but also low memory usage.
44+
45+
We begin by putting the default saving behavior on the **top-right** (for reasons we'll explain in more detail as we introduce more points for other techniques).
46+
47+
48+
---
49+
50+
51+
## Activation Checkpointing (AC)
52+
53+
**[Activation checkpointing (AC)](https://pytorch.org/docs/stable/checkpoint.html)** is a popular technique to reduce memory usage in PyTorch.
54+
55+
During forward, any operations performed inside the AC'd region do not save tensors for backward. (Only the inputs to the function are saved.) During backward, the intermediate activations needed for gradient computation are rematerialized by running the function a second time.
56+
57+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg5.png){:style="width:100%"}
58+
59+
60+
*In the diagram (right), the black box shows where activation checkpointing is applied. Compared to the default eager approach (left), this setup results in fewer tensors being saved (1 versus 3).*
61+
62+
Applying AC on the right parts of the model has the effect of reducing peak memory, because the intermediate activations are no longer materialized in memory when the memory usage typically peaks (at the beginning of backward).
63+
64+
On the speed-versus-memory tradeoff diagram, AC is plotted on the **bottom-left.** Relative to eager mode, it reduces the amount of memory saved for backward but comes with an added cost in compute due to recomputation.
65+
66+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg6.png){:style="width:100%"}
67+
68+
69+
Note that AC’s speed–memory tradeoff /can/ be adjusted by selecting which parts of the forward pass to checkpoint and by defining how many checkpoint regions to use. However, implementing these changes may require modifying your model’s structure and can be cumbersome depending on how your code is organized. For the purposes of this diagram, we assume only one region is checkpointed; under this assumption, AC appears as a single point on the tradeoff diagram.
70+
71+
Also note that “memory” here does not refer to peak memory usage; rather, it indicates the how much memory is saved for backward for a fixed region.
72+
73+
74+
---
75+
76+
77+
## torch.compile and min-cut partitioner
78+
79+
Another notable approach to keep in mind is **torch.compile** (introduced in PyTorch 2.0). Like activation checkpointing, `torch.compile` can also perform some level of recomputation under the hood. Specifically, it traces the forward and backward computations into a single joint graph, which is then processed by a [“min-cut” partitioner](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467). This partitioner uses a min-cut/max-flow algorithm to split the graph such that it minimizes the number of tensors that need to be saved for backward.
80+
81+
At first glance, this might sound a lot like what we want for activation memory reduction. However, the reality is more nuanced. By default, the partitioner’s primary goal is to reduce runtime. As a result, it only recomputes certain types of operations—primarily simpler, fusible, and non-compute-intensive ops (like pointwise ops).
82+
83+
Placing "compile" on the speed-versus-memory tradeoff diagram...
84+
85+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg7.png){:style="width:100%"}
86+
87+
88+
It is to the top-left of the eager non-AC point, as we expect `torch.compile` to improve on both speed and memory.
89+
90+
On the other hand, relative to activation checkpointing, torch.compile is more conservative about what it recomputes, placing it closer to the top-left on the speed-versus-memory diagram.
91+
92+
93+
---
94+
95+
96+
## Selective Activation Checkpoint [NEW!]
97+
98+
While normal checkpointing recomputes every op in a chosen region, [selective activation checkpointing (SAC)](https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts) is an additional setting on top of activation checkpointing that you can apply to have a more granular control over which operations to recompute.
99+
100+
This can be useful if you have certain more expensive operations like matmuls which you prefer to avoid recomputing, but still generally want to recompute cheaper operations like pointwise.
101+
102+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg8.png){:style="width:100%"}
103+
104+
105+
*Where plain AC (left) would save a single tensor and then recompute the entire AC'd region, with SAC (right) you can selectively save specific operations (marked red) in the region, so you can avoid recomputing them.*
106+
107+
To specify what to selectively save, you can specify a policy_fn. To illustrate the additional trade offs you can make with this, we present two simple policy functions.
108+
109+
110+
### Policy 1: Not recomputing matmuls:
111+
112+
113+
```
114+
aten = torch.ops.aten
115+
compute_intensive_ops = [
116+
aten.mm,
117+
aten.bmm,
118+
aten.addmm,
119+
]
120+
def policy_fn(ctx, op, *args, **kwargs):
121+
if op in compute_intensive_ops:
122+
return CheckpointPolicy.MUST_SAVE
123+
else:
124+
return CheckpointPolicy.PREFER_RECOMPUTE
125+
```
126+
127+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg9.png){:style="width:100%"}
128+
129+
130+
### Policy 2: More aggressively save anything compute intensive
131+
132+
133+
```
134+
# torch/_functorch/partitioners.py
135+
aten = torch.ops.aten
136+
compute_intensive_ops = [
137+
aten.mm,
138+
aten.convolution,
139+
aten.convolution_backward,
140+
aten.bmm,
141+
aten.addmm,
142+
aten._scaled_dot_product_flash_attention,
143+
aten._scaled_dot_product_efficient_attention,
144+
aten._flash_attention_forward,
145+
aten._efficient_attention_forward,
146+
aten.upsample_bilinear2d,
147+
aten._scaled_mm
148+
]
149+
def policy_fn(ctx, op, *args, **kwargs):
150+
if op in compute_intensive_ops:
151+
return CheckpointPolicy.MUST_SAVE
152+
else:
153+
return CheckpointPolicy.PREFER_RECOMPUTE
154+
```
155+
156+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg10.png){:style="width:100%"}
157+
158+
159+
On the speed-versus-memory diagram, SAC is plotted as a range of points from closer to AC to closer to Eager, depending on your chosen policy.
160+
161+
162+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg11.png){:style="width:100%"}
163+
164+
165+
**Try it out!** (Available in 2.5 as a prototype feature; see [docs](https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts) for more info + copy-pastable example)
166+
167+
168+
```
169+
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
170+
171+
# Create a policy function that returns a CheckpointPolicy
172+
def policy_fn(ctx, op, *args, **kwargs):
173+
if op in ops_to_save:
174+
return CheckpointPolicy.MUST_SAVE
175+
else:
176+
return CheckpointPolicy.PREFER_RECOMPUTE
177+
178+
# Use the context_fn= arg of the existing checkpoint API
179+
out = checkpoint(
180+
fn, *args,
181+
use_reentrant=False,
182+
# Fill in SAC context_fn's policy_fn with functools.partial
183+
context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
184+
)
185+
186+
```
187+
---
188+
189+
190+
191+
## (compile-only) Memory Budget API [NEW!]
192+
193+
As mentioned previously, any given SAC policy can be represented as a point on a speed-memory tradeoff diagram. Not all policies are created equal, however. The "optimal" policies are the ones that fall on a pareto curve, e.g. for all policies that incur the same memory overhead, this policy is the one that minimizes the amount of required compute.
194+
195+
For users who are using torch.compile, we offer a **memory budget API** that automatically applies SAC over your compiled region with a pareto-optimal policy given a user-specified "memory budget" between 0 and 1, where a budget of 0 behaves like plain-AC and a budget of 1 behaves like default torch.compile.
196+
197+
198+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg12.png){:style="width:100%"}
199+
200+
201+
Below are some real results on a transformer model:
202+
203+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg13.png){:style="width:100%"}
204+
205+
206+
We observe a 50% memory reduction by recomputing only pointwise ops, with a steady drop-off as you recompute more and more of your matmuls. Attention is the most expensive, so you tend to want to recompute those last.
207+
208+
**Try it out!** (Available in 2.4 as an experimental feature; see this [comment block](https://github.com/pytorch/pytorch/blob/68a363548409a3ff17965770304ee5e12fe718d9/torch/_functorch/config.py#L110-L122) for more info)
209+
210+
211+
```
212+
torch._dynamo.config.activation_memory_budget = 0.5
213+
214+
out = torch.compile(fn)(inp)
215+
```
216+
217+
---
218+
219+
220+
221+
222+
## Conclusion
223+
224+
225+
![flow diagram](/assets/images/activation-checkpointing-techniques/fg14.png){:style="width:100%"}
226+
227+
228+
In summary, activation checkpointing techniques in PyTorch offer a variety of ways to balance memory and compute demands, from simple region-based checkpointing to more selective and automated methods. By choosing the option that best matches your model’s structure and resource constraints, you can achieve significant memory savings with an acceptable trade-off in compute.
229+
230+
231+
## Acknowledgements
232+
233+
We would like to thank Meta's [xformers](https://github.com/facebookresearch/xformers) team including [Francisco Massa](https://github.com/fmassa) for working on the original version of Selective Activation Checkpoint.
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)