Skip to content

Commit dc62b26

Browse files
committed
feature: support bm42 embeddings
1 parent dcbea38 commit dc62b26

File tree

17 files changed

+705
-31
lines changed

17 files changed

+705
-31
lines changed

Cargo.lock

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ tokio = { workspace = true }
1717
tracing = { workspace = true }
1818

1919
[features]
20+
default = ["ort"]
2021
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
2122
python = ["dep:text-embeddings-backend-python"]
2223
ort = ["dep:text-embeddings-backend-ort"]

backends/candle/src/models/bert.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ impl BertModel {
856856

857857
(outputs.sum(1)?.broadcast_div(&input_lengths))?
858858
}
859+
Pool::BM42 => unreachable!(),
859860
Pool::Splade => {
860861
// Unwrap is safe here
861862
let splade_head = self.splade.as_ref().unwrap();
@@ -874,7 +875,7 @@ impl BertModel {
874875
}
875876

876877
relu_log.max(1)?
877-
}
878+
},
878879
};
879880
Some(pooled_embeddings)
880881
} else {

backends/candle/src/models/distilbert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ impl DistilBertModel {
587587

588588
(outputs.sum(1)?.broadcast_div(&input_lengths))?
589589
}
590+
Pool::BM42 => unreachable!(),
590591
Pool::Splade => {
591592
// Unwrap is safe here
592593
let splade_head = self.splade.as_ref().unwrap();

backends/candle/src/models/jina.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ impl JinaBertModel {
616616

617617
(outputs.sum(1)?.broadcast_div(&input_lengths))?
618618
}
619+
Pool::BM42 => unreachable!(),
619620
Pool::Splade => unreachable!(),
620621
};
621622
Some(pooled_embeddings)

backends/candle/src/models/jina_code.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,8 @@ impl JinaCodeBertModel {
604604
}
605605

606606
(outputs.sum(1)?.broadcast_div(&input_lengths))?
607-
}
607+
},
608+
Pool::BM42 => unreachable!(),
608609
Pool::Splade => unreachable!(),
609610
};
610611
Some(pooled_embeddings)

backends/candle/src/models/nomic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ impl NomicBertModel {
631631

632632
(outputs.sum(1)?.broadcast_div(&input_lengths))?
633633
}
634+
Pool::BM42 => unreachable!(),
634635
Pool::Splade => unreachable!(),
635636
};
636637
Some(pooled_embeddings)

backends/core/src/lib.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,34 @@ pub enum Pool {
6363
/// This option is only available if the loaded model is a `ForMaskedLM` Transformer
6464
/// model.
6565
Splade,
66+
/// Apply BM42 to the model embeddings.
67+
/// This option is only availale if the loaded model is Qdrant/all_miniLM_L6_v2_with_attentions
68+
BM42,
6669
/// Select the last token as embedding
6770
LastToken,
6871
}
6972

73+
#[derive(Debug, Clone)]
74+
pub struct Bm42Params {
75+
pub invert_vocab: std::collections::HashMap<u32, String>,
76+
pub stopwords: Vec<String>,
77+
pub special_tokens: Vec<String>,
78+
}
79+
80+
#[derive(Debug, Clone)]
81+
pub enum ModelParams {
82+
Bm42(Bm42Params),
83+
None
84+
}
85+
86+
7087
impl fmt::Display for Pool {
7188
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
7289
match self {
7390
Pool::Cls => write!(f, "cls"),
7491
Pool::Mean => write!(f, "mean"),
7592
Pool::Splade => write!(f, "splade"),
93+
Pool::BM42 => write!(f, "bm42"),
7694
Pool::LastToken => write!(f, "last_token"),
7795
}
7896
}

backends/ort/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ tracing = { workspace = true }
1515
thiserror = { workspace = true }
1616
serde = { workspace = true }
1717
serde_json = { workspace = true }
18+
rust-stemmers = "1.2.0"
19+
murmur3 = "0.5.2"

0 commit comments

Comments
 (0)