Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::models::{
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig,
MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel,
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model, StaticEmbeddingConfig, StaticEmbeddingModel,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -113,6 +113,7 @@ enum Config {
Qwen3(Qwen3Config),
Roberta(BertConfig),
XlmRoberta(BertConfig),
StaticEmbedding(StaticEmbeddingConfig),
}

pub struct CandleBackend {
Expand All @@ -131,12 +132,17 @@ impl CandleBackend {
// Default files
let default_safetensors = model_path.join("model.safetensors");
let default_pytorch = model_path.join("pytorch_model.bin");
let static_embedding_safetensors = model_path
.join("0_StaticEmbedding")
.join("model.safetensors");

// Single Files
let model_files = if default_safetensors.exists() {
vec![default_safetensors]
} else if default_pytorch.exists() {
vec![default_pytorch]
} else if static_embedding_safetensors.exists() {
vec![static_embedding_safetensors]
}
// Sharded weights
else {
Expand Down Expand Up @@ -305,6 +311,12 @@ impl CandleBackend {
tracing::info!("Starting Qwen3 model on {:?}", device);
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
}
(Config::StaticEmbedding(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting StaticEmbedding model on {:?}", device);
Ok(Box::new(
StaticEmbeddingModel::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -509,6 +521,13 @@ impl CandleBackend {
))
}
}
#[cfg(feature = "cuda")]
(Config::StaticEmbedding(config), Device::Cuda(_)) => {
tracing::info!("Starting StaticEmbedding model on {:?}", device);
Ok(Box::new(
StaticEmbeddingModel::load(vb, &config, model_type).s()?,
))
}
};

let mut dense_layers = Vec::new();
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod mpnet;
mod nomic;
mod qwen2;
mod qwen3;
mod static_embedding;

#[cfg(feature = "cuda")]
mod flash_bert;
Expand Down Expand Up @@ -64,6 +65,7 @@ pub use mpnet::{MPNetConfig, MPNetModel};
pub use nomic::{NomicBertModel, NomicConfig};
pub use qwen2::Qwen2Config;
pub use qwen3::{Qwen3Config, Qwen3Model};
pub use static_embedding::{StaticEmbeddingConfig, StaticEmbeddingModel};

#[cfg(feature = "cuda")]
pub use flash_bert::FlashBertModel;
Expand Down
249 changes: 249 additions & 0 deletions backends/candle/src/models/static_embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct StaticEmbeddingConfig {
pub vocab_size: usize,
pub hidden_size: usize,
}

#[derive(Debug)]
pub struct StaticEmbedding {
embedding: Embedding,

span: tracing::Span,
}

impl StaticEmbedding {
pub fn load(
vb: VarBuilder,
config: &StaticEmbeddingConfig,
weight_name: String,
) -> Result<Self> {
Ok(Self {
embedding: Embedding::new(
vb.get((config.vocab_size, config.hidden_size), &weight_name)?,
config.hidden_size,
),
span: tracing::span!(tracing::Level::TRACE, "embedding"),
})
}

pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

self.embedding.forward(input_ids)
}
}

pub struct StaticEmbeddingModel {
pool: Pool,
embedding: StaticEmbedding,

device: Device,
dtype: DType,

span: tracing::Span,
}

impl StaticEmbeddingModel {
pub fn load(
vb: VarBuilder,
config: &StaticEmbeddingConfig,
model_type: ModelType,
) -> Result<Self> {
let pool = match model_type {
ModelType::Classifier => {
candle::bail!("`Classifier` model type is not supported for Static models")
}
ModelType::Embedding(pool) => pool,
};

let embedding = StaticEmbedding::load(vb.pp("embedding"), config, "weight".to_string())
.or_else(|_| StaticEmbedding::load(vb.clone(), config, "embeddings".to_string()))?;

Ok(Self {
pool,
embedding,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}

pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let _enter = self.span.enter();

let batch_size = batch.len();
let max_length = batch.max_length as usize;

let shape = (batch_size, max_length);

let (input_ids, input_lengths, attention_mask) = if batch_size > 1 {
// Prepare padded batch
let elems = batch_size * max_length;

let mut input_ids = Vec::with_capacity(elems);
let mut attention_mask = Vec::with_capacity(elems);
let mut input_lengths = Vec::with_capacity(batch_size);
// Bool to know if we need to use the attention mask
let mut masking = false;

for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = (end - start) as u32;
input_lengths.push(seq_length as f32);

// Copy values
for j in start..end {
input_ids.push(batch.input_ids[j]);
attention_mask.push(1.0_f32);
}

// Add padding if needed
let padding = batch.max_length - seq_length;
if padding > 0 {
// Set bool to use attention mask
masking = true;
for _ in 0..padding {
input_ids.push(0);
attention_mask.push(0.0_f32);
}
}
}

let attention_mask = match masking {
true => {
// We only need the mask if we use mean pooling
// For CLS pooling, the bias is enough
if self.pool == Pool::Mean {
let attention_mask = Tensor::from_vec(
attention_mask,
(batch_size, max_length, 1),
&self.device,
)?
.to_dtype(self.dtype)?;

Some(attention_mask)
} else {
None
}
}
false => None,
};

(input_ids, input_lengths, attention_mask)
} else {
(batch.input_ids, vec![batch.max_length as f32], None)
};

let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let mut input_lengths =
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;

let outputs = self.embedding.forward(&input_ids)?;

let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();

let pooled_embeddings = if has_pooling_requests {
let pooled_indices_length = batch.pooled_indices.len();
let mut outputs = outputs.clone();

// Only use pooled_indices if at least one member of the batch ask for raw embeddings
let pooled_indices = if has_raw_requests {
let pooled_indices =
Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?;

// Select values in the batch
outputs = outputs.index_select(&pooled_indices, 0)?;
Some(pooled_indices)
} else {
None
};

let pooled_embeddings = match self.pool {
// CLS pooling
Pool::Cls => outputs.i((.., 0))?,
// Mean pooling
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
let mut attention_mask = attention_mask.clone();

if let Some(pooled_indices) = pooled_indices {
// Select values in the batch
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
};

// Mask padded values
outputs = outputs.broadcast_mul(&attention_mask)?;
}

(outputs.sum(1)?.broadcast_div(&input_lengths))?
}
// Last token and splade pooling are not supported for this model
Pool::LastToken | Pool::Splade => unreachable!(),
};
Some(pooled_embeddings)
} else {
None
};

let raw_embeddings = if has_raw_requests {
// Reshape outputs
let (b, l, h) = outputs.shape().dims3()?;
let outputs = outputs.reshape((b * l, h))?;

// We need to remove the padding tokens only if batch_size > 1 and there are some
// member of the batch that require pooling
// or if batch_size > 1 and the members of the batch have different lengths
if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 {
let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length);

for i in batch.raw_indices.into_iter() {
let start = i * batch.max_length;
let i = i as usize;
let length =
batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];

for j in start..start + length {
// Add indices for the tokens of this specific member of the batch
final_indices.push(j);
}
}

let final_indices_length = final_indices.len();
let final_indices =
Tensor::from_vec(final_indices, final_indices_length, &self.device)?;

// Select the tokens with final indices
Some(outputs.index_select(&final_indices, 0)?)
} else {
Some(outputs)
}
} else {
None
};

Ok((pooled_embeddings, raw_embeddings))
}
}

impl Model for StaticEmbeddingModel {
fn is_padded(&self) -> bool {
true
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, _batch: Batch) -> Result<Tensor> {
candle::bail!("`predict` is not implemented for this model")
}
}
30 changes: 28 additions & 2 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ enum ModuleType {
Pooling,
#[serde(rename = "sentence_transformers.models.Transformer")]
Transformer,
#[serde(rename = "sentence_transformers.models.StaticEmbedding")]
StaticEmbedding,
}

#[derive(Deserialize)]
Expand Down Expand Up @@ -153,7 +155,11 @@ pub fn download_artifacts(
};

api_repo.get("config.json")?;
api_repo.get("tokenizer.json")?;

match api_repo.get("tokenizer.json") {
Ok(path) => path,
Err(_) => api_repo.get("0_StaticEmbedding/tokenizer.json")?,
};

let model_files = match download_safetensors(&api_repo) {
Ok(p) => p,
Expand Down Expand Up @@ -203,6 +209,17 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
Ok(p) => return Ok(vec![p]),
Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err),
};
if let Ok(path) = api.get("model.safetensors") {
return Ok(vec![path]);
}

tracing::warn!("Could not download `model.safetensors`");
tracing::info!("Downloading `0_StaticEmbedding/model.safetensors`");
if let Ok(path) = api.get("0_StaticEmbedding/model.safetensors") {
return Ok(vec![path]);
}

tracing::warn!("Could not download `model.safetensors`");

// Sharded weights
// Download and parse index file
Expand Down Expand Up @@ -279,7 +296,16 @@ pub fn cosine_matcher() -> YamlMatcher<SnapshotEmbeddings> {
pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
// Load tokenizer
let tokenizer_path = model_root.join("tokenizer.json");
let mut tokenizer = Tokenizer::from_file(tokenizer_path).expect("tokenizer.json not found");
let mut tokenizer = match Tokenizer::from_file(&tokenizer_path) {
Ok(t) => t,
Err(e) if e.to_string().contains("No such file") || e.to_string().contains("not found") => {
let fallback_path = model_root.join("0_StaticEmbedding").join("tokenizer.json");
Tokenizer::from_file(&fallback_path)
.expect("0_StaticEmbedding/tokenizer.json not found.")
}
Err(_) => anyhow::bail!("text-embeddings-inference only supports fast tokenizers"),
};

// See https://github.com/huggingface/tokenizers/pull/1357
if let Some(pre_tokenizer) = tokenizer.get_pre_tokenizer() {
if let PreTokenizerWrapper::Metaspace(m) = pre_tokenizer {
Expand Down
Loading