This is a JAX adaptation of the ring-attention algorithm introduced in [1] inspired by the original JAX implementation in [2]. The algorithm was adopted to more closely resemble Flash Attention 2 algorithm [3], but where movement of data between SRAM/HBM is replaced by rotations of query/key/value blocks around the ring of devices which the sequences are shadred across.
This implementation also supports a general mechamnism for incorporating arbitrary attention biases from a user-defined function, similar to Flex Attention [4]. Finally, single-device attention block computaiton is performed with Pallas kernels heavily adopted from the implementations provided in the JAX repository [5].
References:
- Ring Attention with Blockwise Transformers for Near-Infinite Context Liu et al. https://arxiv.org/abs/2310.01889
- Ring Attention JAX code: https://github.com/haoliuhl/ringattention
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Tri Dao https://arxiv.org/abs/2307.08691
- Flex Attention, https://pytorch.org/blog/flexattention/
- Pallas/JAX Flash attention implementation for Pallas kernels. https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/attention.py