Skip to content

Commit 62ede1e

Browse files
CP docs typos fixed (#3761)
1 parent 9f9c490 commit 62ede1e

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

docs/source/concept_guides/context_parallelism.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This guide will cover basics of using context parallelism in 🤗`accelerate`, f
1919

2020
## Why context parallelism?
2121

22-
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
22+
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.
2323
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
2424

2525
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
@@ -44,7 +44,7 @@ accelerator = Accelerator(
4444
)
4545
```
4646

47-
As with any other feature in 🤗`accelerate`, you can enabled context parallelism also by passing the corresponding flags to `accelerate launch`.
47+
As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
4848
In this case, it's no different:
4949

5050
```bash
@@ -65,7 +65,7 @@ accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-st
6565
> [!Warning]
6666
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
6767
68-
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do your training loop, we provide a context manager than is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
68+
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
6969
You can use it as follows:
7070

7171
```python
@@ -82,7 +82,7 @@ for batch in dataloader:
8282
> [!Warning]
8383
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
8484
85-
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
85+
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.
8686

8787
<p align="center">
8888
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
@@ -120,7 +120,7 @@ The context manager takes a few arguments, that are used to configure the contex
120120
121121
122122
## Configurable options
123-
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)
123+
Accelerate provides only a single option to configure context parallelism (except for `cp_size`)
124124
125125
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
126126
@@ -142,7 +142,7 @@ We're going to be using word `shard` extensively in the following sections, so l
142142
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
143143
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
144144
145-
We can formalize this in a following pseudocode:
145+
We can formalize this in the following pseudocode:
146146
```python
147147
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
148148
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
@@ -164,7 +164,7 @@ In ideal scenario, all-gather finishes in the exact moment as the calculation of
164164
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
165165
166166
## How to choose the right rotation method?
167-
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
167+
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
168168
169169
You can directly see this issue in the profiler output in the image below:
170170
<p align="center">

examples/torch_native_parallelism/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Torch Native Parallelism
22

3-
With recent versions of Torch, there has been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
3+
With recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
44
This folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.
55

66
### ND Parallelism
@@ -51,7 +51,7 @@ gaining even more speed and memory savings, as `ao` doesn't ship with any kernel
5151
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
5252
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
5353

54-
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS raise by using FP8.
54+
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.
5555

5656
<div style="display: flex; gap: 25px;">
5757
<div style="text-align: center; width: 49%;">

0 commit comments

Comments
 (0)