Skip to content

Commit 09edc12

Browse files
committed
update: codes
1 parent 06f533f commit 09edc12

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

backends/candle/tests/common.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ enum ModuleType {
113113
Pooling,
114114
#[serde(rename = "sentence_transformers.models.Transformer")]
115115
Transformer,
116+
#[serde(rename = "sentence_transformers.models.StaticEmbedding")]
117+
StaticEmbedding,
116118
}
117119

118120
#[derive(Deserialize)]
@@ -153,7 +155,11 @@ pub fn download_artifacts(
153155
};
154156

155157
api_repo.get("config.json")?;
156-
api_repo.get("tokenizer.json")?;
158+
159+
match api_repo.get("tokenizer.json") {
160+
Ok(path) => path,
161+
Err(_) => api_repo.get("0_StaticEmbedding/tokenizer.json")?,
162+
};
157163

158164
let model_files = match download_safetensors(&api_repo) {
159165
Ok(p) => p,
@@ -203,6 +209,17 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
203209
Ok(p) => return Ok(vec![p]),
204210
Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err),
205211
};
212+
if let Ok(path) = api.get("model.safetensors") {
213+
return Ok(vec![path]);
214+
}
215+
216+
tracing::warn!("Could not download `model.safetensors`");
217+
tracing::info!("Downloading `0_StaticEmbedding/model.safetensors`");
218+
if let Ok(path) = api.get("0_StaticEmbedding/model.safetensors") {
219+
return Ok(vec![path]);
220+
}
221+
222+
tracing::warn!("Could not download `model.safetensors`");
206223

207224
// Sharded weights
208225
// Download and parse index file
@@ -279,7 +296,16 @@ pub fn cosine_matcher() -> YamlMatcher<SnapshotEmbeddings> {
279296
pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
280297
// Load tokenizer
281298
let tokenizer_path = model_root.join("tokenizer.json");
282-
let mut tokenizer = Tokenizer::from_file(tokenizer_path).expect("tokenizer.json not found");
299+
let mut tokenizer = match Tokenizer::from_file(&tokenizer_path) {
300+
Ok(t) => t,
301+
Err(e) if e.to_string().contains("No such file") || e.to_string().contains("not found") => {
302+
let fallback_path = model_root.join("0_StaticEmbedding").join("tokenizer.json");
303+
Tokenizer::from_file(&fallback_path)
304+
.expect("0_StaticEmbedding/tokenizer.json not found.")
305+
}
306+
Err(_) => anyhow::bail!("text-embeddings-inference only supports fast tokenizers"),
307+
};
308+
283309
// See https://github.com/huggingface/tokenizers/pull/1357
284310
if let Some(pre_tokenizer) = tokenizer.get_pre_tokenizer() {
285311
if let PreTokenizerWrapper::Metaspace(m) = pre_tokenizer {

backends/candle/tests/test_static_embedding.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
99
#[test]
1010
#[serial_test::serial]
1111
fn test_static_embedding() -> Result<()> {
12-
let (model_root, _) =
13-
download_artifacts("sentence-transformers/static-similarity-mrl-multilingual-v1", None, None)?;
12+
let (model_root, _) = download_artifacts(
13+
"sentence-transformers/static-similarity-mrl-multilingual-v1",
14+
Some("refs/pr/7"),
15+
None,
16+
)?;
1417
let tokenizer = load_tokenizer(&model_root)?;
1518

1619
let backend = CandleBackend::new(
@@ -71,8 +74,11 @@ fn test_static_embedding() -> Result<()> {
7174
#[test]
7275
#[serial_test::serial]
7376
fn test_static_embedding_pooled_raw() -> Result<()> {
74-
let (model_root, _) =
75-
download_artifacts("sentence-transformers/static-similarity-mrl-multilingual-v1", None, None)?;
77+
let (model_root, _) = download_artifacts(
78+
"sentence-transformers/static-similarity-mrl-multilingual-v1",
79+
Some("refs/pr/7"),
80+
None,
81+
)?;
7682
let tokenizer = load_tokenizer(&model_root)?;
7783

7884
let backend = CandleBackend::new(
@@ -99,7 +105,11 @@ fn test_static_embedding_pooled_raw() -> Result<()> {
99105

100106
let (pooled_embeddings, raw_embeddings) = sort_embeddings(backend.embed(input_batch)?);
101107
let pooled_embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings);
102-
insta::assert_yaml_snapshot!("static_embedding_batch_pooled", pooled_embeddings_batch, &matcher);
108+
insta::assert_yaml_snapshot!(
109+
"static_embedding_batch_pooled",
110+
pooled_embeddings_batch,
111+
&matcher
112+
);
103113

104114
let raw_embeddings_batch = SnapshotEmbeddings::from(raw_embeddings);
105115
insta::assert_yaml_snapshot!("static_embedding_batch_raw", raw_embeddings_batch, &matcher);
@@ -118,7 +128,11 @@ fn test_static_embedding_pooled_raw() -> Result<()> {
118128

119129
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?);
120130
let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings);
121-
insta::assert_yaml_snapshot!("static_embedding_single_pooled", embeddings_single, &matcher);
131+
insta::assert_yaml_snapshot!(
132+
"static_embedding_single_pooled",
133+
embeddings_single,
134+
&matcher
135+
);
122136

123137
assert_eq!(pooled_embeddings_batch[0], embeddings_single[0]);
124138
assert_eq!(pooled_embeddings_batch[2], embeddings_single[0]);

0 commit comments

Comments
 (0)