diff --git a/rust/moshi-core/src/transformer.rs b/rust/moshi-core/src/transformer.rs index c084643..9bc4433 100644 --- a/rust/moshi-core/src/transformer.rs +++ b/rust/moshi-core/src/transformer.rs @@ -10,9 +10,9 @@ use crate::nn::{ use crate::streaming::{StreamTensor, StreamingModule}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle::Context; use candle_nn; use std::sync::Arc; - #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { pub d_model: usize, @@ -54,6 +54,8 @@ pub enum CrossAttentionGating { ConstantGatedSigmoid, ConditionalGatedTanh, ConditionalGatedSigmoid, + ConditionalGatedSigmoidLearnableBias, + ConditionalGatedTanhLearnableBias, } #[derive(Debug, Clone)] @@ -100,45 +102,54 @@ pub enum XaGate { in_proj: MaybeQuantizedLinear, out_proj: MaybeQuantizedLinear, activation: candle_nn::init::NonLinearity, + learnable_bias: bool, }, } impl XaGate { pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { - match cfg.cross_attention.map(|v| v.0) { - // no cross attention - shouldn't occur here - None => candle::bail!("Invalid cross-attention config specified."), + let gating_cfg = + cfg.cross_attention.map(|v| v.0).context("no cross-attention specified")?; + match gating_cfg { // no gating - Some(CrossAttentionGating::Normal) => Ok(Self::Normal), + CrossAttentionGating::Normal => Ok(Self::Normal), // constant (per-layer parameter) with tanh activation - Some(CrossAttentionGating::ConstantGatedTanh) => { - let alpha = vb.get_unquantized((1, 1, 1), "gate.alpha")?.tanh()?; + CrossAttentionGating::ConstantGatedTanh => { + let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?; Ok(Self::ConstantGated { alpha }) } // constant (per-layer parameter) with sigmoid activation - Some(CrossAttentionGating::ConstantGatedSigmoid) => { - let alpha = candle_nn::ops::sigmoid( - &(vb.get_unquantized((1, 1, 1), "gate.alpha")? - 4.0)?, - )?; + CrossAttentionGating::ConstantGatedSigmoid => { + let alpha = + candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?; Ok(Self::ConstantGated { alpha }) } // input conditional (small MLP) with tanh or sigmoid act - Some(CrossAttentionGating::ConditionalGatedTanh) - | Some(CrossAttentionGating::ConditionalGatedSigmoid) => { + CrossAttentionGating::ConditionalGatedTanh + | CrossAttentionGating::ConditionalGatedSigmoid + | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias + | CrossAttentionGating::ConditionalGatedTanhLearnableBias => { let dim = cfg.d_model; let hidden_dims = (0.125 * dim as f32).floor() as usize; - let in_proj = linear(dim, hidden_dims, false, vb.pp("gate.alpha.0"))?; - let out_proj = linear(hidden_dims, dim, false, vb.pp("gate.alpha.2"))?; - let activation = match cfg.cross_attention.map(|v| v.0) { - Some(CrossAttentionGating::ConditionalGatedTanh) => { + let learnable_bias = matches!( + gating_cfg, + CrossAttentionGating::ConditionalGatedSigmoidLearnableBias + | CrossAttentionGating::ConditionalGatedTanhLearnableBias + ); + let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?; + let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?; + let activation = match gating_cfg { + CrossAttentionGating::ConditionalGatedTanh + | CrossAttentionGating::ConditionalGatedTanhLearnableBias => { candle_nn::init::NonLinearity::Tanh } - Some(CrossAttentionGating::ConditionalGatedSigmoid) => { + CrossAttentionGating::ConditionalGatedSigmoid + | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => { candle_nn::init::NonLinearity::Sigmoid } _ => candle::bail!("Invalid cross-attention config specified."), }; - Ok(Self::ConditionalGated { in_proj, out_proj, activation }) + Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias }) } } } @@ -149,11 +160,14 @@ impl Module for XaGate { match self { Self::Normal => Ok(xs.clone()), Self::ConstantGated { alpha } => xs.broadcast_mul(alpha), - Self::ConditionalGated { in_proj, out_proj, activation } => { + Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias } => { let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?; - let alpha = match activation { - candle_nn::init::NonLinearity::Tanh => alpha.tanh(), - candle_nn::init::NonLinearity::Sigmoid => { + let alpha = match (activation, learnable_bias) { + (candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh(), + (candle_nn::init::NonLinearity::Sigmoid, true) => { + candle_nn::ops::sigmoid(&alpha) + } + (candle_nn::init::NonLinearity::Sigmoid, false) => { candle_nn::ops::sigmoid(&(alpha - 4.0)?) } _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"), @@ -180,7 +194,11 @@ pub struct StreamingMultiheadCrossAttention { } impl StreamingMultiheadCrossAttention { - pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result { + pub fn new( + cfg: &Config, + vb: MaybeQuantizedVarBuilder, + gate_vb: Option, + ) -> Result { let embed_dim = cfg.d_model; let num_kv = cfg.num_heads / cfg.kv_repeat; let out_kv_dim = num_kv * (embed_dim / cfg.num_heads); @@ -244,7 +262,10 @@ impl StreamingMultiheadCrossAttention { MaybeQuantizedVarBuilder::Real(weights) => neg_inf.to_dtype(weights.dtype())?, _ => neg_inf, }; - let gate = XaGate::new(cfg, vb)?; + let gate = match gate_vb { + None => XaGate::new(cfg, vb.pp("gate"))?, + Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?, + }; Ok(Self { in_proj_q, in_proj_kv, @@ -690,6 +711,7 @@ impl StreamingTransformerLayer { rope: &Option>, cfg: &Config, vb: MaybeQuantizedVarBuilder, + shared_ca_vb: Option, ) -> Result { if cfg.use_conv_block { candle::bail!("conv-block is not supported") @@ -716,8 +738,16 @@ impl StreamingTransformerLayer { let cross_attn = match cfg.cross_attention.map(|v| v.1) { Some(norm_type) => { let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?; - let cross_attn = - StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; + let cross_attn = match shared_ca_vb { + None => { + StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)? + } + Some(shared_vb) => StreamingMultiheadCrossAttention::new( + cfg, + shared_vb.pp("cross_attention"), + Some(vb.pp("cross_attention.gate")), + )?, + }; Some((norm_cross, cross_attn)) } None => None, @@ -799,7 +829,10 @@ impl StreamingTransformer { }; let mut layers = Vec::with_capacity(cfg.num_layers); for layer_idx in 0..cfg.num_layers { - let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?; + // Also send weights of first layer as only it contains the KQV proj weights + // for shared cross-attention layers + let layer = + StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx), Some(vb_l.pp(0)))?; layers.push(layer) } Ok(Self {