-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat(qwen2): add KV cache management and selective attention #3236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat(qwen2): add KV cache management and selective attention #3236
Conversation
- Fix causal mask shape for cached decoding (critical for multi-turn) - Add extract/restore methods for KV cache manipulation - Add support for non-contiguous cache positions via `cache_position` - Add forward_from_embeds for custom embedding workflows - Improve RoPE and softmax precision with F32 intermediates (matching PyTorch) - Replace NEG_INFINITY with f32::MIN to avoid NaN propagation
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of good stuff here!
A slight issue is that on cuda performance drops by 20%. Would be nice to figure out if we can easily avoid this drop before merging (I'm ok with accuracy > performance)
| for &abs_query_pos in cache_pos_vec.iter().take(query_length) { | ||
| let abs_query_pos = abs_query_pos as usize; | ||
| for j in 0..key_length { | ||
| // Causal: can't attend to future positions | ||
| let is_future = j > abs_query_pos; | ||
| // Sliding window: can't attend to positions too far in the past | ||
| let is_too_old = j + self.sliding_window < abs_query_pos; | ||
|
|
||
| if is_future || is_too_old { | ||
| mask_data.push(min_dtype); | ||
| } else { | ||
| mask_data.push(0.0); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty close to a where_cond where t: min_dtype and f: 0.0. Would be nice if we didn't have to move cache_position into the cpu for this, but hard to express is_future and is_too_old though.
I'll give it a think.
Summary
Adds KV cache management and fixes critical causal mask bug for Qwen2 multi-turn inference. Includes numerical precision improvements for RoPE and attention.
Changes
[tgt, tgt], now[tgt, total]) - critical for multi-turn conversationsextract_kv_cache/restore_kv_cachemethods for cache manipulation and inspectionprepare_4d_causal_attention_mask_with_cache_positionfor non-contiguous cache positionsforward_from_embedsmethods enable custom embedding workflows (e.g., multimodal)NEG_INFINITYwithf32::MINto avoid NaN propagation when combining masksshift_kv_cache_first_to_lastfor advanced patterns (e.g., negative prompt refresh)Motivation
The causal mask bug prevented proper multi-turn decoding with KV cache. The new cache management APIs enable advanced inference patterns like streaming audio generation (VibeVoice) and speculative decoding while maintaining precision for F16/BF16 inference.
Breaking Changes
None - all changes are backward compatible additions or bug fixes.
✅ Validation
Routine
cargo fmt --all
cargo test -p candle-transformers
cargo clippy -p candle-transformers
Test Qwen2 Example
Simple Query
cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. ." --model "2-1.5b"
Test with very short prompt to ensure single-token decode works
cargo run --example qwen --features metal --release -- --prompt "Hi" --sample-len 10 --model "2-1.5b"