diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 796f3a1d1f..f14d448e42 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -338,7 +338,8 @@ fn main() -> Result<()> { Model::Moe(ModelMoe::new(&config, vb)?) } WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => { - let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + let mut config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + config.use_flash_attn = args.use_flash_attn; Model::Base3(Model3::new(&config, vb)?) } WhichModel::W3MoeA3b => { diff --git a/candle-nn/src/attention/cpu_flash/causal.rs b/candle-nn/src/attention/cpu_flash/causal.rs new file mode 100644 index 0000000000..315386c466 --- /dev/null +++ b/candle-nn/src/attention/cpu_flash/causal.rs @@ -0,0 +1,389 @@ +//! Optimized causal attention using loop-bound masking. +//! +//! Instead of materializing a mask tensor and checking each position, +//! this implementation computes the causal boundary and only iterates +//! over valid positions. This skips ~50% of work for causal attention. + +use candle::{Device, Result, Storage, Tensor, WithDType}; +use rayon::prelude::*; +use std::iter::Sum; + +use super::standard::{vec_dot, FLASH_ATTN_POOL}; + +/// Size (in KV positions) processed by each inner-tile job. +const TILE_KV: usize = 16; + +/// Causal attention optimized with loop-bound masking. +/// +/// Dispatches to decode (q_len=1) or prefill (q_len>1) paths. +#[allow(clippy::too_many_arguments)] +pub fn run_causal_attn_cpu( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + kv_offset: usize, + max_bias: Option, + softcap: Option, +) -> Result +where + T: WithDType + Sum + num_traits::real::Real, +{ + // Extract CPU slices for q, k, v + let (q_guard, q_layout) = q.storage_and_layout(); + let q_data: &[T] = match &*q_guard { + Storage::Cpu(cpu) => &cpu.as_slice::()?[q_layout.start_offset()..], + _ => return Err(candle::Error::Msg("Expected CPU storage for q".into())), + }; + + let (k_guard, k_layout) = k.storage_and_layout(); + let k_data: &[T] = match &*k_guard { + Storage::Cpu(cpu) => &cpu.as_slice::()?[k_layout.start_offset()..], + _ => return Err(candle::Error::Msg("Expected CPU storage for k".into())), + }; + + let (v_guard, v_layout) = v.storage_and_layout(); + let v_data: &[T] = match &*v_guard { + Storage::Cpu(cpu) => &cpu.as_slice::()?[v_layout.start_offset()..], + _ => return Err(candle::Error::Msg("Expected CPU storage for v".into())), + }; + + let q_stride = q.stride(); + let k_stride = k.stride(); + let v_stride = v.stride(); + + let q_len = q.shape().dims()[1]; + + if q_len == 1 { + causal_attn_decode( + q_data, + k_data, + v_data, + q.shape().dims(), + k.shape().dims(), + v.shape().dims(), + q_stride, + k_stride, + v_stride, + softmax_scale, + kv_offset, + max_bias.unwrap_or(0.0), + softcap.unwrap_or(0.0), + ) + } else { + causal_attn_prefill( + q_data, + k_data, + v_data, + q.shape().dims(), + k.shape().dims(), + v.shape().dims(), + q_stride, + k_stride, + v_stride, + softmax_scale, + kv_offset, + max_bias.unwrap_or(0.0), + softcap.unwrap_or(0.0), + ) + } +} + +/// Decode path: q_len == 1, attends to all kv_len positions. +/// +/// For decode, the single query token is conceptually at position `kv_offset`, +/// so it can attend to all KV positions [0, kv_len). +#[allow(clippy::too_many_arguments)] +fn causal_attn_decode( + q_data: &[T], + k_data: &[T], + v_data: &[T], + qshape: &[usize], + kshape: &[usize], + vshape: &[usize], + qstride: &[usize], + kstride: &[usize], + vstride: &[usize], + scale: f32, + _kv_offset: usize, // Not used for decode - query sees all KV + max_bias: f32, + logit_softcap: f32, +) -> Result { + let (b, _q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]); + let kv_len = kshape[1]; + let k_h = kshape[2]; + let v_h = vshape[2]; + let rk = h / k_h; + let rv = h / v_h; + let dv = d; + + // ALiBi slope calculation + let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); + + // Output buffer: (B, H, 1, D) + let mut out = vec![0f32; b * h * dv]; + let kv_tiles = kv_len.div_ceil(TILE_KV); + + FLASH_ATTN_POOL.install(|| { + out.par_chunks_mut(dv) + .with_min_len(64) + .enumerate() + .for_each(|(row_idx, out_chunk)| { + let b_i = row_idx / h; + let h_i = row_idx % h; + + // ALiBi slope + let slope = if max_bias > 0.0 { + 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) + } else { + 0.0 + }; + + // GQA head mapping + let k_head = h_i / rk; + let v_head = h_i / rv; + + // Q row - for decode q_len=1, q_pos=0 + let q_base = b_i * qstride[0] + h_i * qstride[2]; + let q_row = &q_data[q_base..q_base + d]; + + // Parallel reduce over KV tiles + let (vkq, s_tot, _m_tot) = (0..kv_tiles) + .into_par_iter() + .map(|tile_idx| { + let start = tile_idx * TILE_KV; + let end = (start + TILE_KV).min(kv_len); + + let mut vkq = vec![0f32; dv]; + let mut s = 0.0f32; + let mut m = f32::NEG_INFINITY; + + for kv_pos in start..end { + // ALiBi bias + let alibi_bias = if max_bias > 0.0 { + slope * (kv_pos as f32 - (kv_len - 1) as f32) + } else { + 0.0 + }; + + // K row + let k_base = + b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; + let k_row = &k_data[k_base..k_base + d]; + + // QK dot product + let mut s_val = vec_dot::(q_row, k_row).to_f64() as f32; + + // Scale + softcap + let mut scale_applied = scale; + if logit_softcap != 0.0 { + scale_applied /= logit_softcap; + } + s_val *= scale_applied; + if logit_softcap != 0.0 { + s_val = logit_softcap * s_val.tanh(); + } + s_val += alibi_bias; + + // Online softmax + let m_old = m; + let (ms, vs) = if s_val > m { + m = s_val; + let ms = (m_old - m).exp(); + for v in vkq.iter_mut() { + *v *= ms; + } + (ms, 1.0f32) + } else { + (1.0f32, (s_val - m).exp()) + }; + + // Accumulate V + let v_base = + b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; + for d_i in 0..dv { + vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; + } + + s = s * ms + vs; + } + + (vkq, s, m) + }) + .reduce( + || (vec![0f32; dv], 0.0f32, f32::NEG_INFINITY), + merge_softmax_accumulators, + ); + + // Final normalization + let inv_s = if s_tot > 0.0 { 1.0 / s_tot } else { 0.0 }; + for (out_v, acc_v) in out_chunk.iter_mut().zip(vkq.iter()) { + *out_v = *acc_v * inv_s; + } + }); + }); + + Tensor::from_vec(out, (b, h, 1usize, dv), &Device::Cpu) +} + +/// Prefill path: q_len > 1, uses loop bounds to skip masked positions. +/// +/// Each query at position q_pos can attend to KV positions [0, q_pos + kv_offset]. +#[allow(clippy::too_many_arguments)] +fn causal_attn_prefill( + q_data: &[T], + k_data: &[T], + v_data: &[T], + qshape: &[usize], + kshape: &[usize], + vshape: &[usize], + qstride: &[usize], + kstride: &[usize], + vstride: &[usize], + scale: f32, + kv_offset: usize, + max_bias: f32, + logit_softcap: f32, +) -> Result { + let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]); + let kv_len = kshape[1]; + let k_h = kshape[2]; + let v_h = vshape[2]; + let rk = h / k_h; + let rv = h / v_h; + let dv = d; + + // ALiBi slope calculation + let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); + + let mut out = vec![0f32; b * q_len * h * dv]; + + FLASH_ATTN_POOL.install(|| { + out.par_chunks_mut(dv) + .with_min_len(64) + .enumerate() + .for_each(|(row_idx, out_chunk)| { + // Decode flat index to (batch, head, q_pos) + let rows_per_batch = h * q_len; + let b_i = row_idx / rows_per_batch; + let rem = row_idx % rows_per_batch; + let h_i = rem / q_len; + let q_pos = rem % q_len; + + // ALiBi slope + let slope = if max_bias > 0.0 { + 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) + } else { + 0.0 + }; + + // GQA head mapping + let k_head = h_i / rk; + let v_head = h_i / rv; + + // Buffers + let mut vkq = vec![0f32; dv]; + let mut s = 0.0f32; + let mut m = f32::NEG_INFINITY; + + let mut q_row: Vec = Vec::with_capacity(d); + let mut k_row: Vec = Vec::with_capacity(d); + + // Gather Q (strided) + let q_base = b_i * qstride[0] + q_pos * qstride[1] + h_i * qstride[2]; + for di in 0..d { + q_row.push(q_data[q_base + di * qstride[3]]); + } + + // LOOP-BOUND CAUSAL: only iterate up to causal boundary + let kv_end = (q_pos + kv_offset + 1).min(kv_len); + + for kv_pos in 0..kv_end { + // ALiBi bias + let alibi_bias = if max_bias > 0.0 { + slope * (kv_pos as i64 - (q_pos + kv_offset) as i64) as f32 + } else { + 0.0 + }; + + // K row (strided) + let k_base = b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; + k_row.clear(); + for di in 0..d { + k_row.push(k_data[k_base + di * kstride[3]]); + } + + // QK dot product + let mut s_val = vec_dot::(&q_row, &k_row); + + // Scale + softcap + let mut scale_applied = scale; + if logit_softcap != 0.0 { + scale_applied /= logit_softcap; + } + s_val *= T::from_f64(scale_applied as f64); + if logit_softcap != 0.0 { + s_val = T::from_f64(logit_softcap as f64 * s_val.to_f64().tanh()); + } + s_val += T::from_f64(alibi_bias as f64); + + // Online softmax + let m_old = m; + let mut ms = 1.0f32; + let mut vs = 1.0f32; + if s_val.to_f64() as f32 > m { + m = s_val.to_f64() as f32; + ms = (m_old - m).exp(); + for v in vkq.iter_mut() { + *v *= ms; + } + } else { + vs = (s_val.to_f64() as f32 - m).exp(); + } + + // V row (strided) + let v_base = b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; + for d_i in 0..dv { + vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; + } + + s = s * ms + vs; + } + + // Normalize & write + let inv_s = if s > 0.0 { 1.0 / s } else { 0.0 }; + for v in vkq.iter_mut() { + *v *= inv_s; + } + out_chunk.copy_from_slice(&vkq); + }); + }); + + Tensor::from_vec(out, (b, h, q_len, dv), &Device::Cpu) +} + +/// Merge two online softmax accumulators. +#[inline] +fn merge_softmax_accumulators( + a: (Vec, f32, f32), + b: (Vec, f32, f32), +) -> (Vec, f32, f32) { + let (vkq_a, s_a, m_a) = a; + let (vkq_b, s_b, m_b) = b; + + if m_a >= m_b { + let factor = (m_b - m_a).exp(); + let mut vkq = vkq_a; + for (va, vb) in vkq.iter_mut().zip(vkq_b.iter()) { + *va += *vb * factor; + } + (vkq, s_a + s_b * factor, m_a) + } else { + let factor = (m_a - m_b).exp(); + let mut vkq = vkq_b; + for (vb, va) in vkq.iter_mut().zip(vkq_a.iter()) { + *vb += va * factor; + } + (vkq, s_b + s_a * factor, m_b) + } +} diff --git a/candle-nn/src/attention/cpu_flash/mod.rs b/candle-nn/src/attention/cpu_flash/mod.rs new file mode 100644 index 0000000000..f41234fb8d --- /dev/null +++ b/candle-nn/src/attention/cpu_flash/mod.rs @@ -0,0 +1,57 @@ +//! CPU flash attention implementations. +//! +//! - `standard`: General-purpose with explicit mask tensor +//! - `causal`: Optimized loop-bound causal masking (no tensor allocation) + +pub mod causal; +pub mod standard; + +use candle::{Result, Tensor, WithDType}; +use std::iter::Sum; + +use super::AttnMask; + +/// Flash attention with automatic dispatch. +/// +/// Selects optimal implementation: +/// - `AttnMask::Causal` → `causal.rs` (loop-bound, no mask tensor) +/// - `AttnMask::None` or `AttnMask::Mask` → `standard.rs` +/// +/// # Arguments +/// * `q` - Query tensor, shape `(B, S, H, D)` +/// * `k` - Key tensor, shape `(B, KV_S, KV_H, D)` +/// * `v` - Value tensor, shape `(B, KV_S, KV_H, D)` +/// * `softmax_scale` - Scale factor (typically `1/sqrt(head_dim)`) +/// * `attn_mask` - Masking strategy +/// * `max_bias` - ALiBi max bias (`None` to disable) +/// * `softcap` - Logit soft-capping (`None` to disable) +/// +/// # Returns +/// Output tensor with shape `(B, H, S, D)` +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + attn_mask: AttnMask<'_>, + max_bias: Option, + softcap: Option, +) -> Result +where + T: WithDType + Sum + num_traits::real::Real, +{ + match attn_mask { + AttnMask::Causal { kv_offset } => { + // Optimized path: loop-bound causal masking + causal::run_causal_attn_cpu::(q, k, v, softmax_scale, kv_offset, max_bias, softcap) + } + AttnMask::None => { + // No masking + standard::run_flash_attn_cpu::(q, k, v, None, softmax_scale, max_bias, softcap) + } + AttnMask::Mask(mask) => { + // Explicit mask tensor + standard::run_flash_attn_cpu::(q, k, v, Some(mask), softmax_scale, max_bias, softcap) + } + } +} diff --git a/candle-nn/src/attention/cpu_flash/standard.rs b/candle-nn/src/attention/cpu_flash/standard.rs new file mode 100644 index 0000000000..94dc38d0a2 --- /dev/null +++ b/candle-nn/src/attention/cpu_flash/standard.rs @@ -0,0 +1,488 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use candle::{Device, Result, Storage, Tensor, WithDType}; +use std::sync::LazyLock; +use std::{f32, iter::Sum}; + +use rayon::prelude::*; +use rayon::ThreadPool; + +#[cfg(target_os = "macos")] +/// Elevate the thread QoS so macOS prefers running it on Performance (P) cores. +unsafe fn set_thread_affinity() { + // USER_INTERACTIVE has the highest scheduling priority that user code + // can request and is most likely to be scheduled on P‑cores. + use libc::{pthread_set_qos_class_self_np, qos_class_t::QOS_CLASS_USER_INTERACTIVE}; + // The second argument is a relative priority within the QoS class (0 = default). + pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); +} + +#[cfg(not(target_os = "macos"))] +#[inline(always)] +unsafe fn set_thread_affinity() { + // On non‑macOS platforms we currently leave affinity untouched. +} + +/// Rayon pool used by the flash‑attention CPU kernels, with a per‑thread +/// start handler that applies our affinity hint exactly once. +pub(crate) static FLASH_ATTN_POOL: LazyLock = LazyLock::new(|| { + rayon::ThreadPoolBuilder::new() + .start_handler(|_| unsafe { + set_thread_affinity(); + }) + .build() + .expect("Failed to build custom Rayon thread‑pool for flash‑attention") +}); + +const DOT_CHUNK: usize = 4; + +/// Size (in KV positions) processed by each inner‑tile job. +const TILE_KV: usize = 16; + +#[inline] +pub(crate) fn vec_dot>( + a: &[T], + b: &[T], +) -> T { + let mut sum = T::zero(); + let chunks = a.len() / DOT_CHUNK; + + for i in 0..chunks { + let i_chunk = i * DOT_CHUNK; + sum = sum + + a[i_chunk] * b[i_chunk] + + a[i_chunk + 1] * b[i_chunk + 1] + + a[i_chunk + 2] * b[i_chunk + 2] + + a[i_chunk + 3] * b[i_chunk + 3]; + } + + for i in (chunks * DOT_CHUNK)..a.len() { + sum += a[i] * b[i]; + } + sum +} + +/// Fused attention optimized for CPU. +/// +/// Computes softmax(qk^T*scale)v. +/// +/// **Inputs shapes:** +/// - `q`: (bs, seq, qhead, hidden) +/// - `k`: (bs, kv_seq, v_head, hidden) +/// - `k`: (bs, kv_seq, kv_head_seq, v_hidden) +/// - `scale` is applied before softmax. +/// +/// - This supports ALiBi with `max_bias` as well as softcapping with `softcap`. +/// +/// **Output shape:** (bs, qhead, seq, v_hidden) +pub fn run_flash_attn_cpu( + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + softmax_scale: f32, + max_bias: Option, + softcap: Option, +) -> Result +where + T: WithDType + Sum + num_traits::real::Real, +{ + // Inline CPU slice extraction for q, k, v, and optional mask + let (q_guard, q_layout) = q.storage_and_layout(); + let q_data: &[T] = if let Storage::Cpu(cpu) = &*q_guard { + let data = cpu.as_slice::()?; + &data[q_layout.start_offset()..] + } else { + return Err(candle::Error::Msg("Expected CPU storage for q".into())); + }; + let (k_guard, k_layout) = k.storage_and_layout(); + let k_data: &[T] = if let Storage::Cpu(cpu) = &*k_guard { + let data = cpu.as_slice::()?; + &data[k_layout.start_offset()..] + } else { + return Err(candle::Error::Msg("Expected CPU storage for k".into())); + }; + let (v_guard, v_layout) = v.storage_and_layout(); + let v_data: &[T] = if let Storage::Cpu(cpu) = &*v_guard { + let data = cpu.as_slice::()?; + &data[v_layout.start_offset()..] + } else { + return Err(candle::Error::Msg("Expected CPU storage for v".into())); + }; + let mask_guard = mask.map(|mask| mask.storage_and_layout().0); + let mask_data: Option<&[T]> = if let Some(mask_guard) = &mask_guard { + let mask = mask.as_ref().unwrap(); + + if let Storage::Cpu(cpu) = &**mask_guard { + let data = cpu.as_slice::()?; + Some(&data[mask.layout().start_offset()..]) + } else { + return Err(candle::Error::Msg("Expected CPU storage for mask".into())); + } + } else { + None + }; + // q_guard, k_guard, v_guard, and m_guard (if any) are kept in scope to hold storage alive + + let q_stride = q.stride(); + let k_stride = k.stride(); + let v_stride = v.stride(); + + // Fast path for decode: q_len == 1 + if q.shape().dims()[1] == 1 { + return flash_attn_cpu_single_q( + q_data, + k_data, + v_data, + mask_data, + q.shape().dims(), + k.shape().dims(), + v.shape().dims(), + q_stride, + k_stride, + v_stride, + softmax_scale, + max_bias.unwrap_or(0.0), + softcap.unwrap_or(0.0), + ); + } + + flash_attn_cpu( + q_data, + k_data, + v_data, + mask_data, + q.shape().dims(), + k.shape().dims(), + v.shape().dims(), + q_stride, + k_stride, + v_stride, + softmax_scale, + max_bias.unwrap_or(0.0), + softcap.unwrap_or(0.0), + ) +} + +/// Optimised path for the common decode case: q_len == 1 but kv_len ≫ 1. +/// We drop the inner q‑position loop and parallelise over `(batch, head)`. +#[allow(clippy::too_many_arguments)] +fn flash_attn_cpu_single_q( + q_data: &[T], + k_data: &[T], + v_data: &[T], + mask_vec: Option<&[T]>, + qshape: &[usize], + kshape: &[usize], + vshape: &[usize], + qstride: &[usize], + kstride: &[usize], + vstride: &[usize], + scale: f32, + max_bias: f32, + logit_softcap: f32, +) -> Result { + // Shapes: (B, 1, H, D) + let (b, _q_len, h, d) = ( + qshape[0], qshape[1], // == 1 + qshape[2], qshape[3], + ); + let kv_len = kshape[1]; + let k_h = kshape[2]; + let v_h = vshape[2]; + let rk2 = h / k_h; + let rv2 = h / v_h; + let dv = d; + + let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); + + // Output buffer: (B, H, 1, D) + let mut out = vec![0f32; b * h * dv]; + + // Expose a second dimension of work: split the KV axis into tiles that + // fit in the last‑level cache and let Rayon schedule them. + let kv_tiles = kv_len.div_ceil(TILE_KV); + + // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut slices, so no two + // threads write the same output area. + FLASH_ATTN_POOL.install(|| { + out.par_chunks_mut(dv) + .with_min_len(64) + .enumerate() + .for_each(|(row_idx, out_chunk)| { + let b_i = row_idx / h; + let h_i = row_idx % h; + + // ALiBi positional bias (standard formula) + let slope = if max_bias > 0.0 { + 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) + } else { + 1.0 + }; + + // For grouped‑KV we collapse multiple query heads into the same K/V head. + let k_head = h_i / rk2; + let v_head = h_i / rv2; + + // ------------------------------------------------------------------ + // Nested parallelism: each KV tile is mapped independently, then we + // reduce the partial results with the correct soft‑max algebra. + // ------------------------------------------------------------------ + let (vkq, s_tot, _m_tot) = (0..kv_tiles) + .into_par_iter() + .map(|tile_idx| { + // ---- per‑tile scratch ------------------------------------------------- + let start = tile_idx * TILE_KV; + let end = (start + TILE_KV).min(kv_len); + + let mut vkq = vec![0f32; dv]; + let mut s = 0.0f32; + let mut m = f32::NEG_INFINITY; + + // ---------------- single‑Q row (already contiguous) ------------------- + let q_base = + b_i * qstride[0] /*batch*/ + h_i * qstride[2] /*head*/; + let q_row = &q_data[q_base..q_base + d]; + + // ---------------- iterate over this KV slice -------------------------- + for kv_pos in start..end { + // Mask + let mv = if let Some(mv_vec) = mask_vec { + let mval = mv_vec[(b_i * kv_len) + kv_pos]; + slope * mval.to_f64() as f32 + } else { + 0.0 + }; + if mv == f32::NEG_INFINITY { + continue; + } + + // K row + let k_base = + b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; + let k_row = &k_data[k_base..k_base + d]; + + // dot(Q, K) + let mut s_val = vec_dot::(q_row, k_row).to_f64() as f32; + + let mut scale_applied = scale; + if logit_softcap != 0.0 { + scale_applied /= logit_softcap; + } + s_val *= scale_applied; + if logit_softcap != 0.0 { + s_val = logit_softcap * s_val.tanh(); + } + s_val += mv; + + // Tile‑local online softmax ------------------------------------------ + let m_old = m; + let mut ms = 1.0f32; + let mut vs = 1.0f32; + if s_val > m { + m = s_val; + ms = (m_old - m).exp(); + for v in vkq.iter_mut() { + *v *= ms; + } + } else { + vs = (s_val - m).exp(); + } + + // V row + let v_base = + b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; + for d_i in 0..dv { + vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; + } + + s = s * ms + vs; + } + + // Return per‑tile accumulator + softmax stats + (vkq, s, m) + }) + // -------- reduce two tiles ----------------------------------------------- + .reduce( + || (vec![0f32; dv], 0.0f32, f32::NEG_INFINITY), + |mut a, b| { + let (ref mut vkq_a, mut s_a, m_a) = a; + let (vkq_b, s_b, m_b) = b; + if m_a >= m_b { + let factor = (m_b - m_a).exp(); + for (va, vb) in vkq_a.iter_mut().zip(vkq_b) { + *va += vb * factor; + } + s_a += s_b * factor; + (vkq_a.clone(), s_a, m_a) + } else { + let factor = (m_a - m_b).exp(); + let mut vkq_new = vkq_b; + for (vb, va) in vkq_new.iter_mut().zip(vkq_a) { + *vb += *va * factor; + } + (vkq_new, s_b + s_a * factor, m_b) + } + }, + ); + + // ---------------- final normalisation --------------------------------------- + let inv_s = 1.0 / s_tot; + for v in out_chunk.iter_mut().zip(vkq.iter()) { + *v.0 = *v.1 * inv_s; + } + }); + }); + + let out_shape = (b, h, 1usize, dv); + Tensor::from_vec(out, out_shape, &Device::Cpu) +} + +/// Main forward flash-attention CPU routine. +/// Shapes follow Candle convention: (B, S, H, D) +#[allow(clippy::too_many_arguments)] +fn flash_attn_cpu( + q_data: &[T], + k_data: &[T], + v_data: &[T], + mask_vec: Option<&[T]>, + qshape: &[usize], + kshape: &[usize], + vshape: &[usize], + qstride: &[usize], + kstride: &[usize], + vstride: &[usize], + scale: f32, + max_bias: f32, + logit_softcap: f32, +) -> Result { + let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]); + let kv_len = kshape[1]; + // --- Head broadcasting factors ---------------------------------------------------- + // Allows K and V to have fewer heads than Q (grouped‑KV); the ratio is an + // integer factor. rk2 = #Q‑heads / #K‑heads, rv2 = #Q‑heads / #V‑heads. + let k_h = kshape[2]; + let v_h = vshape[2]; + let rk2 = h / k_h; // must divide exactly; panic otherwise + let rv2 = h / v_h; + let dv = d; // value dim = key dim in this kernel + + // Precompute value for ALiBi slope calculation + let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); + + let mut out = vec![0f32; b * q_len * h * dv]; + + // ------------------------------------------------------------------ + // Rayon‑parallel version: each (b_i, h_i, q_pos) row is independent. + // ------------------------------------------------------------------ + + let _rows = b * h * q_len; // total independent work items + + // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut [f32] slices, + // so no two threads can write the same output area. + FLASH_ATTN_POOL.install(|| { + out.par_chunks_mut(dv) + .with_min_len(64) + .enumerate() + .for_each(|(row_idx, out_chunk)| { + // Decode flat index back to (batch, head, q_pos) + let rows_per_batch = h * q_len; + let b_i = row_idx / rows_per_batch; + let rem = row_idx % rows_per_batch; + let h_i = rem / q_len; + let q_pos = rem % q_len; + + let slope = if max_bias > 0.0 { + 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) + } else { + 1.0 + }; + + // For grouped‑KV we collapse multiple query heads into the same K/V head. + let k_head = h_i / rk2; + let v_head = h_i / rv2; + + // Buffers local to this row + let mut vkq = vec![0f32; dv]; + let mut s = 0.0f32; + let mut m = f32::NEG_INFINITY; + + // Allocate q_row and k_row once per row + let mut q_row: Vec = Vec::with_capacity(d); + let mut k_row: Vec = Vec::with_capacity(d); + + // ------------------- gather Q (strided) -------------------- + let q_base = b_i * qstride[0] + q_pos * qstride[1] + h_i * qstride[2]; + q_row.clear(); + for di in 0..d { + q_row.push(q_data[q_base + di * qstride[3]]); + } + + // ---------------- iterate over keys/values ----------------- + for kv_pos in 0..kv_len { + // Mask (optional) + let mv = if let Some(mv_vec) = mask_vec { + let mval = mv_vec[((b_i * q_len + q_pos) * kv_len) + kv_pos]; + slope * mval.to_f64() as f32 + } else { + 0.0 + }; + if mv == f32::NEG_INFINITY { + continue; + } + + // K row (strided) + let k_base = b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; + k_row.clear(); + for di in 0..d { + k_row.push(k_data[k_base + di * kstride[3]]); + } + + // dot(Q, K) + let mut s_val = vec_dot::(&q_row, &k_row); + let mut scale_applied = scale; + if logit_softcap != 0.0 { + scale_applied /= logit_softcap; + } + s_val *= T::from_f64(scale_applied as f64); + if logit_softcap != 0.0 { + s_val = T::from_f64(logit_softcap as f64 * s_val.to_f64().tanh()); + } + s_val += T::from_f64(mv as f64); + + // online softmax + let m_old = m; + let mut ms = 1.0f32; + let mut vs = 1.0f32; + if s_val.to_f64() as f32 > m { + m = s_val.to_f64() as f32; + ms = (m_old - m).exp(); + for v in vkq.iter_mut() { + *v *= ms; + } + } else { + vs = (s_val.to_f64() as f32 - m).exp(); + } + + // V row (strided) + let v_base = b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; + for d_i in 0..dv { + vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; + } + + s = s * ms + vs; + } + + // ------------------- normalise & write out ------------------ + let inv_s = 1.0 / s; + for v in vkq.iter_mut() { + *v *= inv_s; + } + out_chunk.copy_from_slice(&vkq); + }); + }); + + // Build output tensor with shape (B, H, S, D) to match standard (permute 0,2,1,3) + let out_shape = (b, h, q_len, dv); + Tensor::from_vec(out, out_shape, &Device::Cpu) +} diff --git a/candle-nn/src/attention/mod.rs b/candle-nn/src/attention/mod.rs new file mode 100644 index 0000000000..a1e36b1918 --- /dev/null +++ b/candle-nn/src/attention/mod.rs @@ -0,0 +1,82 @@ +//! Attention implementations for Candle. +//! +//! # Usage +//! +//! ```ignore +//! use candle_nn::attention::{flash_attn, AttnMask}; +//! +//! // Causal attention (uses optimized loop-bound path) +//! let out = flash_attn::( +//! &q, &k, &v, +//! 1.0 / (head_dim as f32).sqrt(), +//! AttnMask::causal_with_offset(kv_cache_len), +//! None, None, +//! )?; +//! +//! // Custom mask tensor +//! let out = flash_attn::( +//! &q, &k, &v, scale, +//! AttnMask::Mask(&mask_tensor), +//! None, None, +//! )?; +//! ``` + +pub mod cpu_flash; + +use candle::Tensor; + +// Re-export main API +pub use cpu_flash::flash_attn; + +/// Attention mask specification. +/// +/// Using an enum instead of raw tensors enables optimizations: +/// - `Causal`: Loop-bound masking (skips ~50% of positions, no tensor allocation) +/// - `Mask`: Explicit tensor for arbitrary patterns (sliding window, block-sparse) +/// - `None`: Full bidirectional attention +#[derive(Default, Debug, Clone, Copy)] +pub enum AttnMask<'a> { + /// No masking — full bidirectional attention. + #[default] + None, + + /// Causal masking via efficient loop bounds (no tensor allocation). + /// + /// `kv_offset`: Number of prior KV positions when using KV cache. + /// - Prefill: `kv_offset = 0` + /// - Decode: `kv_offset = cached_kv_len` + Causal { kv_offset: usize }, + + /// Custom mask tensor for arbitrary patterns. + /// + /// Shape: `(B, Q_LEN, KV_LEN)` or broadcastable. + /// Values: `0.0` to attend, `NEG_INFINITY` to mask. + Mask(&'a Tensor), +} + +impl<'a> AttnMask<'a> { + /// Causal mask for prefill (no KV offset). + #[inline] + pub fn causal() -> Self { + AttnMask::Causal { kv_offset: 0 } + } + + /// Causal mask for decode with KV cache. + #[inline] + pub fn causal_with_offset(kv_offset: usize) -> Self { + AttnMask::Causal { kv_offset } + } + + #[inline] + pub fn is_causal(&self) -> bool { + matches!(self, AttnMask::Causal { .. }) + } + + #[inline] + pub fn kv_offset(&self) -> usize { + match self { + AttnMask::Causal { kv_offset } => *kv_offset, + _ => 0, + } + } +} diff --git a/candle-nn/src/cpu_flash_attention.rs b/candle-nn/src/cpu_flash_attention.rs index f69b0fbae6..6a5f8bba38 100644 --- a/candle-nn/src/cpu_flash_attention.rs +++ b/candle-nn/src/cpu_flash_attention.rs @@ -1,485 +1,9 @@ -#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] - -use candle::{Device, Result, Storage, Tensor, WithDType}; -use std::sync::LazyLock; -use std::{f32, iter::Sum}; - -use rayon::prelude::*; -use rayon::ThreadPool; - -#[cfg(target_os = "macos")] -/// Elevate the thread QoS so macOS prefers running it on Performance (P) cores. -unsafe fn set_thread_affinity() { - // USER_INTERACTIVE has the highest scheduling priority that user code - // can request and is most likely to be scheduled on P‑cores. - use libc::{pthread_set_qos_class_self_np, qos_class_t::QOS_CLASS_USER_INTERACTIVE}; - // The second argument is a relative priority within the QoS class (0 = default). - pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); -} - -#[cfg(not(target_os = "macos"))] -#[inline(always)] -unsafe fn set_thread_affinity() { - // On non‑macOS platforms we currently leave affinity untouched. -} - -/// Rayon pool used by the flash‑attention CPU kernels, with a per‑thread -/// start handler that applies our affinity hint exactly once. -static FLASH_ATTN_POOL: LazyLock = LazyLock::new(|| { - rayon::ThreadPoolBuilder::new() - .start_handler(|_| unsafe { - set_thread_affinity(); - }) - .build() - .expect("Failed to build custom Rayon thread‑pool for flash‑attention") -}); - -const DOT_CHUNK: usize = 4; - -/// Size (in KV positions) processed by each inner‑tile job. -const TILE_KV: usize = 16; - -#[inline] -fn vec_dot>(a: &[T], b: &[T]) -> T { - let mut sum = T::zero(); - let chunks = a.len() / DOT_CHUNK; - - for i in 0..chunks { - let i_chunk = i * DOT_CHUNK; - sum = sum - + a[i_chunk] * b[i_chunk] - + a[i_chunk + 1] * b[i_chunk + 1] - + a[i_chunk + 2] * b[i_chunk + 2] - + a[i_chunk + 3] * b[i_chunk + 3]; - } - - for i in (chunks * DOT_CHUNK)..a.len() { - sum += a[i] * b[i]; - } - sum -} - -/// Fused attention optimized for CPU. -/// -/// Computes softmax(qk^T*scale)v. -/// -/// **Inputs shapes:** -/// - `q`: (bs, seq, qhead, hidden) -/// - `k`: (bs, kv_seq, v_head, hidden) -/// - `k`: (bs, kv_seq, kv_head_seq, v_hidden) -/// - `scale` is applied before softmax. -/// -/// - This supports ALiBi with `max_bias` as well as softcapping with `softcap`. -/// -/// **Output shape:** (bs, qhead, seq, v_hidden) -pub fn run_flash_attn_cpu( - q: &Tensor, - k: &Tensor, - v: &Tensor, - mask: Option<&Tensor>, - softmax_scale: f32, - max_bias: Option, - softcap: Option, -) -> Result -where - T: WithDType + Sum + num_traits::real::Real, -{ - // Inline CPU slice extraction for q, k, v, and optional mask - let (q_guard, q_layout) = q.storage_and_layout(); - let q_data: &[T] = if let Storage::Cpu(cpu) = &*q_guard { - let data = cpu.as_slice::()?; - &data[q_layout.start_offset()..] - } else { - return Err(candle::Error::Msg("Expected CPU storage for q".into())); - }; - let (k_guard, k_layout) = k.storage_and_layout(); - let k_data: &[T] = if let Storage::Cpu(cpu) = &*k_guard { - let data = cpu.as_slice::()?; - &data[k_layout.start_offset()..] - } else { - return Err(candle::Error::Msg("Expected CPU storage for k".into())); - }; - let (v_guard, v_layout) = v.storage_and_layout(); - let v_data: &[T] = if let Storage::Cpu(cpu) = &*v_guard { - let data = cpu.as_slice::()?; - &data[v_layout.start_offset()..] - } else { - return Err(candle::Error::Msg("Expected CPU storage for v".into())); - }; - let mask_guard = mask.map(|mask| mask.storage_and_layout().0); - let mask_data: Option<&[T]> = if let Some(mask_guard) = &mask_guard { - let mask = mask.as_ref().unwrap(); - - if let Storage::Cpu(cpu) = &**mask_guard { - let data = cpu.as_slice::()?; - Some(&data[mask.layout().start_offset()..]) - } else { - return Err(candle::Error::Msg("Expected CPU storage for mask".into())); - } - } else { - None - }; - // q_guard, k_guard, v_guard, and m_guard (if any) are kept in scope to hold storage alive - - let q_stride = q.stride(); - let k_stride = k.stride(); - let v_stride = v.stride(); - - // Fast path for decode: q_len == 1 - if q.shape().dims()[1] == 1 { - return flash_attn_cpu_single_q( - q_data, - k_data, - v_data, - mask_data, - q.shape().dims(), - k.shape().dims(), - v.shape().dims(), - q_stride, - k_stride, - v_stride, - softmax_scale, - max_bias.unwrap_or(0.0), - softcap.unwrap_or(0.0), - ); - } - - flash_attn_cpu( - q_data, - k_data, - v_data, - mask_data, - q.shape().dims(), - k.shape().dims(), - v.shape().dims(), - q_stride, - k_stride, - v_stride, - softmax_scale, - max_bias.unwrap_or(0.0), - softcap.unwrap_or(0.0), - ) -} - -/// Optimised path for the common decode case: q_len == 1 but kv_len ≫ 1. -/// We drop the inner q‑position loop and parallelise over `(batch, head)`. -#[allow(clippy::too_many_arguments)] -fn flash_attn_cpu_single_q( - q_data: &[T], - k_data: &[T], - v_data: &[T], - mask_vec: Option<&[T]>, - qshape: &[usize], - kshape: &[usize], - vshape: &[usize], - qstride: &[usize], - kstride: &[usize], - vstride: &[usize], - scale: f32, - max_bias: f32, - logit_softcap: f32, -) -> Result { - // Shapes: (B, 1, H, D) - let (b, _q_len, h, d) = ( - qshape[0], qshape[1], // == 1 - qshape[2], qshape[3], - ); - let kv_len = kshape[1]; - let k_h = kshape[2]; - let v_h = vshape[2]; - let rk2 = h / k_h; - let rv2 = h / v_h; - let dv = d; - - let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); - - // Output buffer: (B, H, 1, D) - let mut out = vec![0f32; b * h * dv]; - - // Expose a second dimension of work: split the KV axis into tiles that - // fit in the last‑level cache and let Rayon schedule them. - let kv_tiles = kv_len.div_ceil(TILE_KV); - - // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut slices, so no two - // threads write the same output area. - FLASH_ATTN_POOL.install(|| { - out.par_chunks_mut(dv) - .with_min_len(64) - .enumerate() - .for_each(|(row_idx, out_chunk)| { - let b_i = row_idx / h; - let h_i = row_idx % h; - - // ALiBi positional bias (standard formula) - let slope = if max_bias > 0.0 { - 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) - } else { - 1.0 - }; - - // For grouped‑KV we collapse multiple query heads into the same K/V head. - let k_head = h_i / rk2; - let v_head = h_i / rv2; - - // ------------------------------------------------------------------ - // Nested parallelism: each KV tile is mapped independently, then we - // reduce the partial results with the correct soft‑max algebra. - // ------------------------------------------------------------------ - let (vkq, s_tot, _m_tot) = (0..kv_tiles) - .into_par_iter() - .map(|tile_idx| { - // ---- per‑tile scratch ------------------------------------------------- - let start = tile_idx * TILE_KV; - let end = (start + TILE_KV).min(kv_len); - - let mut vkq = vec![0f32; dv]; - let mut s = 0.0f32; - let mut m = f32::NEG_INFINITY; - - // ---------------- single‑Q row (already contiguous) ------------------- - let q_base = - b_i * qstride[0] /*batch*/ + h_i * qstride[2] /*head*/; - let q_row = &q_data[q_base..q_base + d]; - - // ---------------- iterate over this KV slice -------------------------- - for kv_pos in start..end { - // Mask - let mv = if let Some(mv_vec) = mask_vec { - let mval = mv_vec[(b_i * kv_len) + kv_pos]; - slope * mval.to_f64() as f32 - } else { - 0.0 - }; - if mv == f32::NEG_INFINITY { - continue; - } - - // K row - let k_base = - b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; - let k_row = &k_data[k_base..k_base + d]; - - // dot(Q, K) - let mut s_val = vec_dot::(q_row, k_row).to_f64() as f32; - - let mut scale_applied = scale; - if logit_softcap != 0.0 { - scale_applied /= logit_softcap; - } - s_val *= scale_applied; - if logit_softcap != 0.0 { - s_val = logit_softcap * s_val.tanh(); - } - s_val += mv; - - // Tile‑local online softmax ------------------------------------------ - let m_old = m; - let mut ms = 1.0f32; - let mut vs = 1.0f32; - if s_val > m { - m = s_val; - ms = (m_old - m).exp(); - for v in vkq.iter_mut() { - *v *= ms; - } - } else { - vs = (s_val - m).exp(); - } - - // V row - let v_base = - b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; - for d_i in 0..dv { - vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; - } - - s = s * ms + vs; - } - - // Return per‑tile accumulator + softmax stats - (vkq, s, m) - }) - // -------- reduce two tiles ----------------------------------------------- - .reduce( - || (vec![0f32; dv], 0.0f32, f32::NEG_INFINITY), - |mut a, b| { - let (ref mut vkq_a, mut s_a, m_a) = a; - let (vkq_b, s_b, m_b) = b; - if m_a >= m_b { - let factor = (m_b - m_a).exp(); - for (va, vb) in vkq_a.iter_mut().zip(vkq_b) { - *va += vb * factor; - } - s_a += s_b * factor; - (vkq_a.clone(), s_a, m_a) - } else { - let factor = (m_a - m_b).exp(); - let mut vkq_new = vkq_b; - for (vb, va) in vkq_new.iter_mut().zip(vkq_a) { - *vb += *va * factor; - } - (vkq_new, s_b + s_a * factor, m_b) - } - }, - ); - - // ---------------- final normalisation --------------------------------------- - let inv_s = 1.0 / s_tot; - for v in out_chunk.iter_mut().zip(vkq.iter()) { - *v.0 = *v.1 * inv_s; - } - }); - }); - - let out_shape = (b, h, 1usize, dv); - Tensor::from_vec(out, out_shape, &Device::Cpu) -} - -/// Main forward flash-attention CPU routine. -/// Shapes follow Candle convention: (B, S, H, D) -#[allow(clippy::too_many_arguments)] -fn flash_attn_cpu( - q_data: &[T], - k_data: &[T], - v_data: &[T], - mask_vec: Option<&[T]>, - qshape: &[usize], - kshape: &[usize], - vshape: &[usize], - qstride: &[usize], - kstride: &[usize], - vstride: &[usize], - scale: f32, - max_bias: f32, - logit_softcap: f32, -) -> Result { - let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]); - let kv_len = kshape[1]; - // --- Head broadcasting factors ---------------------------------------------------- - // Allows K and V to have fewer heads than Q (grouped‑KV); the ratio is an - // integer factor. rk2 = #Q‑heads / #K‑heads, rv2 = #Q‑heads / #V‑heads. - let k_h = kshape[2]; - let v_h = vshape[2]; - let rk2 = h / k_h; // must divide exactly; panic otherwise - let rv2 = h / v_h; - let dv = d; // value dim = key dim in this kernel - - // Precompute value for ALiBi slope calculation - let n2 = 2_usize.pow((h as f32).log2().ceil() as u32); - - let mut out = vec![0f32; b * q_len * h * dv]; - - // ------------------------------------------------------------------ - // Rayon‑parallel version: each (b_i, h_i, q_pos) row is independent. - // ------------------------------------------------------------------ - - let _rows = b * h * q_len; // total independent work items - - // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut [f32] slices, - // so no two threads can write the same output area. - FLASH_ATTN_POOL.install(|| { - out.par_chunks_mut(dv) - .with_min_len(64) - .enumerate() - .for_each(|(row_idx, out_chunk)| { - // Decode flat index back to (batch, head, q_pos) - let rows_per_batch = h * q_len; - let b_i = row_idx / rows_per_batch; - let rem = row_idx % rows_per_batch; - let h_i = rem / q_len; - let q_pos = rem % q_len; - - let slope = if max_bias > 0.0 { - 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32) - } else { - 1.0 - }; - - // For grouped‑KV we collapse multiple query heads into the same K/V head. - let k_head = h_i / rk2; - let v_head = h_i / rv2; - - // Buffers local to this row - let mut vkq = vec![0f32; dv]; - let mut s = 0.0f32; - let mut m = f32::NEG_INFINITY; - - // Allocate q_row and k_row once per row - let mut q_row: Vec = Vec::with_capacity(d); - let mut k_row: Vec = Vec::with_capacity(d); - - // ------------------- gather Q (strided) -------------------- - let q_base = b_i * qstride[0] + q_pos * qstride[1] + h_i * qstride[2]; - q_row.clear(); - for di in 0..d { - q_row.push(q_data[q_base + di * qstride[3]]); - } - - // ---------------- iterate over keys/values ----------------- - for kv_pos in 0..kv_len { - // Mask (optional) - let mv = if let Some(mv_vec) = mask_vec { - let mval = mv_vec[((b_i * q_len + q_pos) * kv_len) + kv_pos]; - slope * mval.to_f64() as f32 - } else { - 0.0 - }; - if mv == f32::NEG_INFINITY { - continue; - } - - // K row (strided) - let k_base = b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2]; - k_row.clear(); - for di in 0..d { - k_row.push(k_data[k_base + di * kstride[3]]); - } - - // dot(Q, K) - let mut s_val = vec_dot::(&q_row, &k_row); - let mut scale_applied = scale; - if logit_softcap != 0.0 { - scale_applied /= logit_softcap; - } - s_val *= T::from_f64(scale_applied as f64); - if logit_softcap != 0.0 { - s_val = T::from_f64(logit_softcap as f64 * s_val.to_f64().tanh()); - } - s_val += T::from_f64(mv as f64); - - // online softmax - let m_old = m; - let mut ms = 1.0f32; - let mut vs = 1.0f32; - if s_val.to_f64() as f32 > m { - m = s_val.to_f64() as f32; - ms = (m_old - m).exp(); - for v in vkq.iter_mut() { - *v *= ms; - } - } else { - vs = (s_val.to_f64() as f32 - m).exp(); - } - - // V row (strided) - let v_base = b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2]; - for d_i in 0..dv { - vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs; - } - - s = s * ms + vs; - } - - // ------------------- normalise & write out ------------------ - let inv_s = 1.0 / s; - for v in vkq.iter_mut() { - *v *= inv_s; - } - out_chunk.copy_from_slice(&vkq); - }); - }); - - // Build output tensor with shape (B, H, S, D) to match standard (permute 0,2,1,3) - let out_shape = (b, h, q_len, dv); - Tensor::from_vec(out, out_shape, &Device::Cpu) -} +//! Backward compatibility shim for CPU flash attention. +//! +//! **Deprecated:** Use `candle_nn::attention::{flash_attn, AttnMask}` instead. + +#[deprecated( + since = "0.9.2", + note = "Use `candle_nn::attention::{flash_attn, AttnMask}` instead" +)] +pub use crate::attention::cpu_flash::standard::run_flash_attn_cpu; diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index c7a76fbd7a..ad178b0d4e 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -16,6 +16,7 @@ //! pub mod activation; +pub mod attention; pub mod batch_norm; pub mod conv; pub mod cpu_flash_attention; diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 9f018939ae..99905b410c 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -6,6 +6,12 @@ use candle::{DType, Device, Module, Result, Tensor}; use candle_nn::{kv_cache::ConcatKvCache, Activation, VarBuilder}; use std::sync::Arc; +#[cfg(feature = "flash-attn")] +use candle_flash_attn; + +#[cfg(not(feature = "flash-attn"))] +use candle_nn::attention::{flash_attn, AttnMask}; + #[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { pub vocab_size: usize, @@ -24,6 +30,8 @@ pub struct Config { pub rms_norm_eps: f64, pub use_sliding_window: bool, pub hidden_act: Activation, + #[serde(default)] + pub use_flash_attn: bool, } #[derive(Debug, Clone)] @@ -106,6 +114,7 @@ pub(crate) struct Qwen3Attention { num_kv_groups: usize, head_dim: usize, hidden_size: usize, + use_flash_attn: bool, // utils rotary_emb: Arc, kv_cache: ConcatKvCache, @@ -173,6 +182,7 @@ impl Qwen3Attention { num_kv_groups, head_dim, hidden_size, + use_flash_attn: cfg.use_flash_attn, rotary_emb, kv_cache, }) @@ -202,8 +212,8 @@ impl Qwen3Attention { .reshape((b, l, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - // 3. Per‑head RMSNorm - let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later + // 3. Per-head RMSNorm + let q_flat = q.flatten(0, 2)?; let k_flat = k.flatten(0, 2)?; let q_flat = self.q_norm.forward(&q_flat)?; let k_flat = self.k_norm.forward(&k_flat)?; @@ -216,11 +226,159 @@ impl Qwen3Attention { // 5. Accumulate KV cache let (k, v) = self.kv_cache.append(&k, &v)?; - // 6. GQA repeat_kv - let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; - let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; + // 6. Attention dispatch based on device and features + let on_cpu = x.device().is_cpu(); + + if on_cpu { + if self.use_flash_attn { + // CPU with flash flag: use optimized CPU flash attention + self.forward_cpu_flash_attn(&q, &k, &v, offset, b, l) + } else { + // CPU without flash flag: use standard matmul (for comparison/testing) + self.forward_standard_attn(&q, &k, &v, attn_mask, b, l) + } + } else if self.use_flash_attn { + // GPU with flash-attn flag: use GPU flash attention + self.forward_flash_attn(&q, &k, &v, offset, b, l) + } else { + // GPU without flash-attn: use standard matmul attention + self.forward_standard_attn(&q, &k, &v, attn_mask, b, l) + } + } + + /// GPU flash attention path (requires flash-attn feature) + #[cfg(feature = "flash-attn")] + fn forward_flash_attn( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + _offset: usize, + b: usize, + l: usize, + ) -> Result { + // Flash attention expects (B, S, H, D) format + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + + let scale = 1.0 / (self.head_dim as f32).sqrt(); + + let ctx = candle_flash_attn::flash_attn(&q, &k, &v, scale, true)?; + + // Output: (B, S, H, D) -> (B, L, hidden_size) + ctx.reshape((b, l, self.hidden_size))?.apply(&self.o_proj) + } + + /// Fallback when flash-attn feature not enabled but use_flash_attn was requested + #[cfg(not(feature = "flash-attn"))] + fn forward_flash_attn( + &self, + _q: &Tensor, + _k: &Tensor, + _v: &Tensor, + _offset: usize, + _b: usize, + _l: usize, + ) -> Result { + candle::bail!( + "use_flash_attn=true requires compiling with --features flash-attn. \ + For CPU, omit --use-flash-attn flag to use optimized CPU attention." + ) + } + + /// CPU flash attention - optimized fused kernel for CPU + #[cfg(not(feature = "flash-attn"))] + fn forward_cpu_flash_attn( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + offset: usize, + b: usize, + l: usize, + ) -> Result { + // CPU flash attention expects (B, S, H, D) format + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + + let scale = 1.0 / (self.head_dim as f32).sqrt(); + + let ctx = match q.dtype() { + DType::F32 => flash_attn::( + &q, + &k, + &v, + scale, + AttnMask::causal_with_offset(offset), + None, + None, + )?, + DType::F64 => flash_attn::( + &q, + &k, + &v, + scale, + AttnMask::causal_with_offset(offset), + None, + None, + )?, + DType::BF16 => { + let q_f32 = q.to_dtype(DType::F32)?; + let k_f32 = k.to_dtype(DType::F32)?; + let v_f32 = v.to_dtype(DType::F32)?; + let ctx_f32 = flash_attn::( + &q_f32, + &k_f32, + &v_f32, + scale, + AttnMask::causal_with_offset(offset), + None, + None, + )?; + ctx_f32.to_dtype(DType::BF16)? + } + dtype => candle::bail!("Unsupported dtype for CPU flash attention: {:?}", dtype), + }; + + // Output from CPU flash attention is (B, H, S, D), transpose to (B, S, H, D) + let ctx = ctx.transpose(1, 2)?; + + ctx.reshape((b, l, self.hidden_size))?.apply(&self.o_proj) + } + + /// Stub for when flash-attn is enabled (CPU path not needed) + #[cfg(feature = "flash-attn")] + fn forward_cpu_flash_attn( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + _offset: usize, + b: usize, + l: usize, + ) -> Result { + // When flash-attn feature is enabled, fall back to standard attention for CPU + // This path is rarely hit since GPU is typically used with flash-attn + self.forward_standard_attn(q, k, v, None, b, l) + } + + /// Standard matmul-based attention (works on any device) + fn forward_standard_attn( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + attn_mask: Option<&Tensor>, + b: usize, + l: usize, + ) -> Result { + // GQA repeat_kv + let k = repeat_kv(k.clone(), self.num_kv_groups)?.contiguous()?; + let v = repeat_kv(v.clone(), self.num_kv_groups)?.contiguous()?; - // 7. Attention score + // Attention score let scale = 1.0 / (self.head_dim as f64).sqrt(); let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; if let Some(m) = attn_mask { @@ -229,7 +387,7 @@ impl Qwen3Attention { let probs = candle_nn::ops::softmax_last_dim(&scores)?; let ctx = probs.matmul(&v)?; // (B, H, L, D) - // 8. Output proj + // Output proj ctx.transpose(1, 2)? .reshape((b, l, self.hidden_size))? .apply(&self.o_proj) @@ -287,6 +445,7 @@ pub struct Model { norm: RmsNorm, device: Device, dtype: DType, + use_flash_attn: bool, } impl Model { @@ -305,6 +464,7 @@ impl Model { norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, device: vb.device().clone(), dtype: vb.dtype(), + use_flash_attn: cfg.use_flash_attn, }) } @@ -345,10 +505,13 @@ impl Model { let (b, l) = input.dims2()?; let mut h = self.embed_tokens.forward(input)?; - let causal = if l == 1 { - None - } else { + // Build causal mask for standard attention path + // Flash attention (CPU or GPU) handles masking internally + let needs_mask = !self.use_flash_attn && l > 1; + let causal = if needs_mask { Some(self.causal_mask(b, l, offset, None)?) + } else { + None }; for layer in &mut self.layers { diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index b76ce92de4..eba715ba2b 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -51,6 +51,7 @@ impl From<&Config> for Qwen3Config { rms_norm_eps: val.rms_norm_eps, use_sliding_window: val.use_sliding_window, hidden_act: val.hidden_act, + use_flash_attn: false, } } }