Skip to content

Commit b05877d

Browse files
new: Added jina embedding v3 (#428)
* new: Added jina embedding v3 * refactor: Changed dim to int value * new: Updated notice * new: Extended text embedding with query embed and passage embed * fix: Fix lazy load in query and passage embed * tests: Added test for multitask embeddings * nit: Remove cache dir from tests * tests: Updated tests * improve: Improve task selection * fix: Fix ci * fix: Update fastembed/text/multitask_embedding.py Co-authored-by: George <[email protected]> * Update fastembed/text/multitask_embedding.py Co-authored-by: George <[email protected]> * fix: Pass task id using kwargs to parallel processor * tests: Added test for task assignment * prefer enums over ints * tests: Added test for parallel * improve: Updated model description * fix: Fix ci * fix: Fix ci * refactor: Refactor query_embed and passage_embed * tests: Added task propagation to parallel * refactor: Set default task as retrieval passage * chore: Update default task in tests --------- Co-authored-by: George <[email protected]>
1 parent 54f6cd9 commit b05877d

5 files changed

+388
-2
lines changed

NOTICE

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ This distribution includes the following Jina AI models, each with its respectiv
77
- License: cc-by-nc-4.0
88
- jinaai/jina-reranker-v2-base-multilingual
99
- License: cc-by-nc-4.0
10+
- jinaai/jina-embeddings-v3
11+
- License: cc-by-nc-4.0
1012

1113
These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.
1214

fastembed/text/multitask_embedding.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from enum import Enum
2+
from typing import Any, Type, Iterable, Union, Optional
3+
4+
import numpy as np
5+
6+
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
7+
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
8+
from fastembed.text.onnx_text_model import TextEmbeddingWorker
9+
10+
supported_multitask_models = [
11+
{
12+
"model": "jinaai/jina-embeddings-v3",
13+
"dim": 1024,
14+
"tasks": {
15+
"retrieval.query": 0,
16+
"retrieval.passage": 1,
17+
"separation": 2,
18+
"classification": 3,
19+
"text-matching": 4,
20+
},
21+
"description": "Multi-task unimodal (text) embedding model, multi-lingual (~100), 1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year.",
22+
"license": "cc-by-nc-4.0",
23+
"size_in_GB": 2.29,
24+
"sources": {
25+
"hf": "jinaai/jina-embeddings-v3",
26+
},
27+
"model_file": "onnx/model.onnx",
28+
"additional_files": ["onnx/model.onnx_data"],
29+
},
30+
]
31+
32+
33+
class Task(int, Enum):
34+
RETRIEVAL_QUERY = 0
35+
RETRIEVAL_PASSAGE = 1
36+
SEPARATION = 2
37+
CLASSIFICATION = 3
38+
TEXT_MATCHING = 4
39+
40+
41+
class JinaEmbeddingV3(PooledNormalizedEmbedding):
42+
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
43+
QUERY_TASK = Task.RETRIEVAL_QUERY
44+
45+
def __init__(self, *args, **kwargs):
46+
super().__init__(*args, **kwargs)
47+
self._current_task_id = self.PASSAGE_TASK
48+
49+
@classmethod
50+
def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
51+
return JinaEmbeddingV3Worker
52+
53+
@classmethod
54+
def list_supported_models(cls) -> list[dict[str, Any]]:
55+
return supported_multitask_models
56+
57+
def _preprocess_onnx_input(
58+
self, onnx_input: dict[str, np.ndarray], **kwargs
59+
) -> dict[str, np.ndarray]:
60+
onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64)
61+
return onnx_input
62+
63+
def embed(
64+
self,
65+
documents: Union[str, Iterable[str]],
66+
batch_size: int = 256,
67+
parallel: Optional[int] = None,
68+
task_id: int = PASSAGE_TASK,
69+
**kwargs,
70+
) -> Iterable[np.ndarray]:
71+
self._current_task_id = task_id
72+
kwargs["task_id"] = task_id
73+
yield from super().embed(documents, batch_size, parallel, **kwargs)
74+
75+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
76+
self._current_task_id = self.QUERY_TASK
77+
yield from super().embed(query, **kwargs)
78+
79+
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
80+
self._current_task_id = self.PASSAGE_TASK
81+
yield from super().embed(texts, **kwargs)
82+
83+
84+
class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
85+
def init_embedding(
86+
self,
87+
model_name: str,
88+
cache_dir: str,
89+
**kwargs,
90+
) -> JinaEmbeddingV3:
91+
model = JinaEmbeddingV3(
92+
model_name=model_name,
93+
cache_dir=cache_dir,
94+
threads=1,
95+
**kwargs,
96+
)
97+
model._current_task_id = kwargs["task_id"]
98+
return model

fastembed/text/text_embedding.py

+29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
99
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
1010
from fastembed.text.pooled_embedding import PooledEmbedding
11+
from fastembed.text.multitask_embedding import JinaEmbeddingV3
1112
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1213
from fastembed.text.text_embedding_base import TextEmbeddingBase
1314

@@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
1920
CLIPOnnxEmbedding,
2021
PooledNormalizedEmbedding,
2122
PooledEmbedding,
23+
JinaEmbeddingV3,
2224
]
2325

2426
@classmethod
@@ -113,3 +115,30 @@ def embed(
113115
List of embeddings, one per document
114116
"""
115117
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
118+
119+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
120+
"""
121+
Embeds queries
122+
123+
Args:
124+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
125+
126+
Returns:
127+
Iterable[np.ndarray]: The embeddings.
128+
"""
129+
# This is model-specific, so that different models can have specialized implementations
130+
yield from self.model.query_embed(query, **kwargs)
131+
132+
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
133+
"""
134+
Embeds a list of text passages into a list of embeddings.
135+
136+
Args:
137+
texts (Iterable[str]): The list of texts to embed.
138+
**kwargs: Additional keyword argument to pass to the embed method.
139+
140+
Yields:
141+
Iterable[SparseEmbedding]: The sparse embeddings.
142+
"""
143+
# This is model-specific, so that different models can have specialized implementations
144+
yield from self.model.passage_embed(texts, **kwargs)

0 commit comments

Comments
 (0)