|
| 1 | +# jax-flash-attn |
| 2 | + |
| 3 | +This repo contains bindings for [FlashAttention2](https://github.com/Dao-AILab/flash-attention) |
| 4 | +in JAX. There are two versions for these bindings, a C++ version |
| 5 | +`jax_flash_attn` and a Rust version `jflash_attn`. |
| 6 | + |
| 7 | +The BSD-3 license that holds for the flash-attention repo also applies here. |
| 8 | + |
| 9 | +## Building the C++ Version |
| 10 | + |
| 11 | +Build a wheel file. `-j32` will compile 32 cuda kernels in parallel which could exhaust memory on boxes with |
| 12 | +less than 100GB. |
| 13 | +```bash |
| 14 | +python setup.py bdist_wheel -- -- -j32 |
| 15 | +``` |
| 16 | + |
| 17 | +Build locally for development. |
| 18 | +```bash |
| 19 | +python setup.py build_ext -i -- -- -j32 |
| 20 | +python test.py # run some tests and benchmarks |
| 21 | +``` |
| 22 | + |
| 23 | +This may require you to install the two following pip packages: |
| 24 | +```bash |
| 25 | +pip install scikit_build |
| 26 | +pip install "pybind11[global]" |
| 27 | +``` |
| 28 | + |
| 29 | +## Building the Rust Version |
| 30 | + |
| 31 | +In order to build a python package as a wheel, run `maturin build --release`. |
| 32 | +In order to build a python package and install it in the current virtual |
| 33 | +enviroment, run `maturin develop`. |
| 34 | + |
| 35 | +## Running the Tests and Benchmarks |
| 36 | + |
| 37 | +First compile the C++ and/or Rust package and install them locally. Use the |
| 38 | +following to run the tests. |
| 39 | +```bash |
| 40 | +python test.py --bindings cpp |
| 41 | +python test.py --bindings rust |
| 42 | +``` |
| 43 | + |
| 44 | +And use the `--bench` flag to run the benchmarks instead of the tests. |
| 45 | + |
| 46 | +```bash |
| 47 | +python test.py --bindings cpp --bench True |
| 48 | +python test.py --bindings rust --bench True |
| 49 | +``` |
| 50 | + |
| 51 | +## Benchmarks (H100 80G HBM3) |
| 52 | + |
| 53 | +This measures the time spent in the attention layer for three different implementations. |
| 54 | +- `flash-attn`: uses the optimized flash-attention kernel. |
| 55 | +- `attn-einsum`: uses a simple attention implementation based on einsum. |
| 56 | +- `attn-flax`: uses `flax.linen.dot_product_attention`. |
| 57 | +Timings include the forward pass only for the first lines and both the forward |
| 58 | +and backward passes for the lines that start with `bwd`. The second column is the |
| 59 | +sequence length (the batch size is adapted so as to have a reasonable amount of |
| 60 | +computation). |
| 61 | + |
| 62 | +``` |
| 63 | +flash-attn 512 1.23ms 55.8 TFLOPS (std 0.54ms, min 0.79ms, max 2.38ms) |
| 64 | +attn-flax 512 1.83ms 37.6 TFLOPS (std 0.58ms, min 1.54ms, max 3.88ms) |
| 65 | +flash-attn 1024 1.24ms 110.7 TFLOPS (std 0.38ms, min 0.89ms, max 2.14ms) |
| 66 | +attn-flax 1024 2.40ms 57.2 TFLOPS (std 0.49ms, min 1.81ms, max 3.58ms) |
| 67 | +flash-attn 2048 1.59ms 173.2 TFLOPS (std 0.34ms, min 1.37ms, max 2.44ms) |
| 68 | +attn-flax 2048 3.46ms 79.4 TFLOPS (std 0.30ms, min 3.04ms, max 4.42ms) |
| 69 | +flash-attn 4096 2.40ms 229.2 TFLOPS (std 0.22ms, min 2.23ms, max 3.24ms) |
| 70 | +attn-flax 4096 6.08ms 90.4 TFLOPS (std 0.45ms, min 5.76ms, max 7.32ms) |
| 71 | +flash-attn 8192 4.26ms 258.3 TFLOPS (std 0.25ms, min 4.08ms, max 4.96ms) |
| 72 | +attn-flax 8192 11.19ms 98.3 TFLOPS (std 0.31ms, min 10.85ms, max 12.08ms) |
| 73 | +flash-attn 16384 7.86ms 279.8 TFLOPS (std 0.35ms, min 7.63ms, max 8.81ms) |
| 74 | +attn-flax 16384 26.56ms 82.8 TFLOPS (std 0.48ms, min 25.96ms, max 27.62ms) |
| 75 | +bwd flash-attn 512 3.01ms 79.9 TFLOPS (std 0.44ms, min 2.74ms, max 4.42ms) |
| 76 | +bwd attn-flax 512 4.26ms 56.4 TFLOPS (std 0.43ms, min 3.88ms, max 5.50ms) |
| 77 | +bwd flash-attn 1024 3.90ms 123.3 TFLOPS (std 0.53ms, min 3.30ms, max 4.92ms) |
| 78 | +bwd attn-flax 1024 5.43ms 88.6 TFLOPS (std 0.53ms, min 5.05ms, max 6.70ms) |
| 79 | +bwd flash-attn 2048 5.22ms 184.4 TFLOPS (std 0.61ms, min 4.52ms, max 6.51ms) |
| 80 | +bwd attn-flax 2048 8.69ms 110.6 TFLOPS (std 0.62ms, min 8.22ms, max 10.66ms) |
| 81 | +bwd flash-attn 4096 7.58ms 253.9 TFLOPS (std 0.30ms, min 7.35ms, max 8.47ms) |
| 82 | +bwd attn-flax 4096 15.08ms 127.6 TFLOPS (std 0.55ms, min 14.55ms, max 16.43ms) |
| 83 | +bwd flash-attn 8192 14.22ms 270.7 TFLOPS (std 0.76ms, min 13.56ms, max 16.65ms) |
| 84 | +bwd attn-flax 8192 28.03ms 137.3 TFLOPS (std 0.58ms, min 27.51ms, max 29.30ms) |
| 85 | +bwd flash-attn 16384 26.42ms 291.4 TFLOPS (std 0.45ms, min 26.03ms, max 27.50ms) |
| 86 | +bwd attn-flax 16384 57.84ms 133.1 TFLOPS (std 0.61ms, min 57.28ms, max 59.24ms) |
| 87 | +``` |
0 commit comments