-
Notifications
You must be signed in to change notification settings - Fork 250
/
Copy pathtest_jina_code.rs
50 lines (40 loc) · 1.67 KB
/
test_jina_code.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
mod common;
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};
#[test]
fn test_jina_code_base() -> Result<()> {
let model_root = download_artifacts("jinaai/jina-embeddings-v2-base-code", None)?;
let tokenizer = load_tokenizer(&model_root)?;
let backend = CandleBackend::new(
model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
let input_batch = batch(
vec![
tokenizer.encode("What is Deep Learning?", true).unwrap(),
tokenizer.encode("Deep Learning is...", true).unwrap(),
tokenizer.encode("What is Deep Learning?", true).unwrap(),
],
[0, 1, 2].to_vec(),
vec![],
);
let matcher = cosine_matcher();
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?);
let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings);
insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher);
let input_single = batch(
vec![tokenizer.encode("What is Deep Learning?", true).unwrap()],
[0].to_vec(),
vec![],
);
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?);
let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings);
insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher);
assert_eq!(embeddings_batch[0], embeddings_single[0]);
assert_eq!(embeddings_batch[2], embeddings_single[0]);
Ok(())
}