diff --git a/Cargo.lock b/Cargo.lock index cd05ea15..70c906fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2292,6 +2292,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "murmur3" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9252111cf132ba0929b6f8e030cac2a24b507f3a4d6db6fb2896f27b354c714b" + [[package]] name = "native-tls" version = "0.2.12" @@ -3476,6 +3482,16 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust-stemmers" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54" +dependencies = [ + "serde", + "serde_derive", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -4063,9 +4079,11 @@ name = "text-embeddings-backend-ort" version = "1.4.0" dependencies = [ "anyhow", + "murmur3", "ndarray", "nohash-hasher", "ort", + "rust-stemmers", "serde", "serde_json", "text-embeddings-backend-core", diff --git a/backends/Cargo.toml b/backends/Cargo.toml index f29283ad..8d6c727a 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -17,6 +17,7 @@ tokio = { workspace = true } tracing = { workspace = true } [features] +default = ["ort"] clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] ort = ["dep:text-embeddings-backend-ort"] diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 32880d44..303fce16 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -856,6 +856,7 @@ impl BertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::BM42 => unreachable!(), Pool::Splade => { // Unwrap is safe here let splade_head = self.splade.as_ref().unwrap(); @@ -874,7 +875,7 @@ impl BertModel { } relu_log.max(1)? - } + }, }; Some(pooled_embeddings) } else { diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs index 2cf62081..819d44b6 100644 --- a/backends/candle/src/models/distilbert.rs +++ b/backends/candle/src/models/distilbert.rs @@ -587,6 +587,7 @@ impl DistilBertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::BM42 => unreachable!(), Pool::Splade => { // Unwrap is safe here let splade_head = self.splade.as_ref().unwrap(); diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index ecee8bfe..acbe3676 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -616,6 +616,7 @@ impl JinaBertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::BM42 => unreachable!(), Pool::Splade => unreachable!(), }; Some(pooled_embeddings) diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 5f13fe08..b1f074b9 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -604,7 +604,8 @@ impl JinaCodeBertModel { } (outputs.sum(1)?.broadcast_div(&input_lengths))? - } + }, + Pool::BM42 => unreachable!(), Pool::Splade => unreachable!(), }; Some(pooled_embeddings) diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index cdaaea92..ea69fa90 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -631,6 +631,7 @@ impl NomicBertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::BM42 => unreachable!(), Pool::Splade => unreachable!(), }; Some(pooled_embeddings) diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 932c0083..61b4ac12 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -63,16 +63,34 @@ pub enum Pool { /// This option is only available if the loaded model is a `ForMaskedLM` Transformer /// model. Splade, + /// Apply BM42 to the model embeddings. + /// This option is only availale if the loaded model is Qdrant/all_miniLM_L6_v2_with_attentions + BM42, /// Select the last token as embedding LastToken, } +#[derive(Debug, Clone)] +pub struct Bm42Params { + pub invert_vocab: std::collections::HashMap, + pub stopwords: Vec, + pub special_tokens: Vec, +} + +#[derive(Debug, Clone)] +pub enum ModelParams { + Bm42(Bm42Params), + None +} + + impl fmt::Display for Pool { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Pool::Cls => write!(f, "cls"), Pool::Mean => write!(f, "mean"), Pool::Splade => write!(f, "splade"), + Pool::BM42 => write!(f, "bm42"), Pool::LastToken => write!(f, "last_token"), } } diff --git a/backends/ort/Cargo.toml b/backends/ort/Cargo.toml index 9cbde87e..01fb250d 100644 --- a/backends/ort/Cargo.toml +++ b/backends/ort/Cargo.toml @@ -15,3 +15,5 @@ tracing = { workspace = true } thiserror = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +rust-stemmers = "1.2.0" +murmur3 = "0.5.2" diff --git a/backends/ort/src/bm42.rs b/backends/ort/src/bm42.rs new file mode 100644 index 00000000..d5cd1734 --- /dev/null +++ b/backends/ort/src/bm42.rs @@ -0,0 +1,512 @@ +use ndarray::s; +use nohash_hasher::BuildNoHashHasher; +use ort::{GraphOptimizationLevel, Session}; +use std::collections::HashMap; +use std::path::PathBuf; +use text_embeddings_backend_core::{ + Backend, BackendError, Batch, Bm42Params, Embedding, Embeddings, ModelType, Pool, Predictions, +}; + +pub struct Bm42Backend { + session: Session, + pool: Pool, + type_id_name: Option, + invert_vocab: HashMap, + punctuation: Vec, + alpha: f32, + stemmer: rust_stemmers::Stemmer, + stopwords: Vec, + special_tokens: Vec, +} + +impl Bm42Backend { + pub fn new( + model_path: PathBuf, + dtype: String, + model_type: ModelType, + model_params: Bm42Params, + ) -> Result { + // Check dtype + if &dtype == "float32" { + } else { + return Err(BackendError::Start(format!( + "DType {dtype} is not supported" + ))); + }; + + // Check model type + let pool = match model_type { + ModelType::Classifier => Pool::Cls, + ModelType::Embedding(pool) => match pool { + Pool::Splade | Pool::LastToken => { + return Err(BackendError::Start(format!( + "Pooling {pool} is not supported for this backend. Use `candle` backend instead." + ))); + } + pool => pool, + }, + }; + + // Get model path + let onnx_path = { + let default_path = model_path.join("model.onnx"); + match default_path.exists() { + true => default_path, + false => model_path.join("onnx/model.onnx"), + } + }; + + // Start onnx session + let session = Session::builder() + .s()? + .with_optimization_level(GraphOptimizationLevel::Level3) + .s()? + .commit_from_file(onnx_path) + .s()?; + + // Check if the model requires type tokens + let mut type_id_name = None; + for input in &session.inputs { + if &input.name == "token_type_ids" || &input.name == "input_type" { + type_id_name = Some(input.name.clone()); + break; + } + } + + let punctuation = &[ + "!", "\"", "#", "$", "%", "&", "\'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", + "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", + ]; + let punctuation: Vec = punctuation.iter().map(|x| x.to_string()).collect(); + let punctuation = punctuation.to_vec(); + + Ok(Self { + session, + pool, + type_id_name, + invert_vocab: model_params.invert_vocab, + punctuation, + stemmer: rust_stemmers::Stemmer::create(rust_stemmers::Algorithm::English), + alpha: 0.5, + stopwords: model_params.stopwords, + special_tokens: model_params.special_tokens, + }) + } +} + +impl Backend for Bm42Backend { + fn max_batch_size(&self) -> Option { + Some(8) + } + + fn health(&self) -> Result<(), BackendError> { + Ok(()) + } + + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result { + println!("himoadf"); + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + // Whether a least one of the request in the batch is padded + let mut masking = true; + + let (input_ids, type_ids, input_lengths, attention_mask) = { + let elems = batch_size * max_length; + + if batch_size > 1 { + // Prepare padded batch + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j] as i64); + type_ids.push(batch.token_type_ids[j] as i64); + attention_mask.push(1_i64); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + attention_mask.push(0_i64); + } + } + } + (input_ids, type_ids, input_lengths, attention_mask) + } else { + let attention_mask = vec![1_i64; elems]; + + ( + batch.input_ids.into_iter().map(|v| v as i64).collect(), + batch.token_type_ids.into_iter().map(|v| v as i64).collect(), + vec![batch.max_length as f32], + attention_mask, + ) + } + }; + + // Create ndarrays + let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; + let attention_mask = + ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + let input_lengths = ndarray::Array1::from_vec(input_lengths); + + // Create onnx inputs + let inputs = match (self.type_id_name.as_ref(), &self.pool) { + (_, Pool::BM42) => { + println!("input_ids {:?}", input_ids); + (ort::inputs![ + "input_ids" => ort::Value::from_array(input_ids.clone()).unwrap() + ]) + .e()? + } + (Some(type_id_name), _) => { + // Add type ids to inputs + let type_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; + ort::inputs!["input_ids" => input_ids.clone(), "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + } + (None, _) => { + ort::inputs!["input_ids" => input_ids.clone(), "attention_mask" => attention_mask.clone()] + .e()? + } + }; + // Run model + let session_outputs = self.session.run(inputs).e()?; + + // Final embeddings struct + let mut embeddings = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + + let outputs = { + let mut output_final: Vec>> = Vec::with_capacity(batch_size); + + let outputs = session_outputs + .get("attention_6") + .ok_or(BackendError::Inference(format!( + "Unknown output keys: {:?}", + self.session.outputs + )))? + .try_extract_tensor::() + .e()? + .to_owned(); + + for i in 0..batch_size { + let output: Vec> = outputs + .view() + .slice(s![i, .., 0, ..]) + .rows() + .into_iter() + .map(|row| row.to_vec()) + .collect(); + + output_final.push(output); + } + + output_final + }; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + + if has_pooling_requests { + let outputs = outputs.clone(); + + let pooled_embeddings: Vec> = match self.pool { + Pool::BM42 => outputs + .iter() + .map(|output| { + let mean = + output + .iter() + .fold(vec![0.0; output[0].len()], |acc, inner_vec| { + acc.iter().zip(inner_vec).map(|(&a, &b)| a + b).collect() + }); + let mean = mean + .iter() + .map(|&sum| sum / output.len() as f32) + .collect::>(); + + mean.iter() + .zip(attention_mask.clone()) + .map(|(m, a)| m * a as f32) + .collect() + }) + .collect(), + _ => unreachable!(), + }; + + println!("fda"); + let mut rescored_vectors = vec![]; + + for i in 0..input_ids.slice(s![.., 0]).len() { + let document_token_ids = input_ids.slice(s![i, ..]); + let attention_value = &pooled_embeddings[i]; + + let doc_tokens_with_ids: Vec<(usize, String)> = document_token_ids + .iter() + .enumerate() + .map(|(idx, &id)| (idx, self.invert_vocab[&(id as u32)].clone())) + .collect(); + + println!("spe {:?}", self.special_tokens); + let reconstructed = + reconstruct_bpe(doc_tokens_with_ids, &self.special_tokens.clone()); + println!("rec {:?}", reconstructed); + + let filtered = + filter_pair_tokens(reconstructed, &self.stopwords, &self.punctuation); + println!("fffc {:?}", filtered); + + let stemmed = stem_pair_tokens(&self.stemmer, filtered); + println!("stemme {:?}", stemmed); + + let weighted = aggregate_weights(&stemmed, attention_value); + println!("weights {:?}", weighted); + + let mut max_token_weight: HashMap = HashMap::new(); + + weighted.into_iter().for_each(|(token, weight)| { + let weight = max_token_weight.get(&token).unwrap_or(&0.0).max(weight); + max_token_weight.insert(token, weight); + }); + + let rescored = rescore_vector(&max_token_weight, self.alpha); + + let max_value = *rescored.keys().max().unwrap_or(&0) as usize; + + // Convert HashMap into vec![] + let mut embedding = ndarray::Array::zeros(max_value + 1); + + for (k, v) in rescored.iter() { + embedding[*k as usize] = *v + } + + rescored_vectors.push(embedding); + } + + for (i, e) in batch.pooled_indices.into_iter().zip(rescored_vectors) { + embeddings.insert(i as usize, Embedding::Pooled(e.to_vec())); + } + }; + + println!("returning out "); + Ok(embeddings) + } + + fn predict(&self, batch: Batch) -> Result { + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let (input_ids, type_ids, attention_mask) = { + let elems = batch_size * max_length; + + if batch_size > 1 { + // Prepare padded batch + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j] as i64); + type_ids.push(batch.token_type_ids[j] as i64); + attention_mask.push(1_i64); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + attention_mask.push(0_i64); + } + } + } + (input_ids, type_ids, attention_mask) + } else { + let attention_mask = vec![1_i64; elems]; + + ( + batch.input_ids.into_iter().map(|v| v as i64).collect(), + batch.token_type_ids.into_iter().map(|v| v as i64).collect(), + attention_mask, + ) + } + }; + + // Create ndarrays + let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; + let attention_mask = + ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + + // Create onnx inputs + let inputs = match self.type_id_name.as_ref() { + Some(type_id_name) => { + // Add type ids to inputs + let type_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + } + None => { + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] + .e()? + } + }; + + // Run model + let outputs = self.session.run(inputs).e()?; + // Get last_hidden_state ndarray + let outputs = outputs["logits"] + .try_extract_tensor::() + .e()? + .to_owned(); + + let mut predictions = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + for (i, r) in outputs.rows().into_iter().enumerate() { + predictions.insert(i, r.to_vec()); + } + + Ok(predictions) + } +} + +pub fn stem_pair_tokens( + stemmer: &rust_stemmers::Stemmer, + tokens: Vec<(String, Vec)>, +) -> Vec<(String, Vec)> { + let mut result: Vec<(String, Vec)> = Vec::new(); + + for (token, value) in tokens.into_iter() { + let processed_token = stemmer.stem(&token).to_string(); + result.push((processed_token, value)); + } + + result +} + +pub fn rescore_vector(vector: &HashMap, alpha: f32) -> HashMap { + let mut new_vector: HashMap = HashMap::new(); + + for (token, &value) in vector.iter() { + let token_id = + (murmur3::murmur3_32(&mut std::io::Cursor::new(token), 0).unwrap() as i32).abs(); + + let new_score = (1.0 + value).ln().powf(alpha); + + new_vector.insert(token_id, new_score); + } + + new_vector +} + +pub fn aggregate_weights(tokens: &[(String, Vec)], weights: &[f32]) -> Vec<(String, f32)> { + let mut result: Vec<(String, f32)> = Vec::new(); + + for (token, idxs) in tokens.iter() { + let sum_weight: f32 = idxs.iter().map(|&idx| weights[idx]).sum(); + result.push((token.clone(), sum_weight)); + } + + result +} + +pub fn filter_pair_tokens( + tokens: Vec<(String, Vec)>, + stopwords: &[String], + punctuation: &[String], +) -> Vec<(String, Vec)> { + let mut result: Vec<(String, Vec)> = Vec::new(); + + for (token, value) in tokens.into_iter() { + if stopwords.contains(&token) || punctuation.contains(&token) { + continue; + } + result.push((token.clone(), value)); + } + + result +} + +pub fn reconstruct_bpe( + bpe_tokens: impl IntoIterator, + special_tokens: &[String], +) -> Vec<(String, Vec)> { + let mut result = Vec::new(); + let mut acc = String::new(); + let mut acc_idx = Vec::new(); + + let continuing_subword_prefix = "##"; + let continuing_subword_prefix_len = continuing_subword_prefix.len(); + + for (idx, token) in bpe_tokens { + if special_tokens.contains(&token) { + continue; + } + + if token.starts_with(continuing_subword_prefix) { + acc.push_str(&token[continuing_subword_prefix_len..]); + acc_idx.push(idx); + } else { + if !acc.is_empty() { + result.push((acc.clone(), acc_idx.clone())); + acc_idx = vec![]; + } + acc = token; + acc_idx.push(idx); + } + } + + if !acc.is_empty() { + result.push((acc, acc_idx)); + } + + result +} + +pub trait WrapErr { + fn s(self) -> Result; + fn e(self) -> Result; +} + +impl WrapErr for Result { + fn s(self) -> Result { + self.map_err(|e| BackendError::Start(e.to_string())) + } + fn e(self) -> Result { + self.map_err(|e| BackendError::Inference(e.to_string())) + } +} + +impl WrapErr for Result { + fn s(self) -> Result { + self.map_err(|e| BackendError::Start(e.to_string())) + } + fn e(self) -> Result { + self.map_err(|e| BackendError::Inference(e.to_string())) + } +} diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 9573f6b0..add3f3e8 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -1,3 +1,5 @@ +pub mod bm42; + use ndarray::{s, Axis}; use nohash_hasher::BuildNoHashHasher; use ort::{GraphOptimizationLevel, Session}; @@ -150,26 +152,35 @@ impl Backend for OrtBackend { let input_lengths = ndarray::Array1::from_vec(input_lengths); // Create onnx inputs - let inputs = match self.type_id_name.as_ref() { - Some(type_id_name) => { + let inputs = match (self.type_id_name.as_ref(), &self.pool) { + (_, Pool::BM42) => { + (ort::inputs![ + "input_ids" => ort::Value::from_array(input_ids.clone()).unwrap() + ]) + .e()? + } + (Some(type_id_name), _) => { // Add type ids to inputs let type_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? } - None => { + (None, _) => { ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] .e()? } }; - // Run model - let outputs = self.session.run(inputs).e()?; - // Get last_hidden_state ndarray + let session_outputs = self.session.run(inputs).e()?; - let outputs = outputs + // Final embeddings struct + let mut embeddings = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + + // Get last_hidden_state ndarray + let outputs = session_outputs .get("last_hidden_state") - .or(outputs.get("token_embeddings")) + .or(session_outputs.get("token_embeddings")) .ok_or(BackendError::Inference(format!( "Unknown output keys: {:?}", self.session.outputs @@ -178,9 +189,35 @@ impl Backend for OrtBackend { .e()? .to_owned(); - // Final embeddings struct - let mut embeddings = - HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + let bm42_outputs = if Pool::BM42 == self.pool { + let mut output_final: Vec>> = Vec::with_capacity(batch_size); + + let outputs = session_outputs + .get("attention_6") + .ok_or(BackendError::Inference(format!( + "Unknown output keys: {:?}", + self.session.outputs + )))? + .try_extract_tensor::() + .e()? + .to_owned(); + + for i in 0..batch_size { + let output: Vec> = outputs + .view() + .slice(s![i, .., 0, ..]) + .rows() + .into_iter() + .map(|row| row.to_vec()) + .collect(); + + output_final.push(output); + } + + Some(output_final) + } else { + None + }; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); @@ -229,6 +266,7 @@ impl Backend for OrtBackend { outputs.mean_axis(Axis(1)).unwrap() } } + Pool::BM42 => unreachable!(), Pool::Splade => unreachable!(), }; @@ -242,6 +280,7 @@ impl Backend for OrtBackend { }; if has_raw_requests { + println!("reshape"); // Reshape outputs let s = outputs.shape().to_vec(); let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?; @@ -397,3 +436,4 @@ impl WrapErr for Result { self.map_err(|e| BackendError::Inference(e.to_string())) } } + diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 2ee63279..b00e44d4 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -12,12 +12,14 @@ use tracing::{instrument, Span}; pub use crate::dtype::DType; pub use text_embeddings_backend_core::{ - BackendError, Batch, Embedding, Embeddings, ModelType, Pool, + BackendError, Batch, Embedding, Embeddings, ModelType, Pool, ModelParams, Bm42Params, }; #[cfg(feature = "candle")] use text_embeddings_backend_candle::CandleBackend; +#[cfg(feature = "ort")] +use text_embeddings_backend_ort::bm42::Bm42Backend; #[cfg(feature = "ort")] use text_embeddings_backend_ort::OrtBackend; @@ -44,6 +46,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + model_params: ModelParams ) -> Result { let (backend_sender, backend_receiver) = mpsc::channel(8); @@ -54,6 +57,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + model_params )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -200,6 +204,7 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + model_params: ModelParams, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -227,12 +232,21 @@ fn init_backend( )); } } else if cfg!(feature = "ort") { - #[cfg(feature = "ort")] - return Ok(Box::new(OrtBackend::new( - model_path, - dtype.to_string(), - model_type, - )?)); + #[cfg(feature = "ort")] + if let ModelParams::Bm42(params) = model_params { + return Ok(Box::new(Bm42Backend::new( + model_path, + dtype.to_string(), + model_type, + params + )?)); + } else { + return Ok(Box::new(OrtBackend::new( + model_path, + dtype.to_string(), + model_type, + )?)); + } } Err(BackendError::NoBackend) } @@ -364,3 +378,4 @@ async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { Ok(safetensors_files) } + diff --git a/core/src/download.rs b/core/src/download.rs index 02ea484d..9ed17fdd 100644 --- a/core/src/download.rs +++ b/core/src/download.rs @@ -63,3 +63,18 @@ pub async fn download_new_st_config(api: &ApiRepo) -> Result let pool_config_path = api.get("config_sentence_transformers.json").await?; Ok(pool_config_path) } + +#[instrument(skip_all)] +pub async fn download_stopwords(api: &ApiRepo) -> Result { + tracing::info!("Downloading `stopwords.txt`"); + let pool_config_path = api.get("stopwords.txt").await?; + Ok(pool_config_path) +} + +#[instrument(skip_all)] +pub async fn download_special_tokens_maps(api: &ApiRepo) -> Result { + tracing::info!("Downloading `special_tokens_map.json`"); + let pool_config_path = api.get("special_tokens_map.json").await?; + Ok(pool_config_path) +} + diff --git a/core/src/infer.rs b/core/src/infer.rs index 23b343bf..d961a819 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -125,7 +125,7 @@ impl Infer { ) -> Result { let start_time = Instant::now(); - if self.is_splade() { + if self.is_sparse() { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); let message = "`embed_all` is not available for SPLADE models".to_string(); @@ -180,7 +180,7 @@ impl Infer { ) -> Result { let start_time = Instant::now(); - if !self.is_splade() { + if !self.is_sparse() { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); let message = "Model is not an embedding model with SPLADE pooling".to_string(); @@ -236,7 +236,7 @@ impl Infer { ) -> Result { let start_time = Instant::now(); - if self.is_splade() && normalize { + if self.is_sparse() && normalize { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); let message = "`normalize` is not available for SPLADE models".to_string(); @@ -480,10 +480,10 @@ impl Infer { } #[instrument(skip(self))] - pub fn is_splade(&self) -> bool { + pub fn is_sparse(&self) -> bool { matches!( self.backend.model_type, - ModelType::Embedding(text_embeddings_backend::Pool::Splade) + ModelType::Embedding(text_embeddings_backend::Pool::Splade) | ModelType::Embedding(text_embeddings_backend::Pool::BM42) ) } diff --git a/router/Cargo.toml b/router/Cargo.toml index 3a3fba27..fca7d33a 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -74,7 +74,7 @@ vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } tonic-build = { version = "0.11.0", optional = true } [features] -default = ["candle", "http"] +default = ["candle", "http", "ort"] http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:base64", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"] metal = ["text-embeddings-backend/metal"] diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 040ac070..a23c88b7 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1502,7 +1502,7 @@ async fn vertex_compatibility( futures.push(predict_future(local_infer, local_info, instance).boxed()); } ModelType::Embedding(_) => { - if infer.is_splade() { + if infer.is_sparse() { let instance = serde_json::from_value::(instance) .map_err(ErrorResponse::from)?; futures.push(embed_sparse_future(local_infer, local_info, instance).boxed()); diff --git a/router/src/lib.rs b/router/src/lib.rs index 86a0f884..15c3d715 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -26,10 +26,10 @@ use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::{Duration, Instant}; -use text_embeddings_backend::{DType, Pool}; +use text_embeddings_backend::{Bm42Params, DType, ModelParams, Pool}; use text_embeddings_core::download::{ - download_artifacts, download_new_st_config, download_pool_config, download_st_config, - ST_CONFIG_NAMES, + download_artifacts, download_new_st_config, download_pool_config, download_special_tokens_maps, + download_st_config, download_stopwords, ST_CONFIG_NAMES, }; use text_embeddings_core::infer::Infer; use text_embeddings_core::queue::Queue; @@ -99,6 +99,12 @@ pub async fn run( // Download new sentence transformers config let _ = download_new_st_config(&api_repo).await; + // Download special_tokesn_map.json + let _ = download_special_tokens_maps(&api_repo).await; + + // Download stopwords.txt + let _ = download_stopwords(&api_repo).await; + // // Download model from the Hub download_artifacts(&api_repo) .await @@ -112,7 +118,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, pooling.clone())?; // Info model type let model_type = match &backend_model_type { @@ -224,6 +230,47 @@ pub async fn run( default_prompt }; + let model_params: ModelParams = if Some(Pool::BM42) == pooling.clone() { + let mut special_tokens = vec![]; + + let special_tokens_map_path = model_root.join("special_tokens_map.json"); + if let Ok(special_tokens_txt) = fs::read_to_string(special_tokens_map_path) { + let special_tokens_map = serde_json::from_str(&special_tokens_txt) + .context("Failed to parse `special_tokens_map.json`")?; + + if let serde_json::Value::Object(root_object) = special_tokens_map { + for (_, value) in root_object.iter() { + if value.is_string() { + let token = value.as_str().unwrap(); + + special_tokens.push(token.to_string()); + } else if value.is_object() { + let token = value["content"].as_str().unwrap(); + + special_tokens.push(token.to_string()); + } + } + } + } + + let stop_words_path = model_root.join("stopwords.txt"); + let stopwords = fs::read_to_string(stop_words_path).context("Failed to parse `stopwords.txt`")?.lines().map(|s| s.to_string()).collect(); + + let invert_vocab: HashMap = tokenizer + .get_vocab(true) + .into_iter() + .map(|(key, value)| (value, key)) + .collect(); + + ModelParams::Bm42(Bm42Params { + invert_vocab, + special_tokens, + stopwords, + }) + } else { + ModelParams::None + }; + // Tokenization logic let tokenization = Tokenization::new( tokenization_workers, @@ -246,6 +293,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), + model_params, ) .context("Could not create backend")?; backend