Skip to content

Commit 6891c57

Browse files
authored
Feat: context parallel v2.0 (#3700)
* Cleanup: context parallel * Feat: cleanup * Feat: concept guide * Fix: rename + version check * Style * Fix: add to namespace in a test * Fix: add skip_if on dataclass tests * Fix: proper version for version check * Feat: add tests and cleanup * Fix: properly version check added tests * Feat: address comments * Fix: add both shift_labels and labels to make the model.forward calculate loss * Fix: remove import, improve comment * Fix: final checks * Fix: style * Fix: style
1 parent 24e48f3 commit 6891c57

18 files changed

+683
-218
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
title: FSDP vs DeepSpeed
9393
- local: concept_guides/fsdp1_vs_fsdp2
9494
title: FSDP1 vs FSDP2
95+
- local: concept_guides/context_parallelism
96+
title: Context parallelism
9597
- local: concept_guides/low_precision_training
9698
title: Low precision training methods
9799
- local: concept_guides/training_tpu
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
-->
15+
16+
# Context Parallel in 🤗`accelerate`
17+
18+
This guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections.
19+
20+
## Why context parallelism?
21+
22+
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
23+
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
24+
25+
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
26+
27+
## How to use context parallelism?
28+
29+
```diff
30+
from accelerate.utils import ParallelismConfig, TorchContextParallelConfig
31+
32+
+ cp_config = TorchContextParallelConfig(
33+
+ cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
34+
+ )
35+
36+
+ parallelism_config = ParallelismConfig(
37+
+ cp_size=8,
38+
+ cp_handler=cp_config, # or just cp_size=8, if you want to use the default "allgather"
39+
+ )
40+
41+
accelerator = Accelerator(
42+
...,
43+
parallelism_config=parallelism_config,
44+
)
45+
```
46+
47+
As with any other feature in 🤗`accelerate`, you can enabled context parallelism also by passing the corresponding flags to `accelerate launch`.
48+
In this case, it's no different:
49+
50+
```bash
51+
accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...
52+
```
53+
54+
> [!Tip]
55+
> You can also set the `cp_size` and `cp_comm_strategy` in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.
56+
57+
> [!Tip]
58+
> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.
59+
> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.
60+
61+
> [!Warning]
62+
> Context parallelism is tightly coupled with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your
63+
> program. If no `FSDP2` is used, error will be raised.
64+
65+
> [!Warning]
66+
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
67+
68+
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do your training loop, we provide a context manager than is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
69+
You can use it as follows:
70+
71+
```python
72+
for batch in dataloader:
73+
with accelerator.maybe_context_parallel(
74+
buffers=[batch["input_ids"], batch["attention_mask"]],
75+
buffer_seq_dims=[1, 1],
76+
no_restore_buffers={batch["input_ids"], batch["labels"]},
77+
):
78+
outputs = model(**batch)
79+
...
80+
```
81+
82+
> [!Warning]
83+
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
84+
85+
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
86+
87+
<p align="center">
88+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
89+
<br>
90+
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
91+
</p>
92+
93+
> [!Tip]
94+
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
95+
> ```bash
96+
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
97+
> ```
98+
99+
100+
## Accelerate's interface
101+
102+
The context manager takes a few arguments, that are used to configure the context parallelism.
103+
104+
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
105+
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.
106+
as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.
107+
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.
108+
109+
110+
> [!Warning]
111+
> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.
112+
> Imagine this case:
113+
> labels = [l1, l2, l3, l4, ... li]
114+
> if we apply context parallelism, each rank would end up with a part of labels, such as this:
115+
> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...
116+
> after transformers modelling code shifts the labels, it would end up with:
117+
> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...
118+
> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.
119+
> Because of this, you need to manually shift the labels before passing them in the model
120+
121+
122+
## Configurable options
123+
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)
124+
125+
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
126+
127+
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
128+
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
129+
130+
You can see an end-to-end example in the [ND parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py) file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.
131+
132+
## Technical details
133+
134+
> [!Tip]
135+
> This section is fairly technical, so if you don't need to learn the internals of context parallelism, you can skip it and start building 🚀
136+
137+
We're going to be using word `shard` extensively in the following sections, so let's define it first. If we call tensor `sharded` across `Dth` dimension, across `N` ranks, we mean that this tensor is split into `N` parts, where each part of the tensor has shape `[..., D//N, ...]`.
138+
139+
140+
## So how does it work?
141+
142+
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
143+
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
144+
145+
We can formalize this in a following pseudocode:
146+
```python
147+
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
148+
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
149+
attn[i] = attn(Qi, Ki, Vi)
150+
for j in range(context_parallel_size):
151+
Kj, Vj = comm_kernel()
152+
attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]
153+
154+
final_attn = combine(attn)
155+
```
156+
157+
## all-to-all vs all-gather
158+
159+
### all-gather
160+
So what's the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention `attn_i` we launch an all-gather to gather all other `Ks` and `Vs` from all other ranks. As this communication is done, each rank has all the `Ks` and `Vs` from all other ranks, and can compute the attention with them sequentially.
161+
In ideal scenario, all-gather finishes in the exact moment as the calculation of `attn_i` is done. However, this never happens in practice, so the ideal real overlap is achieved when the full `attn_i` is overlapped with a part of the communication, then to start the computation with `K_j` and `V_j`, we wait for the all-gather to finish.
162+
163+
### all-to-all
164+
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
165+
166+
## How to choose the right rotation method?
167+
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
168+
169+
You can directly see this issue in the profiler output in the image below:
170+
<p align="center">
171+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png" alt="all-to-all profiler output" />
172+
<br>
173+
<em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em>
174+
</p>
175+
176+
177+
## Why only FSDP2?
178+
179+
We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
180+
utilize its full potential.
181+
How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.
182+
This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.
183+
184+
<p align="center">
185+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
186+
<br>
187+
<em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em>
188+
</p>
189+
190+
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.
191+
192+
## Data dispatching in joint mesh
193+
194+
We make sure to dispatch the same batch of data to the whole `cp` subgroup, so that the results are correct. (Meaning each rank in `cp` subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of `dp_shard` group.
195+
Imagine it like this:
196+
```
197+
# 8 GPUS, --dp_shard_size 4, --cp_size 2
198+
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
199+
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
200+
# GPUs 0,1 = batch 0
201+
# GPUs 2,3 = batch 1
202+
... and so on.
203+
```
204+

examples/fsdp2/nd_parallel.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def parse_args():
4343
parser.add_argument("--dp-replicate-size", type=int, default=1)
4444
parser.add_argument("--dp-shard-size", type=int, default=1)
4545
parser.add_argument("--tp-size", type=int, default=1)
46+
parser.add_argument("--cp-size", type=int, default=1)
4647
parser.add_argument("--sequence-length", type=int, default=1024)
4748
parser.add_argument("--num-steps", type=int, default=1000)
4849
parser.add_argument("--save-dir", type=str, default="./outputs")
@@ -52,17 +53,28 @@ def parse_args():
5253
return parser.parse_args()
5354

5455

55-
def forward(model, batch, optimizer, accelerator):
56-
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
57-
loss_reduce_grp = (
58-
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
59-
)
60-
outputs = model(**batch)
61-
loss = outputs.loss
62-
accelerator.backward(loss)
63-
optimizer.step()
64-
optimizer.zero_grad()
65-
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
56+
def forward(model, batch, optimizer, accelerator: Accelerator):
57+
# We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
58+
# itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
59+
buffers = [batch["input_ids"], batch["shift_labels"], batch["labels"]]
60+
with accelerator.maybe_context_parallel(
61+
buffers=buffers, buffer_seq_dims=[1, 1, 1], no_restore_buffers=set(buffers)
62+
):
63+
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
64+
# As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device
65+
# I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts
66+
loss_reduce_grp = (
67+
accelerator.torch_device_mesh["dp_cp"].get_group()
68+
if accelerator.parallelism_config.dp_cp_dim_names
69+
else None
70+
)
71+
outputs = model(**batch)
72+
loss = outputs.loss
73+
accelerator.backward(loss)
74+
optimizer.step()
75+
optimizer.zero_grad()
76+
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
77+
6678
return loss
6779

6880

@@ -71,21 +83,21 @@ def train(args):
7183
dp_replicate_size=args.dp_replicate_size,
7284
dp_shard_size=args.dp_shard_size,
7385
tp_size=args.tp_size,
86+
cp_size=args.cp_size,
7487
)
7588

7689
# FSDP needs extra configuration, so we properly shard the model
77-
if parallelism_config.dp_shard_enabled:
90+
fsdp2_plugin = None
91+
if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:
7892
fsdp2_plugin = FullyShardedDataParallelPlugin(
7993
fsdp_version=2,
8094
auto_wrap_policy="transformer_based_wrap",
8195
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
96+
state_dict_type="SHARDED_STATE_DICT",
8297
)
8398

8499
accelerator = Accelerator(
85-
log_with=["wandb"],
86-
mixed_precision="bf16",
87-
parallelism_config=parallelism_config,
88-
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
100+
log_with=["wandb"], mixed_precision="bf16", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin
89101
)
90102
accelerator.init_trackers("nd_parallel_training")
91103

@@ -146,7 +158,7 @@ def train(args):
146158
if __name__ == "__main__":
147159
set_seed(42)
148160
args = parse_args()
149-
if args.dp_shard_size == 1:
161+
if args.dp_shard_size == 1 and args.tp_size > 1:
150162
# We currently don't support saving with `save_state` when using only
151163
# tensor parallelism, fsdp must be enabled
152164
warnings.warn(

0 commit comments

Comments
 (0)