Skip to content

Commit c8fff66

Browse files
authored
Colbert (#248)
* new: add late interaction embedding, colbert * new: update imports * new: add comments * fix: rollback mp methods * fix: restore existing padding after embed query * fix: fix OnnxOutputContext in onnx embed, fix preprocessing for colbert
1 parent 85aaae4 commit c8fff66

11 files changed

+512
-19
lines changed

fastembed/__init__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
from fastembed.image import ImageEmbedding
44
from fastembed.text import TextEmbedding
55
from fastembed.sparse import SparseTextEmbedding, SparseEmbedding
6-
6+
from fastembed.late_interaction import LateInteractionTextEmbedding
77

88
try:
99
version = importlib.metadata.version("fastembed")
1010
except importlib.metadata.PackageNotFoundError as _:
1111
version = importlib.metadata.version("fastembed-gpu")
1212

1313
__version__ = version
14-
__all__ = ["TextEmbedding", "SparseTextEmbedding", "SparseEmbedding", "ImageEmbedding"]
15-
14+
__all__ = [
15+
"TextEmbedding",
16+
"SparseTextEmbedding",
17+
"SparseEmbedding",
18+
"ImageEmbedding",
19+
"LateInteractionTextEmbedding",
20+
]

fastembed/common/onnx_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def __init__(self) -> None:
3333
self.model = None
3434
self.tokenizer = None
3535

36-
def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
36+
def _preprocess_onnx_input(
37+
self, onnx_input: Dict[str, np.ndarray], **kwargs
38+
) -> Dict[str, np.ndarray]:
3739
"""
3840
Preprocess the onnx input.
3941
"""

fastembed/image/onnx_embedding.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def embed(
112112
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
113113
return OnnxImageEmbeddingWorker
114114

115-
def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
115+
def _preprocess_onnx_input(
116+
self, onnx_input: Dict[str, np.ndarray], **kwargs
117+
) -> Dict[str, np.ndarray]:
116118
"""
117119
Preprocess the onnx input.
118120
"""

fastembed/image/onnx_image_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __init__(self) -> None:
2828
super().__init__()
2929
self.processor = None
3030

31-
def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
31+
def _preprocess_onnx_input(
32+
self, onnx_input: Dict[str, np.ndarray], **kwargs
33+
) -> Dict[str, np.ndarray]:
3234
"""
3335
Preprocess the onnx input.
3436
"""
@@ -49,16 +51,14 @@ def load_onnx_model(
4951
def _build_onnx_input(self, encoded: np.ndarray) -> Dict[str, np.ndarray]:
5052
return {node.name: encoded for node in self.model.get_inputs()}
5153

52-
def onnx_embed(self, images: List[PathInput]) -> OnnxOutputContext:
54+
def onnx_embed(self, images: List[PathInput], **kwargs) -> OnnxOutputContext:
5355
with contextlib.ExitStack():
5456
image_files = [Image.open(image) for image in images]
5557
encoded = self.processor(image_files)
5658
onnx_input = self._build_onnx_input(encoded)
5759
onnx_input = self._preprocess_onnx_input(onnx_input)
5860
model_output = self.model.run(None, onnx_input)
59-
6061
embeddings = model_output[0].reshape(len(images), -1)
61-
6262
return OnnxOutputContext(
6363
model_output=embeddings
6464
)
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from fastembed.late_interaction.late_interaction_text_embedding import LateInteractionTextEmbedding
2+
3+
4+
__all__ = ["LateInteractionTextEmbedding"]

fastembed/late_interaction/colbert.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from typing import Any, Dict, Iterable, List, Optional, Union, Type, Sequence
2+
import string
3+
4+
import numpy as np
5+
from tokenizers import Encoding
6+
7+
from fastembed.common import OnnxProvider
8+
from fastembed.common.onnx_model import OnnxOutputContext
9+
from fastembed.common.utils import define_cache_dir
10+
from fastembed.late_interaction.late_interaction_embedding_base import (
11+
LateInteractionTextEmbeddingBase,
12+
)
13+
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
14+
15+
16+
supported_colbert_models = [
17+
{
18+
"model": "colbert-ir/colbertv2.0",
19+
"dim": 128,
20+
"description": "Late interaction model",
21+
"size_in_GB": 0.44,
22+
"sources": {
23+
"hf": "colbert-ir/colbertv2.0",
24+
},
25+
"model_file": "model.onnx",
26+
}
27+
]
28+
29+
30+
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
31+
QUERY_MARKER_TOKEN_ID = 1
32+
DOCUMENT_MARKER_TOKEN_ID = 2
33+
MIN_QUERY_LENGTH = 32
34+
MASK_TOKEN = "[MASK]"
35+
36+
def _post_process_onnx_output(
37+
self, output: OnnxOutputContext, is_doc: bool = True
38+
) -> Iterable[np.ndarray]:
39+
if not is_doc:
40+
return output.model_output.astype(np.float32)
41+
42+
for i, token_sequence in enumerate(output.input_ids):
43+
for j, token_id in enumerate(token_sequence):
44+
if token_id in self.skip_list or token_id == self.pad_token_id:
45+
output.attention_mask[i, j] = 0
46+
47+
output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
48+
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
49+
norm_clamped = np.maximum(norm, 1e-12)
50+
output.model_output /= norm_clamped
51+
return output.model_output.astype(np.float32)
52+
53+
def _preprocess_onnx_input(
54+
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
55+
) -> Dict[str, np.ndarray]:
56+
if is_doc:
57+
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
58+
else:
59+
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
60+
return onnx_input
61+
62+
def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
63+
return (
64+
self._tokenize_documents(documents=documents)
65+
if is_doc
66+
else self._tokenize_query(query=next(iter(documents)))
67+
)
68+
69+
def _tokenize_query(self, query: str) -> List[Encoding]:
70+
# ". " is added to a query to be replaced with a special query token
71+
query = [f". {query}"]
72+
encoded = self.tokenizer.encode_batch(query)
73+
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
74+
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
75+
prev_padding = None
76+
if self.tokenizer.padding:
77+
prev_padding = self.tokenizer.padding
78+
self.tokenizer.enable_padding(
79+
pad_token=self.MASK_TOKEN, pad_id=self.mask_token_id, length=self.MIN_QUERY_LENGTH
80+
)
81+
encoded = self.tokenizer.encode_batch(query)
82+
if prev_padding is None:
83+
self.tokenizer.no_padding()
84+
else:
85+
self.tokenizer.enable_padding(**prev_padding)
86+
return encoded
87+
88+
def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
89+
# ". " is added to a document to be replaced with a special document token
90+
documents = [". " + doc for doc in documents]
91+
encoded = self.tokenizer.encode_batch(documents)
92+
return encoded
93+
94+
@classmethod
95+
def list_supported_models(cls) -> List[Dict[str, Any]]:
96+
"""Lists the supported models.
97+
98+
Returns:
99+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
100+
"""
101+
return supported_colbert_models
102+
103+
def __init__(
104+
self,
105+
model_name: str,
106+
cache_dir: Optional[str] = None,
107+
threads: Optional[int] = None,
108+
providers: Optional[Sequence[OnnxProvider]] = None,
109+
**kwargs,
110+
):
111+
"""
112+
Args:
113+
model_name (str): The name of the model to use.
114+
cache_dir (str, optional): The path to the cache directory.
115+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
116+
Defaults to `fastembed_cache` in the system's temp directory.
117+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
118+
119+
Raises:
120+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
121+
"""
122+
123+
super().__init__(model_name, cache_dir, threads, **kwargs)
124+
125+
model_description = self._get_model_description(model_name)
126+
cache_dir = define_cache_dir(cache_dir)
127+
128+
model_dir = self.download_model(
129+
model_description, cache_dir, local_files_only=self._local_files_only
130+
)
131+
132+
self.load_onnx_model(
133+
model_dir=model_dir,
134+
model_file=model_description["model_file"],
135+
threads=threads,
136+
providers=providers,
137+
)
138+
self.mask_token_id = self.special_token_to_id["[MASK]"]
139+
self.pad_token_id = self.tokenizer.padding["pad_id"]
140+
141+
self.skip_list = {
142+
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
143+
for symbol in string.punctuation
144+
}
145+
146+
def embed(
147+
self,
148+
documents: Union[str, Iterable[str]],
149+
batch_size: int = 256,
150+
parallel: Optional[int] = None,
151+
**kwargs,
152+
) -> Iterable[np.ndarray]:
153+
"""
154+
Encode a list of documents into list of embeddings.
155+
We use mean pooling with attention so that the model can handle variable-length inputs.
156+
157+
Args:
158+
documents: Iterator of documents or single document to embed
159+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
160+
parallel:
161+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
162+
If 0, use all available cores.
163+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
164+
165+
Returns:
166+
List of embeddings, one per document
167+
"""
168+
yield from self._embed_documents(
169+
model_name=self.model_name,
170+
cache_dir=str(self.cache_dir),
171+
documents=documents,
172+
batch_size=batch_size,
173+
parallel=parallel,
174+
)
175+
176+
def query_embed(self, query: Union[str, List[str]], **kwargs) -> np.ndarray:
177+
if isinstance(query, str):
178+
query = [query]
179+
180+
for text in query:
181+
yield from self._post_process_onnx_output(
182+
self.onnx_embed([text], is_doc=False), is_doc=False
183+
)
184+
185+
@classmethod
186+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
187+
return ColbertEmbeddingWorker
188+
189+
190+
class ColbertEmbeddingWorker(TextEmbeddingWorker):
191+
def init_embedding(
192+
self,
193+
model_name: str,
194+
cache_dir: str,
195+
) -> Colbert:
196+
return Colbert(model_name=model_name, cache_dir=cache_dir, threads=1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Iterable, Optional, Union
2+
3+
import numpy as np
4+
5+
from fastembed.common.model_management import ModelManagement
6+
7+
8+
class LateInteractionTextEmbeddingBase(ModelManagement):
9+
def __init__(
10+
self,
11+
model_name: str,
12+
cache_dir: Optional[str] = None,
13+
threads: Optional[int] = None,
14+
**kwargs,
15+
):
16+
self.model_name = model_name
17+
self.cache_dir = cache_dir
18+
self.threads = threads
19+
self._local_files_only = kwargs.pop("local_files_only", False)
20+
21+
def embed(
22+
self,
23+
documents: Union[str, Iterable[str]],
24+
batch_size: int = 256,
25+
parallel: Optional[int] = None,
26+
**kwargs,
27+
) -> Iterable[np.ndarray]:
28+
raise NotImplementedError()
29+
30+
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
31+
"""
32+
Embeds a list of text passages into a list of embeddings.
33+
34+
Args:
35+
texts (Iterable[str]): The list of texts to embed.
36+
**kwargs: Additional keyword argument to pass to the embed method.
37+
38+
Yields:
39+
Iterable[np.ndarray]: The embeddings.
40+
"""
41+
42+
# This is model-specific, so that different models can have specialized implementations
43+
yield from self.embed(texts, **kwargs)
44+
45+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
46+
"""
47+
Embeds queries
48+
49+
Args:
50+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
51+
52+
Returns:
53+
Iterable[np.ndarray]: The embeddings.
54+
"""
55+
56+
# This is model-specific, so that different models can have specialized implementations
57+
if isinstance(query, str):
58+
yield from self.embed([query], **kwargs)
59+
if isinstance(query, Iterable):
60+
yield from self.embed(query, **kwargs)

0 commit comments

Comments
 (0)