Optimization for CPU Causal Flash Attention (integrated into Qwen3) #3254
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This work centered around implementing flash attention in the Qwen3 transformer.
Implementing flash for GPU was straightforward but the cpu_flash_attention had a fundamentally different API structure. So I shimmed the old cpu_flash attention and moved it to a modular attention/cpu_flash candle-nn/src/attention/cpu_flash/standard.rs with only two extremely minor changes (changed two functions to pub(crate) from private). This enabled me to cleanly implement a causal attention for CPU which leverages the loop bound approach and dispatch logic. This saved 14% peak memory compared to the prior cpu_flash method and a 3-4% speed improvement (see experiments below).
This memory improvement applies to standard cpu attention also. But unfortunately, all existing cpu flash attentions are still currently slower than standard cpu attention. (Future work will be to improve this with fused kernels.)
Experiments were run on NVIDIA DGX (CPU and GPU benchmarks) for Qwen3-0.6B and the prefill was the first ~1,500 words of Ulysses.
CPU Throughput
Causal is ~3.5% faster than prior mask-based flash attention.
CPU Memory (Peak RSS)
Flash attention reduces peak memory by ~460 MB for 1685-token prefill by avoiding full Q×KV matrix materialization.
GPU Throughput
Future Work
Unified SDPA module: Add scaled dot-product attention with dispatch for all backends (CPU flash, GPU flash, GPU matmul). This will simplify the Qwen3 transformer and other implementations.
Improve causal performance: Pursuing fused SIMD kernels that handle GQA head broadcasting internally. Had tried Broadcast matmul but ran into limitations #3253 however fused kernels are likely more performant.
Explore interleaved KV-cached attention.