Skip to content

Commit ffbcf86

Browse files
committed
Initial checkout.
0 parents  commit ffbcf86

File tree

125 files changed

+9286
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+9286
-0
lines changed

.gitignore

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
_skbuild
6+
7+
# C extensions
8+
*.so
9+
*.o
10+
*.swp
11+
*~
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
share/python-wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.nox/
47+
.coverage
48+
.coverage.*
49+
.cache
50+
nosetests.xml
51+
coverage.xml
52+
*.cover
53+
*.py,cover
54+
.hypothesis/
55+
.pytest_cache/
56+
cover/
57+
58+
# Translations
59+
*.mo
60+
*.pot
61+
62+
# Django stuff:
63+
*.log
64+
local_settings.py
65+
db.sqlite3
66+
db.sqlite3-journal
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
.pybuilder/
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
85+
# IPython
86+
profile_default/
87+
ipython_config.py
88+
89+
# pyenv
90+
# For a library or package, you might want to ignore these files since the code is
91+
# intended to run in multiple environments; otherwise, check them in:
92+
# .python-version
93+
94+
# pipenv
95+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
97+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
98+
# install all needed dependencies.
99+
#Pipfile.lock
100+
101+
# poetry
102+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103+
# This is especially recommended for binary packages to ensure reproducibility, and is more
104+
# commonly ignored for libraries.
105+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106+
#poetry.lock
107+
108+
# pdm
109+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110+
#pdm.lock
111+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112+
# in version control.
113+
# https://pdm.fming.dev/#use-with-ide
114+
.pdm.toml
115+
116+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117+
__pypackages__/
118+
119+
# Celery stuff
120+
celerybeat-schedule
121+
celerybeat.pid
122+
123+
# SageMath parsed files
124+
*.sage.py
125+
126+
# Environments
127+
.env
128+
.venv
129+
env/
130+
venv/
131+
ENV/
132+
env.bak/
133+
venv.bak/
134+
135+
# Spyder project settings
136+
.spyderproject
137+
.spyproject
138+
139+
# Rope project settings
140+
.ropeproject
141+
142+
# mkdocs documentation
143+
/site
144+
145+
# mypy
146+
.mypy_cache/
147+
.dmypy.json
148+
dmypy.json
149+
150+
# Pyre type checker
151+
.pyre/
152+
153+
# pytype static type analyzer
154+
.pytype/
155+
156+
# Cython debug symbols
157+
cython_debug/
158+
159+
# PyCharm
160+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162+
# and can be added to the global gitignore or merged into this file. For a more nuclear
163+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
164+
#.idea/
165+
166+
Cargo.lock
167+
*.sqlite
168+
*.nsys-rep

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "cutlass"]
2+
path = cutlass
3+
url = https://github.com/NVIDIA/cutlass.git

CMakeLists.txt

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Adapted from https://github.com/dfm/extending-jax/blob/main/CMakeLists.txt
2+
3+
cmake_minimum_required(VERSION 3.15...3.26)
4+
project(jax_flash_attn LANGUAGES C CXX)
5+
message(STATUS "Using CMake version: " ${CMAKE_VERSION})
6+
set(CMAKE_BUILD_TYPE RelWithDebInfo)
7+
8+
set(PYBIND11_NEWPYTHON ON)
9+
find_package(pybind11 CONFIG REQUIRED)
10+
11+
include_directories(${CMAKE_CURRENT_LIST_DIR}/csrc)
12+
13+
include(CheckLanguage)
14+
check_language(CUDA)
15+
16+
if(CMAKE_CUDA_COMPILER)
17+
enable_language(CUDA)
18+
set( CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}" "--expt-relaxed-constexpr" )
19+
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
20+
include_directories(${CMAKE_CURRENT_LIST_DIR}/cutlass/include)
21+
file(GLOB KERNEL_FILES ${CMAKE_CURRENT_LIST_DIR}/csrc/flash_*wd_hdim*.cu)
22+
message(KF="${KERNEL_FILES}")
23+
pybind11_add_module(
24+
_jax_flash_attn
25+
${KERNEL_FILES}
26+
${CMAKE_CURRENT_LIST_DIR}/csrc/flash_attn_ops.cpp
27+
${CMAKE_CURRENT_LIST_DIR}/csrc/flash_api.cu)
28+
target_compile_options(_jax_flash_attn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
29+
set_property(TARGET _jax_flash_attn PROPERTY CUDA_ARCHITECTURES 90)
30+
install(TARGETS _jax_flash_attn LIBRARY DESTINATION .)
31+
32+
add_library(
33+
cc_jax_flash_attn
34+
STATIC
35+
${KERNEL_FILES}
36+
${CMAKE_CURRENT_LIST_DIR}/csrc/flash_api.cu)
37+
target_compile_options(cc_jax_flash_attn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
38+
set_property(TARGET cc_jax_flash_attn PROPERTY CUDA_ARCHITECTURES 90)
39+
install(TARGETS cc_jax_flash_attn LIBRARY DESTINATION .)
40+
else()
41+
message(FATAL_ERROR "Cannot be built without CUDA")
42+
endif()

Cargo.toml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[package]
2+
name = "jax_flash_attn"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
name = "jflash_attn"
8+
crate-type = ["cdylib"]
9+
10+
[dependencies]
11+
pyo3 = { version = "0.20.0", features = ["extension-module"] }
12+
serde = { version = "1.0", features = ["derive"] }
13+
bincode = "1.3"
14+
15+
[build-dependencies]
16+
anyhow = { version = "1", features = ["backtrace"] }
17+
num_cpus = "1.15.0"
18+
rayon = "1.7.0"
19+
pyo3-build-config = "0.20"

Makefile

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
python_executable=$(shell python -c 'import sys; print(sys.executable)')
2+
pybind_include_path=$(shell python -c "import pybind11; print(pybind11.get_include())")
3+
4+
FLASH_FWD_CU = \
5+
build/flash_fwd_hdim32_fp16_sm80.cu.o \
6+
build/flash_fwd_hdim64_fp16_sm80.cu.o \
7+
build/flash_fwd_hdim96_fp16_sm80.cu.o \
8+
build/flash_fwd_hdim128_fp16_sm80.cu.o \
9+
build/flash_fwd_hdim160_fp16_sm80.cu.o \
10+
build/flash_fwd_hdim192_fp16_sm80.cu.o \
11+
build/flash_fwd_hdim224_fp16_sm80.cu.o \
12+
build/flash_fwd_hdim256_fp16_sm80.cu.o \
13+
build/flash_fwd_hdim32_bf16_sm80.cu.o \
14+
build/flash_fwd_hdim64_bf16_sm80.cu.o \
15+
build/flash_fwd_hdim96_bf16_sm80.cu.o \
16+
build/flash_fwd_hdim128_bf16_sm80.cu.o \
17+
build/flash_fwd_hdim160_bf16_sm80.cu.o \
18+
build/flash_fwd_hdim192_bf16_sm80.cu.o \
19+
build/flash_fwd_hdim224_bf16_sm80.cu.o \
20+
build/flash_fwd_hdim256_bf16_sm80.cu.o \
21+
build/flash_fwd_hdim32_fp16_causal_sm80.cu.o \
22+
build/flash_fwd_hdim64_fp16_causal_sm80.cu.o \
23+
build/flash_fwd_hdim96_fp16_causal_sm80.cu.o \
24+
build/flash_fwd_hdim128_fp16_causal_sm80.cu.o \
25+
build/flash_fwd_hdim160_fp16_causal_sm80.cu.o \
26+
build/flash_fwd_hdim192_fp16_causal_sm80.cu.o \
27+
build/flash_fwd_hdim224_fp16_causal_sm80.cu.o \
28+
build/flash_fwd_hdim256_fp16_causal_sm80.cu.o \
29+
build/flash_fwd_hdim32_bf16_causal_sm80.cu.o \
30+
build/flash_fwd_hdim64_bf16_causal_sm80.cu.o \
31+
build/flash_fwd_hdim96_bf16_causal_sm80.cu.o \
32+
build/flash_fwd_hdim128_bf16_causal_sm80.cu.o \
33+
build/flash_fwd_hdim160_bf16_causal_sm80.cu.o \
34+
build/flash_fwd_hdim192_bf16_causal_sm80.cu.o \
35+
build/flash_fwd_hdim224_bf16_causal_sm80.cu.o \
36+
build/flash_fwd_hdim256_bf16_causal_sm80.cu.o
37+
FLASH_BWD_CU = \
38+
build/flash_bwd_hdim32_fp16_sm80.cu.o \
39+
build/flash_bwd_hdim64_fp16_sm80.cu.o \
40+
build/flash_bwd_hdim96_fp16_sm80.cu.o \
41+
build/flash_bwd_hdim128_fp16_sm80.cu.o \
42+
build/flash_bwd_hdim160_fp16_sm80.cu.o \
43+
build/flash_bwd_hdim192_fp16_sm80.cu.o \
44+
build/flash_bwd_hdim224_fp16_sm80.cu.o \
45+
build/flash_bwd_hdim256_fp16_sm80.cu.o \
46+
build/flash_bwd_hdim32_bf16_sm80.cu.o \
47+
build/flash_bwd_hdim64_bf16_sm80.cu.o \
48+
build/flash_bwd_hdim96_bf16_sm80.cu.o \
49+
build/flash_bwd_hdim128_bf16_sm80.cu.o \
50+
build/flash_bwd_hdim160_bf16_sm80.cu.o \
51+
build/flash_bwd_hdim192_bf16_sm80.cu.o \
52+
build/flash_bwd_hdim224_bf16_sm80.cu.o \
53+
build/flash_bwd_hdim256_bf16_sm80.cu.o
54+
55+
all: build/flash_attn.so
56+
57+
clean:
58+
rm -Rf build/*
59+
build/%.cu.o : src/%.cu
60+
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 \
61+
-Icutlass/include -std=c++17 \
62+
--generate-code=arch=compute_90,code=[compute_90,sm_90] \
63+
-Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c $< -o $@
64+
65+
build/flash_attn_ops.cpp.o: src/flash_attn_ops.cpp src/pybind11_kernel_helpers.h src/kernels.h
66+
c++ -I/usr/local/cuda/include -I/usr/include/python3.10 -std=c++17 \
67+
-I$(pybind_include_path) $(${python_executable}-config --cflags) \
68+
-O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects \
69+
-o build/flash_attn_ops.cpp.o -c src/flash_attn_ops.cpp
70+
71+
build/flash_attn.so: $(FLASH_FWD_CU) $(FLASH_BWD_CU) build/flash_attn_ops.cpp.o build/flash_api.cu.o
72+
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o $@ -std=c++17 \
73+
build/flash_attn_ops.cpp.o $(FLASH_FWD_CU) $(FLASH_BWD_CU) build/flash_api.cu.o -L/usr/local/cuda/lib64 \
74+
-lcudadevrt -lcudart_static -lrt -lpthread -ldl

README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# jax-flash-attn
2+
3+
This repo contains bindings for [FlashAttention2](https://github.com/Dao-AILab/flash-attention)
4+
in JAX. There are two versions for these bindings, a C++ version
5+
`jax_flash_attn` and a Rust version `jflash_attn`.
6+
7+
The BSD-3 license that holds for the flash-attention repo also applies here.
8+
9+
## Building the C++ Version
10+
11+
Build a wheel file. `-j32` will compile 32 cuda kernels in parallel which could exhaust memory on boxes with
12+
less than 100GB.
13+
```bash
14+
python setup.py bdist_wheel -- -- -j32
15+
```
16+
17+
Build locally for development.
18+
```bash
19+
python setup.py build_ext -i -- -- -j32
20+
python test.py # run some tests and benchmarks
21+
```
22+
23+
This may require you to install the two following pip packages:
24+
```bash
25+
pip install scikit_build
26+
pip install "pybind11[global]"
27+
```
28+
29+
## Building the Rust Version
30+
31+
In order to build a python package as a wheel, run `maturin build --release`.
32+
In order to build a python package and install it in the current virtual
33+
enviroment, run `maturin develop`.
34+
35+
## Running the Tests and Benchmarks
36+
37+
First compile the C++ and/or Rust package and install them locally. Use the
38+
following to run the tests.
39+
```bash
40+
python test.py --bindings cpp
41+
python test.py --bindings rust
42+
```
43+
44+
And use the `--bench` flag to run the benchmarks instead of the tests.
45+
46+
```bash
47+
python test.py --bindings cpp --bench True
48+
python test.py --bindings rust --bench True
49+
```
50+
51+
## Benchmarks (H100 80G HBM3)
52+
53+
This measures the time spent in the attention layer for three different implementations.
54+
- `flash-attn`: uses the optimized flash-attention kernel.
55+
- `attn-einsum`: uses a simple attention implementation based on einsum.
56+
- `attn-flax`: uses `flax.linen.dot_product_attention`.
57+
Timings include the forward pass only for the first lines and both the forward
58+
and backward passes for the lines that start with `bwd`. The second column is the
59+
sequence length (the batch size is adapted so as to have a reasonable amount of
60+
computation).
61+
62+
```
63+
flash-attn 512 1.23ms 55.8 TFLOPS (std 0.54ms, min 0.79ms, max 2.38ms)
64+
attn-flax 512 1.83ms 37.6 TFLOPS (std 0.58ms, min 1.54ms, max 3.88ms)
65+
flash-attn 1024 1.24ms 110.7 TFLOPS (std 0.38ms, min 0.89ms, max 2.14ms)
66+
attn-flax 1024 2.40ms 57.2 TFLOPS (std 0.49ms, min 1.81ms, max 3.58ms)
67+
flash-attn 2048 1.59ms 173.2 TFLOPS (std 0.34ms, min 1.37ms, max 2.44ms)
68+
attn-flax 2048 3.46ms 79.4 TFLOPS (std 0.30ms, min 3.04ms, max 4.42ms)
69+
flash-attn 4096 2.40ms 229.2 TFLOPS (std 0.22ms, min 2.23ms, max 3.24ms)
70+
attn-flax 4096 6.08ms 90.4 TFLOPS (std 0.45ms, min 5.76ms, max 7.32ms)
71+
flash-attn 8192 4.26ms 258.3 TFLOPS (std 0.25ms, min 4.08ms, max 4.96ms)
72+
attn-flax 8192 11.19ms 98.3 TFLOPS (std 0.31ms, min 10.85ms, max 12.08ms)
73+
flash-attn 16384 7.86ms 279.8 TFLOPS (std 0.35ms, min 7.63ms, max 8.81ms)
74+
attn-flax 16384 26.56ms 82.8 TFLOPS (std 0.48ms, min 25.96ms, max 27.62ms)
75+
bwd flash-attn 512 3.01ms 79.9 TFLOPS (std 0.44ms, min 2.74ms, max 4.42ms)
76+
bwd attn-flax 512 4.26ms 56.4 TFLOPS (std 0.43ms, min 3.88ms, max 5.50ms)
77+
bwd flash-attn 1024 3.90ms 123.3 TFLOPS (std 0.53ms, min 3.30ms, max 4.92ms)
78+
bwd attn-flax 1024 5.43ms 88.6 TFLOPS (std 0.53ms, min 5.05ms, max 6.70ms)
79+
bwd flash-attn 2048 5.22ms 184.4 TFLOPS (std 0.61ms, min 4.52ms, max 6.51ms)
80+
bwd attn-flax 2048 8.69ms 110.6 TFLOPS (std 0.62ms, min 8.22ms, max 10.66ms)
81+
bwd flash-attn 4096 7.58ms 253.9 TFLOPS (std 0.30ms, min 7.35ms, max 8.47ms)
82+
bwd attn-flax 4096 15.08ms 127.6 TFLOPS (std 0.55ms, min 14.55ms, max 16.43ms)
83+
bwd flash-attn 8192 14.22ms 270.7 TFLOPS (std 0.76ms, min 13.56ms, max 16.65ms)
84+
bwd attn-flax 8192 28.03ms 137.3 TFLOPS (std 0.58ms, min 27.51ms, max 29.30ms)
85+
bwd flash-attn 16384 26.42ms 291.4 TFLOPS (std 0.45ms, min 26.03ms, max 27.50ms)
86+
bwd attn-flax 16384 57.84ms 133.1 TFLOPS (std 0.61ms, min 57.28ms, max 59.24ms)
87+
```

0 commit comments

Comments
 (0)