Skip to content

jondeaton/ring-attention-jax-pallas

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 

Repository files navigation

Ring Attention in JAX / Pallas with flexible attention.

This is a JAX adaptation of the ring-attention algorithm introduced in [1] inspired by the original JAX implementation in [2]. The algorithm was adopted to more closely resemble Flash Attention 2 algorithm [3], but where movement of data between SRAM/HBM is replaced by rotations of query/key/value blocks around the ring of devices which the sequences are shadred across.

This implementation also supports a general mechamnism for incorporating arbitrary attention biases from a user-defined function, similar to Flex Attention [4]. Finally, single-device attention block computaiton is performed with Pallas kernels heavily adopted from the implementations provided in the JAX repository [5].

References:

  1. Ring Attention with Blockwise Transformers for Near-Infinite Context Liu et al. https://arxiv.org/abs/2310.01889
  2. Ring Attention JAX code: https://github.com/haoliuhl/ringattention
  3. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Tri Dao https://arxiv.org/abs/2307.08691
  4. Flex Attention, https://pytorch.org/blog/flexattention/
  5. Pallas/JAX Flash attention implementation for Pallas kernels. https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/attention.py

About

Ring Attention in JAX/Pallas

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages