Skip to content

Commit c3785ed

Browse files
authored
Add mean pooling strategy for Modernbert classifier (#616)
1 parent 26fa510 commit c3785ed

File tree

5 files changed

+67
-2
lines changed

5 files changed

+67
-2
lines changed

backends/candle/src/models/flash_modernbert.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,11 @@ impl FlashModernBertModel {
260260

261261
let (pool, classifier) = match model_type {
262262
ModelType::Classifier => {
263-
let pool = Pool::Cls;
263+
let pool: Pool = config
264+
.classifier_pooling
265+
.as_deref()
266+
.and_then(|s| Pool::from_str(s).ok())
267+
.unwrap_or(Pool::Cls);
264268

265269
let classifier: Box<dyn ClassificationHead + Send> =
266270
Box::new(ModernBertClassificationHead::load(vb.clone(), config)?);

backends/candle/src/models/modernbert.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
77
use candle_nn::{Embedding, VarBuilder};
88
use serde::Deserialize;
99
use text_embeddings_backend_core::{Batch, ModelType, Pool};
10+
use std::str::FromStr;
1011

1112
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/configuration_modernbert.py
1213
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -484,7 +485,11 @@ impl ModernBertModel {
484485
pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result<Self> {
485486
let (pool, classifier) = match model_type {
486487
ModelType::Classifier => {
487-
let pool = Pool::Cls;
488+
let pool: Pool = config
489+
.classifier_pooling
490+
.as_deref()
491+
.and_then(|s| Pool::from_str(s).ok())
492+
.unwrap_or(Pool::Cls);
488493

489494
let classifier: Box<dyn ClassificationHead + Send> =
490495
Box::new(ModernBertClassificationHead::load(vb.clone(), config)?);
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
source: backends/candle/tests/test_modernbert.rs
3+
assertion_line: 229
4+
expression: predictions_single
5+
---
6+
- - -0.30617672
7+

backends/candle/tests/test_modernbert.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,35 @@ fn test_modernbert_classification() -> Result<()> {
202202

203203
Ok(())
204204
}
205+
206+
#[test]
207+
#[serial_test::serial]
208+
fn test_modernbert_classification_mean_pooling() -> Result<()> {
209+
let model_root = download_artifacts("tomaarsen/reranker-ModernBERT-large-gooaq-bce", None)?;
210+
let tokenizer = load_tokenizer(&model_root)?;
211+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
212+
213+
let input_single = batch(
214+
vec![tokenizer
215+
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
216+
.unwrap()],
217+
[0].to_vec(),
218+
vec![],
219+
);
220+
221+
let predictions: Vec<Vec<f32>> = backend
222+
.predict(input_single)?
223+
.into_iter()
224+
.map(|(_, v)| v)
225+
.collect();
226+
let predictions_single = SnapshotScores::from(predictions);
227+
228+
let matcher = relative_matcher();
229+
insta::assert_yaml_snapshot!(
230+
"modernbert_classification_mean_pooling",
231+
predictions_single,
232+
&matcher
233+
);
234+
235+
Ok(())
236+
}

backends/core/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ impl fmt::Display for Pool {
7878
}
7979
}
8080

81+
impl std::str::FromStr for Pool {
82+
type Err = String;
83+
84+
fn from_str(s: &str) -> Result<Self, Self::Err> {
85+
match s.trim().to_lowercase().as_str() {
86+
"cls" => Ok(Pool::Cls),
87+
"mean" => Ok(Pool::Mean),
88+
"splade" => Ok(Pool::Splade),
89+
"last_token" => Ok(Pool::LastToken),
90+
_ => Err(format!(
91+
"Invalid pooling method '{}'. Valid options: cls, mean, splade, last_token",
92+
s
93+
)),
94+
}
95+
}
96+
}
97+
8198
#[derive(Debug, Error, Clone)]
8299
pub enum BackendError {
83100
#[error("No backend found")]

0 commit comments

Comments
 (0)