Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include some recent transformer changes. #235

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 61 additions & 28 deletions rust/moshi-core/src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use crate::nn::{
use crate::streaming::{StreamTensor, StreamingModule};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};

use candle::Context;
use candle_nn;
use std::sync::Arc;

#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub d_model: usize,
Expand Down Expand Up @@ -54,6 +54,8 @@ pub enum CrossAttentionGating {
ConstantGatedSigmoid,
ConditionalGatedTanh,
ConditionalGatedSigmoid,
ConditionalGatedSigmoidLearnableBias,
ConditionalGatedTanhLearnableBias,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -100,45 +102,54 @@ pub enum XaGate {
in_proj: MaybeQuantizedLinear,
out_proj: MaybeQuantizedLinear,
activation: candle_nn::init::NonLinearity,
learnable_bias: bool,
},
}

impl XaGate {
pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
match cfg.cross_attention.map(|v| v.0) {
// no cross attention - shouldn't occur here
None => candle::bail!("Invalid cross-attention config specified."),
let gating_cfg =
cfg.cross_attention.map(|v| v.0).context("no cross-attention specified")?;
match gating_cfg {
// no gating
Some(CrossAttentionGating::Normal) => Ok(Self::Normal),
CrossAttentionGating::Normal => Ok(Self::Normal),
// constant (per-layer parameter) with tanh activation
Some(CrossAttentionGating::ConstantGatedTanh) => {
let alpha = vb.get_unquantized((1, 1, 1), "gate.alpha")?.tanh()?;
CrossAttentionGating::ConstantGatedTanh => {
let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?;
Ok(Self::ConstantGated { alpha })
}
// constant (per-layer parameter) with sigmoid activation
Some(CrossAttentionGating::ConstantGatedSigmoid) => {
let alpha = candle_nn::ops::sigmoid(
&(vb.get_unquantized((1, 1, 1), "gate.alpha")? - 4.0)?,
)?;
CrossAttentionGating::ConstantGatedSigmoid => {
let alpha =
candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?;
Ok(Self::ConstantGated { alpha })
}
// input conditional (small MLP) with tanh or sigmoid act
Some(CrossAttentionGating::ConditionalGatedTanh)
| Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
CrossAttentionGating::ConditionalGatedTanh
| CrossAttentionGating::ConditionalGatedSigmoid
| CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
| CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
let dim = cfg.d_model;
let hidden_dims = (0.125 * dim as f32).floor() as usize;
let in_proj = linear(dim, hidden_dims, false, vb.pp("gate.alpha.0"))?;
let out_proj = linear(hidden_dims, dim, false, vb.pp("gate.alpha.2"))?;
let activation = match cfg.cross_attention.map(|v| v.0) {
Some(CrossAttentionGating::ConditionalGatedTanh) => {
let learnable_bias = matches!(
gating_cfg,
CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
| CrossAttentionGating::ConditionalGatedTanhLearnableBias
);
let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?;
let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?;
let activation = match gating_cfg {
CrossAttentionGating::ConditionalGatedTanh
| CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
candle_nn::init::NonLinearity::Tanh
}
Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
CrossAttentionGating::ConditionalGatedSigmoid
| CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => {
candle_nn::init::NonLinearity::Sigmoid
}
_ => candle::bail!("Invalid cross-attention config specified."),
};
Ok(Self::ConditionalGated { in_proj, out_proj, activation })
Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias })
}
}
}
Expand All @@ -149,11 +160,14 @@ impl Module for XaGate {
match self {
Self::Normal => Ok(xs.clone()),
Self::ConstantGated { alpha } => xs.broadcast_mul(alpha),
Self::ConditionalGated { in_proj, out_proj, activation } => {
Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias } => {
let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?;
let alpha = match activation {
candle_nn::init::NonLinearity::Tanh => alpha.tanh(),
candle_nn::init::NonLinearity::Sigmoid => {
let alpha = match (activation, learnable_bias) {
(candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh(),
(candle_nn::init::NonLinearity::Sigmoid, true) => {
candle_nn::ops::sigmoid(&alpha)
}
(candle_nn::init::NonLinearity::Sigmoid, false) => {
candle_nn::ops::sigmoid(&(alpha - 4.0)?)
}
_ => candle::bail!("Invalid non-linearity specified in cross-attention gating"),
Expand All @@ -180,7 +194,11 @@ pub struct StreamingMultiheadCrossAttention {
}

impl StreamingMultiheadCrossAttention {
pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
pub fn new(
cfg: &Config,
vb: MaybeQuantizedVarBuilder,
gate_vb: Option<MaybeQuantizedVarBuilder>,
) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let out_kv_dim = num_kv * (embed_dim / cfg.num_heads);
Expand Down Expand Up @@ -244,7 +262,10 @@ impl StreamingMultiheadCrossAttention {
MaybeQuantizedVarBuilder::Real(weights) => neg_inf.to_dtype(weights.dtype())?,
_ => neg_inf,
};
let gate = XaGate::new(cfg, vb)?;
let gate = match gate_vb {
None => XaGate::new(cfg, vb.pp("gate"))?,
Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?,
};
Ok(Self {
in_proj_q,
in_proj_kv,
Expand Down Expand Up @@ -690,6 +711,7 @@ impl StreamingTransformerLayer {
rope: &Option<Arc<RotaryEmbedding>>,
cfg: &Config,
vb: MaybeQuantizedVarBuilder,
shared_ca_vb: Option<MaybeQuantizedVarBuilder>,
) -> Result<Self> {
if cfg.use_conv_block {
candle::bail!("conv-block is not supported")
Expand All @@ -716,8 +738,16 @@ impl StreamingTransformerLayer {
let cross_attn = match cfg.cross_attention.map(|v| v.1) {
Some(norm_type) => {
let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?;
let cross_attn =
StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
let cross_attn = match shared_ca_vb {
None => {
StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)?
}
Some(shared_vb) => StreamingMultiheadCrossAttention::new(
cfg,
shared_vb.pp("cross_attention"),
Some(vb.pp("cross_attention.gate")),
)?,
};
Some((norm_cross, cross_attn))
}
None => None,
Expand Down Expand Up @@ -799,7 +829,10 @@ impl StreamingTransformer {
};
let mut layers = Vec::with_capacity(cfg.num_layers);
for layer_idx in 0..cfg.num_layers {
let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
// Also send weights of first layer as only it contains the KQV proj weights
// for shared cross-attention layers
let layer =
StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx), Some(vb_l.pp(0)))?;
layers.push(layer)
}
Ok(Self {
Expand Down
Loading