Skip to content

Conversation

@DrJesseGlass
Copy link
Contributor

This work centered around implementing flash attention in the Qwen3 transformer.

Implementing flash for GPU was straightforward but the cpu_flash_attention had a fundamentally different API structure. So I shimmed the old cpu_flash attention and moved it to a modular attention/cpu_flash candle-nn/src/attention/cpu_flash/standard.rs with only two extremely minor changes (changed two functions to pub(crate) from private). This enabled me to cleanly implement a causal attention for CPU which leverages the loop bound approach and dispatch logic. This saved 14% peak memory compared to the prior cpu_flash method and a 3-4% speed improvement (see experiments below).

This memory improvement applies to standard cpu attention also. But unfortunately, all existing cpu flash attentions are still currently slower than standard cpu attention. (Future work will be to improve this with fused kernels.)

Experiments were run on NVIDIA DGX (CPU and GPU benchmarks) for Qwen3-0.6B and the prefill was the first ~1,500 words of Ulysses.

CPU Throughput

Implementation Throughput Notes
Causal (new) 3.54 t/s Loop-bound, no mask tensor
Prior flash (mask) 3.42 t/s Requires mask tensor
Standard matmul 3.69 t/s Baseline

Causal is ~3.5% faster than prior mask-based flash attention.

CPU Memory (Peak RSS)

Implementation Peak Memory Δ vs Standard
Flash attention 2.91 GB -14%
Standard matmul 3.37 GB baseline

Flash attention reduces peak memory by ~460 MB for 1685-token prefill by avoiding full Q×KV matrix materialization.

GPU Throughput

Implementation Throughput
GPU flash attention 40.39 t/s
GPU standard matmul 38.04 t/s

Future Work

Unified SDPA module: Add scaled dot-product attention with dispatch for all backends (CPU flash, GPU flash, GPU matmul). This will simplify the Qwen3 transformer and other implementations.

Improve causal performance: Pursuing fused SIMD kernels that handle GQA head broadcasting internally. Had tried Broadcast matmul but ran into limitations #3253 however fused kernels are likely more performant.

Explore interleaved KV-cached attention.

@DrJesseGlass
Copy link
Contributor Author

Thinking to also include Rotary embeddings and other attention related shared processes to this component.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant