You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: prototype_source/context_parallel.rst
+2-2
Original file line number
Diff line number
Diff line change
@@ -26,7 +26,7 @@ Introduction
26
26
Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices.
27
27
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
28
28
29
-
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
29
+
Ring Attention, a novel parallel implementation of the Attention layer, is critical to performant Context Parallel.
30
30
Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device.
31
31
Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__:
32
32
@@ -42,7 +42,7 @@ The Context Parallel APIs consist of two parts:
42
42
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
43
43
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
44
44
argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers``
45
-
and shard them along this dimension.
45
+
and shard them along this dimension. Taking Llama3 training as an example, missing ``freq_cis`` in ``buffers`` will result in a miscalculated rotary embedding.
46
46
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
0 commit comments