Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions candle-transformers/src/models/quantized_gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,34 @@ use candle::quantized::QTensor;
use candle::D;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
use std::io::{Read, Seek};

pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;

struct Gguf<R: Read + Seek> {
ct: gguf_file::Content,
reader: R,
device: Device,
}

impl<R: Read + Seek> Gguf<R> {
fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
Self { ct, reader, device }
}

fn metadata(&self) -> &std::collections::HashMap<String, gguf_file::Value> {
&self.ct.metadata
}

fn tensor(&mut self, name: &str) -> Result<QTensor> {
self.ct.tensor(&mut self.reader, name, &self.device)
}
}

#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
Expand Down Expand Up @@ -263,7 +284,8 @@ impl ModelWeights {
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
let mut gg = Gguf::new(ct, reader, device.clone());
let md_get = |s: &str| match gg.metadata().get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
Expand Down Expand Up @@ -301,66 +323,56 @@ impl ModelWeights {
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;

// Load token embeddings and output projection
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = gg.tensor("token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = match ct.tensor(reader, "output.weight", device) {
let norm = RmsNorm::from_qtensor(gg.tensor("output_norm.weight")?, rms_norm_eps)?;
let output = match gg.tensor("output.weight") {
Ok(tensor) => tensor,
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
Err(_) => gg.tensor("token_embd.weight")?, // Use tied weights if output.weight doesn't exist
};

let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");

let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let attention_wq = gg.tensor(&format!("{prefix}.attn_q.weight"))?;
let attention_wk = gg.tensor(&format!("{prefix}.attn_k.weight"))?;
let attention_wv = gg.tensor(&format!("{prefix}.attn_v.weight"))?;
let attention_wo = gg.tensor(&format!("{prefix}.attn_output.weight"))?;

let attention_q_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
gg.tensor(&format!("{prefix}.attn_q_norm.weight"))?,
rms_norm_eps,
)?;

let attention_k_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
gg.tensor(&format!("{prefix}.attn_k_norm.weight"))?,
rms_norm_eps,
)?;

let attention_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
gg.tensor(&format!("{prefix}.attn_norm.weight"))?,
rms_norm_eps,
)?;

let post_attention_norm = RmsNorm::from_qtensor(
ct.tensor(
reader,
&format!("{prefix}.post_attention_norm.weight"),
device,
)?,
gg.tensor(&format!("{prefix}.post_attention_norm.weight"))?,
rms_norm_eps,
)?;

let ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
gg.tensor(&format!("{prefix}.ffn_norm.weight"))?,
rms_norm_eps,
)?;

let post_ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
gg.tensor(&format!("{prefix}.post_ffw_norm.weight"))?,
rms_norm_eps,
)?;

let feed_forward_gate =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let feed_forward_down =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_gate = gg.tensor(&format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_up = gg.tensor(&format!("{prefix}.ffn_up.weight"))?;
let feed_forward_down = gg.tensor(&format!("{prefix}.ffn_down.weight"))?;

let mlp = Mlp {
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
Expand Down