Skip to content

Commit aebeff4

Browse files
XilunWusvekarsjustinchubydesertfire
authored
Add Context Parallel tutorial (#3319)
Summary: The compiled model run takes the same input as Eager. No need to explicitly compose args as a tuple. address comments: improve pass-KV description --------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Bin Bao <[email protected]>
1 parent 7cb6915 commit aebeff4

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed

Diff for: prototype_source/context_parallel.rst

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
Introduction to Context Parallel
2+
======================================
3+
**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__
4+
5+
.. note::
6+
|edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/prototype_source/context_parallel.rst>`__.
7+
8+
.. grid:: 2
9+
10+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
11+
:class-card: card-prerequisites
12+
13+
* `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__
14+
* `1M sequence training in TorchTitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
15+
16+
17+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18+
:class-card: card-prerequisites
19+
20+
* PyTorch 2.7 or later
21+
22+
23+
Introduction
24+
------------
25+
26+
Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices.
27+
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
28+
29+
Ring Attention, a novel parallel implementation of the Attention layer, is critical to performant Context Parallel.
30+
Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device.
31+
Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__:
32+
33+
1. The all-gather based pass-KV algorithm is used in Llama3 training, which initially performs an all-gather on the key and value tensors, followed by computing the attention output for the
34+
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
35+
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
36+
overlap between the attention computation and the all-gather collective. For example, in the case of Llama3 training, we also shard ``freq_cis`` over the sequence dimension.
37+
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA (Scaled Dot Product Attention) computation and the all-to-all communication
38+
necessary for the next SDPA.
39+
40+
The Context Parallel APIs consist of two parts:
41+
42+
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
43+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
44+
argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers``
45+
and shard them along this dimension. Taking Llama3 training as an example, missing ``freq_cis`` in ``buffers`` will result in a miscalculated rotary embedding.
46+
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
47+
48+
49+
Setup
50+
---------------------
51+
52+
With ``torch.distributed.tensor.experimental.context_parallel()``, users can easily shard the Tensor input and parallelize the execution of the SDPA function.
53+
To better demonstrate the usage of this API, we start with a simple code snippet doing SDPA and then parallelize it using the API:
54+
55+
.. code:: python
56+
57+
import torch
58+
import torch.nn.functional as F
59+
60+
from torch.nn.attention import sdpa_kernel, SDPBackend
61+
62+
63+
def sdpa_example():
64+
assert torch.cuda.is_available()
65+
torch.cuda.set_device("cuda:0")
66+
torch.cuda.manual_seed(0)
67+
68+
batch = 8
69+
nheads = 8
70+
qkv_len = 8192
71+
dim = 32
72+
backend = SDPBackend.FLASH_ATTENTION
73+
dtype = (
74+
torch.bfloat16
75+
if backend == SDPBackend.FLASH_ATTENTION
76+
or backend == SDPBackend.CUDNN_ATTENTION
77+
else torch.float32
78+
)
79+
80+
qkv = [
81+
torch.rand(
82+
(batch, nheads, qkv_len, dim),
83+
dtype=dtype,
84+
requires_grad=True,
85+
device='cuda',
86+
)
87+
for _ in range(3)
88+
]
89+
# specify the SDPBackend to use
90+
with sdpa_kernel(backend):
91+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
92+
93+
94+
if __name__ == "__main__":
95+
sdpa_example()
96+
97+
98+
Enable Context Parallel
99+
-----------------------
100+
101+
Now, let's first adapt it to a distributed program where each rank has the same tensor input. Then we apply the context parallel API to
102+
shard to input and distribute the computation across ranks:
103+
104+
.. code:: python
105+
106+
# file: cp_sdpa_example.py
107+
import os
108+
109+
import torch
110+
import torch.distributed as dist
111+
import torch.nn.functional as F
112+
from torch.distributed.device_mesh import init_device_mesh
113+
from torch.distributed.tensor.experimental import context_parallel
114+
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
115+
from torch.nn.attention import sdpa_kernel, SDPBackend
116+
117+
118+
def context_parallel_sdpa_example(world_size: int, rank: int):
119+
assert torch.cuda.is_available()
120+
assert dist.is_nccl_available()
121+
torch.cuda.set_device(f"cuda:{rank}")
122+
torch.cuda.manual_seed(0)
123+
124+
dist.init_process_group(
125+
backend="nccl",
126+
init_method="env://",
127+
world_size=world_size,
128+
rank=rank,
129+
)
130+
device_mesh = init_device_mesh(
131+
device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
132+
)
133+
134+
batch = 8
135+
nheads = 8
136+
qkv_len = 64
137+
dim = 32
138+
backend = SDPBackend.FLASH_ATTENTION
139+
dtype = (
140+
torch.bfloat16
141+
if backend == SDPBackend.FLASH_ATTENTION
142+
or backend == SDPBackend.CUDNN_ATTENTION
143+
else torch.float32
144+
)
145+
146+
qkv = [
147+
torch.rand(
148+
(batch, nheads, qkv_len, dim),
149+
dtype=dtype,
150+
requires_grad=True,
151+
device='cuda',
152+
)
153+
for _ in range(3)
154+
]
155+
# specify the SDPBackend to use
156+
with sdpa_kernel(backend):
157+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
158+
159+
# make a clean copy of QKV for output comparison
160+
cp_qkv = [t.detach().clone() for t in qkv]
161+
162+
with sdpa_kernel(backend):
163+
# This `context_parallel()` performs two actions:
164+
# 1. Shard the tensor objects in `buffers` in-place along the dimension
165+
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
166+
# sharding dims in `buffer_seq_dims` are organized in the same order.
167+
# 2. Replace the execution of `F.scaled_dot_product_attention` with a
168+
# context-paralleled-enabled Ring Attention.
169+
with context_parallel(
170+
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
171+
):
172+
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
173+
174+
# The output `cp_out` is still sharded in the same way as QKV
175+
# the `context_parallel_unshard` API allows users to easily
176+
# unshard to gain the full tensor.
177+
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
178+
179+
assert torch.allclose(
180+
cp_out,
181+
out,
182+
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
183+
)
184+
185+
186+
if __name__ == "__main__":
187+
rank = int(os.environ["RANK"])
188+
world_size = int(os.environ["WORLD_SIZE"])
189+
190+
try:
191+
context_parallel_sdpa_example(world_size, rank)
192+
finally:
193+
dist.barrier()
194+
dist.destroy_process_group()
195+
196+
197+
You can use the command ``torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py`` to launch the above context parallel
198+
SDPA on 4 GPUs. We demonstrate the numeric correctness by comparing the output of Ring Attention to that of SDPA on a single GPU.
199+
200+
201+
Select Rotation Approach
202+
------------------------
203+
204+
You can choose the desired shards rotation approach in Ring Attention by using ``torch.distributed.tensor.experimental._attention.set_rotate_method()``:
205+
206+
.. code:: python
207+
208+
# file: cp_sdpa_example.py
209+
from torch.distributed.tensor.experimental._attention import set_rotate_method
210+
211+
set_rotate_method("alltoall") # rotate shards using all-to-all
212+
213+
with sdpa_kernel(backend):
214+
with context_parallel(
215+
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
216+
):
217+
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
218+
219+
220+
The default rotation approach is the all-gather based pass-KV.
221+
222+
223+
Conclusion
224+
----------
225+
226+
In this tutorial, we have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
227+
design and implementation details, performance analysis, and an end-to-end training example in `TorchTitan <https://github.com/pytorch/torchtitan>`__,
228+
see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__.

Diff for: prototype_source/prototype_index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ Prototype features are not available as part of binary distributions like PyPI o
239239
:link: ../prototype/flight_recorder_tutorial.html
240240
:tags: Distributed, Debugging, FlightRecorder
241241

242+
.. customcarditem::
243+
:header: Context Parallel Tutorial
244+
:card_description: Parallelize the attention computation along sequence dimension
245+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
246+
:link: ../prototype/context_parallel.html
247+
:tags: Distributed, Context Parallel
248+
242249
.. Integration
243250
.. customcarditem::
244251
:header: Out-of-tree extension autoloading in Python
@@ -265,6 +272,7 @@ Prototype features are not available as part of binary distributions like PyPI o
265272
.. toctree::
266273
:hidden:
267274

275+
prototype/context_parallel.html
268276
prototype/fx_graph_mode_quant_guide.html
269277
prototype/fx_graph_mode_ptq_dynamic.html
270278
prototype/fx_graph_mode_ptq_static.html

0 commit comments

Comments
 (0)