diff --git a/rust/moshi-core/src/transformer.rs b/rust/moshi-core/src/transformer.rs
index c084643..9bc4433 100644
--- a/rust/moshi-core/src/transformer.rs
+++ b/rust/moshi-core/src/transformer.rs
@@ -10,9 +10,9 @@ use crate::nn::{
 use crate::streaming::{StreamTensor, StreamingModule};
 use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
 
+use candle::Context;
 use candle_nn;
 use std::sync::Arc;
-
 #[derive(Debug, Clone, serde::Deserialize)]
 pub struct Config {
     pub d_model: usize,
@@ -54,6 +54,8 @@ pub enum CrossAttentionGating {
     ConstantGatedSigmoid,
     ConditionalGatedTanh,
     ConditionalGatedSigmoid,
+    ConditionalGatedSigmoidLearnableBias,
+    ConditionalGatedTanhLearnableBias,
 }
 
 #[derive(Debug, Clone)]
@@ -100,45 +102,54 @@ pub enum XaGate {
         in_proj: MaybeQuantizedLinear,
         out_proj: MaybeQuantizedLinear,
         activation: candle_nn::init::NonLinearity,
+        learnable_bias: bool,
     },
 }
 
 impl XaGate {
     pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
-        match cfg.cross_attention.map(|v| v.0) {
-            // no cross attention - shouldn't occur here
-            None => candle::bail!("Invalid cross-attention config specified."),
+        let gating_cfg =
+            cfg.cross_attention.map(|v| v.0).context("no cross-attention specified")?;
+        match gating_cfg {
             // no gating
-            Some(CrossAttentionGating::Normal) => Ok(Self::Normal),
+            CrossAttentionGating::Normal => Ok(Self::Normal),
             // constant (per-layer parameter) with tanh activation
-            Some(CrossAttentionGating::ConstantGatedTanh) => {
-                let alpha = vb.get_unquantized((1, 1, 1), "gate.alpha")?.tanh()?;
+            CrossAttentionGating::ConstantGatedTanh => {
+                let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?;
                 Ok(Self::ConstantGated { alpha })
             }
             // constant (per-layer parameter) with sigmoid activation
-            Some(CrossAttentionGating::ConstantGatedSigmoid) => {
-                let alpha = candle_nn::ops::sigmoid(
-                    &(vb.get_unquantized((1, 1, 1), "gate.alpha")? - 4.0)?,
-                )?;
+            CrossAttentionGating::ConstantGatedSigmoid => {
+                let alpha =
+                    candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?;
                 Ok(Self::ConstantGated { alpha })
             }
             // input conditional (small MLP) with tanh or sigmoid act
-            Some(CrossAttentionGating::ConditionalGatedTanh)
-            | Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
+            CrossAttentionGating::ConditionalGatedTanh
+            | CrossAttentionGating::ConditionalGatedSigmoid
+            | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
+            | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
                 let dim = cfg.d_model;
                 let hidden_dims = (0.125 * dim as f32).floor() as usize;
-                let in_proj = linear(dim, hidden_dims, false, vb.pp("gate.alpha.0"))?;
-                let out_proj = linear(hidden_dims, dim, false, vb.pp("gate.alpha.2"))?;
-                let activation = match cfg.cross_attention.map(|v| v.0) {
-                    Some(CrossAttentionGating::ConditionalGatedTanh) => {
+                let learnable_bias = matches!(
+                    gating_cfg,
+                    CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
+                        | CrossAttentionGating::ConditionalGatedTanhLearnableBias
+                );
+                let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?;
+                let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?;
+                let activation = match gating_cfg {
+                    CrossAttentionGating::ConditionalGatedTanh
+                    | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
                         candle_nn::init::NonLinearity::Tanh
                     }
-                    Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
+                    CrossAttentionGating::ConditionalGatedSigmoid
+                    | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => {
                         candle_nn::init::NonLinearity::Sigmoid
                     }
                     _ => candle::bail!("Invalid cross-attention config specified."),
                 };
-                Ok(Self::ConditionalGated { in_proj, out_proj, activation })
+                Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias })
             }
         }
     }
@@ -149,11 +160,14 @@ impl Module for XaGate {
         match self {
             Self::Normal => Ok(xs.clone()),
             Self::ConstantGated { alpha } => xs.broadcast_mul(alpha),
-            Self::ConditionalGated { in_proj, out_proj, activation } => {
+            Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias } => {
                 let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?;
-                let alpha = match activation {
-                    candle_nn::init::NonLinearity::Tanh => alpha.tanh(),
-                    candle_nn::init::NonLinearity::Sigmoid => {
+                let alpha = match (activation, learnable_bias) {
+                    (candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh(),
+                    (candle_nn::init::NonLinearity::Sigmoid, true) => {
+                        candle_nn::ops::sigmoid(&alpha)
+                    }
+                    (candle_nn::init::NonLinearity::Sigmoid, false) => {
                         candle_nn::ops::sigmoid(&(alpha - 4.0)?)
                     }
                     _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"),
@@ -180,7 +194,11 @@ pub struct StreamingMultiheadCrossAttention {
 }
 
 impl StreamingMultiheadCrossAttention {
-    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
+    pub fn new(
+        cfg: &Config,
+        vb: MaybeQuantizedVarBuilder,
+        gate_vb: Option<MaybeQuantizedVarBuilder>,
+    ) -> Result<Self> {
         let embed_dim = cfg.d_model;
         let num_kv = cfg.num_heads / cfg.kv_repeat;
         let out_kv_dim = num_kv * (embed_dim / cfg.num_heads);
@@ -244,7 +262,10 @@ impl StreamingMultiheadCrossAttention {
             MaybeQuantizedVarBuilder::Real(weights) => neg_inf.to_dtype(weights.dtype())?,
             _ => neg_inf,
         };
-        let gate = XaGate::new(cfg, vb)?;
+        let gate = match gate_vb {
+            None => XaGate::new(cfg, vb.pp("gate"))?,
+            Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?,
+        };
         Ok(Self {
             in_proj_q,
             in_proj_kv,
@@ -690,6 +711,7 @@ impl StreamingTransformerLayer {
         rope: &Option<Arc<RotaryEmbedding>>,
         cfg: &Config,
         vb: MaybeQuantizedVarBuilder,
+        shared_ca_vb: Option<MaybeQuantizedVarBuilder>,
     ) -> Result<Self> {
         if cfg.use_conv_block {
             candle::bail!("conv-block is not supported")
@@ -716,8 +738,16 @@ impl StreamingTransformerLayer {
         let cross_attn = match cfg.cross_attention.map(|v| v.1) {
             Some(norm_type) => {
                 let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?;
-                let cross_attn =
-                    StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
+                let cross_attn = match shared_ca_vb {
+                    None => {
+                        StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)?
+                    }
+                    Some(shared_vb) => StreamingMultiheadCrossAttention::new(
+                        cfg,
+                        shared_vb.pp("cross_attention"),
+                        Some(vb.pp("cross_attention.gate")),
+                    )?,
+                };
                 Some((norm_cross, cross_attn))
             }
             None => None,
@@ -799,7 +829,10 @@ impl StreamingTransformer {
         };
         let mut layers = Vec::with_capacity(cfg.num_layers);
         for layer_idx in 0..cfg.num_layers {
-            let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
+            // Also send weights of first layer as only it contains the KQV proj weights
+            // for shared cross-attention layers
+            let layer =
+                StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx), Some(vb_l.pp(0)))?;
             layers.push(layer)
         }
         Ok(Self {