Skip to content

Commit d09af55

Browse files
I8dNLod.rudenko
and
d.rudenko
authored
Nomic-embeddings-support (#280)
* Nomic-embeddings-support * Jina models moved to pooled-normalized embeddings * Canonical vector for nomic-ai/nomic-embed-text-v1.5-Q * Moved all nomics to pooled_embeddings --------- Co-authored-by: d.rudenko <[email protected]>
1 parent 9387ca3 commit d09af55

6 files changed

+119
-114
lines changed

fastembed/text/mini_lm_embedding.py

-58
This file was deleted.

fastembed/text/onnx_embedding.py

+6-32
Original file line numberDiff line numberDiff line change
@@ -80,36 +80,6 @@
8080
},
8181
"model_file": "model_optimized.onnx",
8282
},
83-
{
84-
"model": "nomic-ai/nomic-embed-text-v1",
85-
"dim": 768,
86-
"description": "8192 context length english model",
87-
"size_in_GB": 0.52,
88-
"sources": {
89-
"hf": "nomic-ai/nomic-embed-text-v1",
90-
},
91-
"model_file": "onnx/model.onnx",
92-
},
93-
{
94-
"model": "nomic-ai/nomic-embed-text-v1.5",
95-
"dim": 768,
96-
"description": "8192 context length english model",
97-
"size_in_GB": 0.52,
98-
"sources": {
99-
"hf": "nomic-ai/nomic-embed-text-v1.5",
100-
},
101-
"model_file": "onnx/model.onnx",
102-
},
103-
{
104-
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
105-
"dim": 768,
106-
"description": "Quantized 8192 context length english model",
107-
"size_in_GB": 0.13,
108-
"sources": {
109-
"hf": "nomic-ai/nomic-embed-text-v1.5",
110-
},
111-
"model_file": "onnx/model_quantized.onnx",
112-
},
11383
{
11484
"model": "thenlper/gte-large",
11585
"dim": 1024,
@@ -274,7 +244,9 @@ def _preprocess_onnx_input(
274244
"""
275245
return onnx_input
276246

277-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
247+
def _post_process_onnx_output(
248+
self, output: OnnxOutputContext
249+
) -> Iterable[np.ndarray]:
278250
embeddings = output.model_output
279251
return normalize(embeddings[:, 0]).astype(np.float32)
280252

@@ -286,4 +258,6 @@ def init_embedding(
286258
cache_dir: str,
287259
**kwargs,
288260
) -> OnnxTextEmbedding:
289-
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
261+
return OnnxTextEmbedding(
262+
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
263+
)

fastembed/text/pooled_embedding.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Any, Dict, Iterable, List, Type
2+
3+
import numpy as np
4+
5+
from fastembed.common.onnx_model import OnnxOutputContext
6+
from fastembed.common.utils import normalize
7+
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
8+
from fastembed.text.onnx_text_model import TextEmbeddingWorker
9+
10+
supported_pooled_models = [
11+
{
12+
"model": "nomic-ai/nomic-embed-text-v1.5",
13+
"dim": 768,
14+
"description": "8192 context length english model",
15+
"size_in_GB": 0.52,
16+
"sources": {
17+
"hf": "nomic-ai/nomic-embed-text-v1.5",
18+
},
19+
"model_file": "onnx/model.onnx",
20+
},
21+
{
22+
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
23+
"dim": 768,
24+
"description": "Quantized 8192 context length english model",
25+
"size_in_GB": 0.13,
26+
"sources": {
27+
"hf": "nomic-ai/nomic-embed-text-v1.5",
28+
},
29+
"model_file": "onnx/model_quantized.onnx",
30+
},
31+
{
32+
"model": "nomic-ai/nomic-embed-text-v1",
33+
"dim": 768,
34+
"description": "8192 context length english model",
35+
"size_in_GB": 0.52,
36+
"sources": {
37+
"hf": "nomic-ai/nomic-embed-text-v1",
38+
},
39+
"model_file": "onnx/model.onnx",
40+
},
41+
]
42+
43+
44+
class PooledEmbedding(OnnxTextEmbedding):
45+
@classmethod
46+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
47+
return PooledEmbeddingWorker
48+
49+
@classmethod
50+
def mean_pooling(
51+
cls, model_output: np.ndarray, attention_mask: np.ndarray
52+
) -> np.ndarray:
53+
token_embeddings = model_output
54+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
55+
input_mask_expanded = np.tile(
56+
input_mask_expanded, (1, 1, token_embeddings.shape[-1])
57+
)
58+
input_mask_expanded = input_mask_expanded.astype(float)
59+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
60+
sum_mask = np.sum(input_mask_expanded, axis=1)
61+
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
62+
return pooled_embeddings
63+
64+
@classmethod
65+
def list_supported_models(cls) -> List[Dict[str, Any]]:
66+
"""Lists the supported models.
67+
68+
Returns:
69+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
70+
"""
71+
return supported_pooled_models
72+
73+
def _post_process_onnx_output(
74+
self, output: OnnxOutputContext
75+
) -> Iterable[np.ndarray]:
76+
embeddings = output.model_output
77+
attn_mask = output.attention_mask
78+
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
79+
80+
81+
class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
82+
def init_embedding(
83+
self, model_name: str, cache_dir: str, **kwargs
84+
) -> OnnxTextEmbedding:
85+
return PooledEmbedding(
86+
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
87+
)

fastembed/text/jina_onnx_embedding.py fastembed/text/pooled_normalized_embedding.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,20 @@
66
from fastembed.common.utils import normalize
77
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
88
from fastembed.text.onnx_text_model import TextEmbeddingWorker
9+
from fastembed.text.pooled_embedding import PooledEmbedding
910

10-
supported_jina_models = [
11+
supported_pooled_normalized_models = [
12+
{
13+
"model": "sentence-transformers/all-MiniLM-L6-v2",
14+
"dim": 384,
15+
"description": "Sentence Transformer model, MiniLM-L6-v2",
16+
"size_in_GB": 0.09,
17+
"sources": {
18+
"url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
19+
"hf": "qdrant/all-MiniLM-L6-v2-onnx",
20+
},
21+
"model_file": "model.onnx",
22+
},
1123
{
1224
"model": "jinaai/jina-embeddings-v2-base-en",
1325
"dim": 768,
@@ -35,20 +47,10 @@
3547
]
3648

3749

38-
class JinaOnnxEmbedding(OnnxTextEmbedding):
50+
class PooledNormalizedEmbedding(PooledEmbedding):
3951
@classmethod
4052
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
41-
return JinaEmbeddingWorker
42-
43-
@classmethod
44-
def mean_pooling(cls, model_output, attention_mask) -> np.ndarray:
45-
token_embeddings = model_output
46-
input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)
47-
48-
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
49-
mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
50-
51-
return sum_embeddings / mask_sum
53+
return PooledNormalizedEmbeddingWorker
5254

5355
@classmethod
5456
def list_supported_models(cls) -> List[Dict[str, Any]]:
@@ -57,7 +59,7 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
5759
Returns:
5860
List[Dict[str, Any]]: A list of dictionaries containing the model information.
5961
"""
60-
return supported_jina_models
62+
return supported_pooled_normalized_models
6163

6264
def _post_process_onnx_output(
6365
self, output: OnnxOutputContext
@@ -67,10 +69,10 @@ def _post_process_onnx_output(
6769
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
6870

6971

70-
class JinaEmbeddingWorker(OnnxTextEmbeddingWorker):
72+
class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
7173
def init_embedding(
7274
self, model_name: str, cache_dir: str, **kwargs
7375
) -> OnnxTextEmbedding:
74-
return JinaOnnxEmbedding(
76+
return PooledNormalizedEmbedding(
7577
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
7678
)

fastembed/text/text_embedding.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from fastembed.common import OnnxProvider
66
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
77
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
8-
from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding
9-
from fastembed.text.mini_lm_embedding import MiniLMOnnxEmbedding
8+
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
9+
from fastembed.text.pooled_embedding import PooledEmbedding
1010
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1111
from fastembed.text.text_embedding_base import TextEmbeddingBase
1212

@@ -15,9 +15,9 @@ class TextEmbedding(TextEmbeddingBase):
1515
EMBEDDINGS_REGISTRY: List[Type[TextEmbeddingBase]] = [
1616
OnnxTextEmbedding,
1717
E5OnnxEmbedding,
18-
JinaOnnxEmbedding,
1918
CLIPOnnxEmbedding,
20-
MiniLMOnnxEmbedding,
19+
PooledNormalizedEmbedding,
20+
PooledEmbedding,
2121
]
2222

2323
@classmethod

tests/test_text_onnx_embeddings.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@
4545
[-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]
4646
),
4747
"jinaai/jina-embeddings-v2-base-de": np.array(
48-
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
48+
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
4949
),
5050
"nomic-ai/nomic-embed-text-v1": np.array(
51-
[0.0061, 0.0103, -0.0296, -0.0242, -0.0170]
51+
[0.3708 , 0.2031, -0.3406, -0.2114, -0.3230]
5252
),
5353
"nomic-ai/nomic-embed-text-v1.5": np.array(
54-
[-1.6531514e-02, 8.5380634e-05, -1.8171231e-01, -3.9333291e-03, 1.2763254e-02]
54+
[-0.15407836, -0.03053198, -3.9138033, 0.1910364, 0.13224715]
5555
),
5656
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
57-
[-0.01554983, 0.0129992, -0.17909265, -0.01062993, 0.00512859]
57+
[-0.12525563, 0.38030425, -3.961622 , 0.04176439, -0.0758301]
5858
),
5959
"thenlper/gte-large": np.array(
6060
[-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]

0 commit comments

Comments
 (0)