|
| 1 | +Introduction to Context Parallel |
| 2 | +====================================== |
| 3 | +**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__ |
| 4 | + |
| 5 | +.. note:: |
| 6 | + |edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/prototype_source/context_parallel.rst>`__. |
| 7 | + |
| 8 | +.. grid:: 2 |
| 9 | + |
| 10 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 11 | + :class-card: card-prerequisites |
| 12 | + |
| 13 | + * `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__ |
| 14 | + * `1M sequence training in TorchTitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__ |
| 15 | + |
| 16 | + |
| 17 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 18 | + :class-card: card-prerequisites |
| 19 | + |
| 20 | + * PyTorch 2.7 or later |
| 21 | + |
| 22 | + |
| 23 | +Introduction |
| 24 | +------------ |
| 25 | + |
| 26 | +Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices. |
| 27 | +It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks. |
| 28 | + |
| 29 | +Ring Attention, a novel parallel implementation of the Attention layer, is critical to performant Context Parallel. |
| 30 | +Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device. |
| 31 | +Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__: |
| 32 | + |
| 33 | +1. The all-gather based pass-KV algorithm is used in Llama3 training, which initially performs an all-gather on the key and value tensors, followed by computing the attention output for the |
| 34 | + local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk |
| 35 | + using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of |
| 36 | + overlap between the attention computation and the all-gather collective. For example, in the case of Llama3 training, we also shard ``freq_cis`` over the sequence dimension. |
| 37 | +2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA (Scaled Dot Product Attention) computation and the all-to-all communication |
| 38 | + necessary for the next SDPA. |
| 39 | + |
| 40 | +The Context Parallel APIs consist of two parts: |
| 41 | + |
| 42 | +1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``) |
| 43 | + will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to |
| 44 | + argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers`` |
| 45 | + and shard them along this dimension. Taking Llama3 training as an example, missing ``freq_cis`` in ``buffers`` will result in a miscalculated rotary embedding. |
| 46 | +2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach. |
| 47 | + |
| 48 | + |
| 49 | +Setup |
| 50 | +--------------------- |
| 51 | + |
| 52 | +With ``torch.distributed.tensor.experimental.context_parallel()``, users can easily shard the Tensor input and parallelize the execution of the SDPA function. |
| 53 | +To better demonstrate the usage of this API, we start with a simple code snippet doing SDPA and then parallelize it using the API: |
| 54 | + |
| 55 | +.. code:: python |
| 56 | +
|
| 57 | + import torch |
| 58 | + import torch.nn.functional as F |
| 59 | +
|
| 60 | + from torch.nn.attention import sdpa_kernel, SDPBackend |
| 61 | +
|
| 62 | +
|
| 63 | + def sdpa_example(): |
| 64 | + assert torch.cuda.is_available() |
| 65 | + torch.cuda.set_device("cuda:0") |
| 66 | + torch.cuda.manual_seed(0) |
| 67 | +
|
| 68 | + batch = 8 |
| 69 | + nheads = 8 |
| 70 | + qkv_len = 8192 |
| 71 | + dim = 32 |
| 72 | + backend = SDPBackend.FLASH_ATTENTION |
| 73 | + dtype = ( |
| 74 | + torch.bfloat16 |
| 75 | + if backend == SDPBackend.FLASH_ATTENTION |
| 76 | + or backend == SDPBackend.CUDNN_ATTENTION |
| 77 | + else torch.float32 |
| 78 | + ) |
| 79 | +
|
| 80 | + qkv = [ |
| 81 | + torch.rand( |
| 82 | + (batch, nheads, qkv_len, dim), |
| 83 | + dtype=dtype, |
| 84 | + requires_grad=True, |
| 85 | + device='cuda', |
| 86 | + ) |
| 87 | + for _ in range(3) |
| 88 | + ] |
| 89 | + # specify the SDPBackend to use |
| 90 | + with sdpa_kernel(backend): |
| 91 | + out = F.scaled_dot_product_attention(*qkv, is_causal=True) |
| 92 | +
|
| 93 | +
|
| 94 | + if __name__ == "__main__": |
| 95 | + sdpa_example() |
| 96 | +
|
| 97 | +
|
| 98 | +Enable Context Parallel |
| 99 | +----------------------- |
| 100 | + |
| 101 | +Now, let's first adapt it to a distributed program where each rank has the same tensor input. Then we apply the context parallel API to |
| 102 | +shard to input and distribute the computation across ranks: |
| 103 | + |
| 104 | +.. code:: python |
| 105 | +
|
| 106 | + # file: cp_sdpa_example.py |
| 107 | + import os |
| 108 | +
|
| 109 | + import torch |
| 110 | + import torch.distributed as dist |
| 111 | + import torch.nn.functional as F |
| 112 | + from torch.distributed.device_mesh import init_device_mesh |
| 113 | + from torch.distributed.tensor.experimental import context_parallel |
| 114 | + from torch.distributed.tensor.experimental._attention import context_parallel_unshard |
| 115 | + from torch.nn.attention import sdpa_kernel, SDPBackend |
| 116 | +
|
| 117 | +
|
| 118 | + def context_parallel_sdpa_example(world_size: int, rank: int): |
| 119 | + assert torch.cuda.is_available() |
| 120 | + assert dist.is_nccl_available() |
| 121 | + torch.cuda.set_device(f"cuda:{rank}") |
| 122 | + torch.cuda.manual_seed(0) |
| 123 | +
|
| 124 | + dist.init_process_group( |
| 125 | + backend="nccl", |
| 126 | + init_method="env://", |
| 127 | + world_size=world_size, |
| 128 | + rank=rank, |
| 129 | + ) |
| 130 | + device_mesh = init_device_mesh( |
| 131 | + device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",) |
| 132 | + ) |
| 133 | +
|
| 134 | + batch = 8 |
| 135 | + nheads = 8 |
| 136 | + qkv_len = 64 |
| 137 | + dim = 32 |
| 138 | + backend = SDPBackend.FLASH_ATTENTION |
| 139 | + dtype = ( |
| 140 | + torch.bfloat16 |
| 141 | + if backend == SDPBackend.FLASH_ATTENTION |
| 142 | + or backend == SDPBackend.CUDNN_ATTENTION |
| 143 | + else torch.float32 |
| 144 | + ) |
| 145 | +
|
| 146 | + qkv = [ |
| 147 | + torch.rand( |
| 148 | + (batch, nheads, qkv_len, dim), |
| 149 | + dtype=dtype, |
| 150 | + requires_grad=True, |
| 151 | + device='cuda', |
| 152 | + ) |
| 153 | + for _ in range(3) |
| 154 | + ] |
| 155 | + # specify the SDPBackend to use |
| 156 | + with sdpa_kernel(backend): |
| 157 | + out = F.scaled_dot_product_attention(*qkv, is_causal=True) |
| 158 | +
|
| 159 | + # make a clean copy of QKV for output comparison |
| 160 | + cp_qkv = [t.detach().clone() for t in qkv] |
| 161 | +
|
| 162 | + with sdpa_kernel(backend): |
| 163 | + # This `context_parallel()` performs two actions: |
| 164 | + # 1. Shard the tensor objects in `buffers` in-place along the dimension |
| 165 | + # specified in `buffer_seq_dims`, the tensors in `buffers` and their |
| 166 | + # sharding dims in `buffer_seq_dims` are organized in the same order. |
| 167 | + # 2. Replace the execution of `F.scaled_dot_product_attention` with a |
| 168 | + # context-paralleled-enabled Ring Attention. |
| 169 | + with context_parallel( |
| 170 | + device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2) |
| 171 | + ): |
| 172 | + cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True) |
| 173 | +
|
| 174 | + # The output `cp_out` is still sharded in the same way as QKV |
| 175 | + # the `context_parallel_unshard` API allows users to easily |
| 176 | + # unshard to gain the full tensor. |
| 177 | + (cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2]) |
| 178 | +
|
| 179 | + assert torch.allclose( |
| 180 | + cp_out, |
| 181 | + out, |
| 182 | + atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size), |
| 183 | + ) |
| 184 | +
|
| 185 | +
|
| 186 | + if __name__ == "__main__": |
| 187 | + rank = int(os.environ["RANK"]) |
| 188 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 189 | +
|
| 190 | + try: |
| 191 | + context_parallel_sdpa_example(world_size, rank) |
| 192 | + finally: |
| 193 | + dist.barrier() |
| 194 | + dist.destroy_process_group() |
| 195 | +
|
| 196 | +
|
| 197 | +You can use the command ``torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py`` to launch the above context parallel |
| 198 | +SDPA on 4 GPUs. We demonstrate the numeric correctness by comparing the output of Ring Attention to that of SDPA on a single GPU. |
| 199 | + |
| 200 | + |
| 201 | +Select Rotation Approach |
| 202 | +------------------------ |
| 203 | + |
| 204 | +You can choose the desired shards rotation approach in Ring Attention by using ``torch.distributed.tensor.experimental._attention.set_rotate_method()``: |
| 205 | + |
| 206 | +.. code:: python |
| 207 | +
|
| 208 | + # file: cp_sdpa_example.py |
| 209 | + from torch.distributed.tensor.experimental._attention import set_rotate_method |
| 210 | +
|
| 211 | + set_rotate_method("alltoall") # rotate shards using all-to-all |
| 212 | +
|
| 213 | + with sdpa_kernel(backend): |
| 214 | + with context_parallel( |
| 215 | + device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2) |
| 216 | + ): |
| 217 | + cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True) |
| 218 | +
|
| 219 | +
|
| 220 | +The default rotation approach is the all-gather based pass-KV. |
| 221 | + |
| 222 | + |
| 223 | +Conclusion |
| 224 | +---------- |
| 225 | + |
| 226 | +In this tutorial, we have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For |
| 227 | +design and implementation details, performance analysis, and an end-to-end training example in `TorchTitan <https://github.com/pytorch/torchtitan>`__, |
| 228 | +see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__. |
0 commit comments