From 900e0d81801b4bcc1172560086224b40d3307263 Mon Sep 17 00:00:00 2001 From: Junjun Dong Date: Mon, 1 Dec 2025 18:26:19 -0800 Subject: [PATCH] fix: add gguf wrapper in quantized gemma3 --- .../src/models/quantized_gemma3.rs | 68 +++++++++++-------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/candle-transformers/src/models/quantized_gemma3.rs b/candle-transformers/src/models/quantized_gemma3.rs index bc5b9e7ff0..270d1e6c71 100644 --- a/candle-transformers/src/models/quantized_gemma3.rs +++ b/candle-transformers/src/models/quantized_gemma3.rs @@ -20,6 +20,7 @@ use candle::quantized::QTensor; use candle::D; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module}; +use std::io::{Read, Seek}; pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6; @@ -27,6 +28,26 @@ pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.; pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.; pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.; +struct Gguf { + ct: gguf_file::Content, + reader: R, + device: Device, +} + +impl Gguf { + fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { + Self { ct, reader, device } + } + + fn metadata(&self) -> &std::collections::HashMap { + &self.ct.metadata + } + + fn tensor(&mut self, name: &str) -> Result { + self.ct.tensor(&mut self.reader, name, &self.device) + } +} + #[derive(Debug, Clone)] struct QMatMul { inner: candle::quantized::QMatMul, @@ -263,7 +284,8 @@ impl ModelWeights { reader: &mut R, device: &Device, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { + let mut gg = Gguf::new(ct, reader, device.clone()); + let md_get = |s: &str| match gg.metadata().get(s) { None => candle::bail!("cannot find {s} in metadata"), Some(v) => Ok(v), }; @@ -301,66 +323,56 @@ impl ModelWeights { let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; // Load token embeddings and output projection - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = gg.tensor("token_embd.weight")?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = RmsNorm::from_qtensor( - ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, - )?; - let output = match ct.tensor(reader, "output.weight", device) { + let norm = RmsNorm::from_qtensor(gg.tensor("output_norm.weight")?, rms_norm_eps)?; + let output = match gg.tensor("output.weight") { Ok(tensor) => tensor, - Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist + Err(_) => gg.tensor("token_embd.weight")?, // Use tied weights if output.weight doesn't exist }; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wo = - ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wq = gg.tensor(&format!("{prefix}.attn_q.weight"))?; + let attention_wk = gg.tensor(&format!("{prefix}.attn_k.weight"))?; + let attention_wv = gg.tensor(&format!("{prefix}.attn_v.weight"))?; + let attention_wo = gg.tensor(&format!("{prefix}.attn_output.weight"))?; let attention_q_norm = RmsNorm::from_qtensor( - ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?, + gg.tensor(&format!("{prefix}.attn_q_norm.weight"))?, rms_norm_eps, )?; let attention_k_norm = RmsNorm::from_qtensor( - ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?, + gg.tensor(&format!("{prefix}.attn_k_norm.weight"))?, rms_norm_eps, )?; let attention_norm = RmsNorm::from_qtensor( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + gg.tensor(&format!("{prefix}.attn_norm.weight"))?, rms_norm_eps, )?; let post_attention_norm = RmsNorm::from_qtensor( - ct.tensor( - reader, - &format!("{prefix}.post_attention_norm.weight"), - device, - )?, + gg.tensor(&format!("{prefix}.post_attention_norm.weight"))?, rms_norm_eps, )?; let ffn_norm = RmsNorm::from_qtensor( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + gg.tensor(&format!("{prefix}.ffn_norm.weight"))?, rms_norm_eps, )?; let post_ffn_norm = RmsNorm::from_qtensor( - ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?, + gg.tensor(&format!("{prefix}.post_ffw_norm.weight"))?, rms_norm_eps, )?; - let feed_forward_gate = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; - let feed_forward_down = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_gate = gg.tensor(&format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_up = gg.tensor(&format!("{prefix}.ffn_up.weight"))?; + let feed_forward_down = gg.tensor(&format!("{prefix}.ffn_down.weight"))?; let mlp = Mlp { feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,