Skip to content

Commit 2fe33c5

Browse files
joeinI8dNLo
andauthored
colpali v1.3 by AndrewOgn (#427)
* wip: design draft * Operators fix * Fix model inputs * Import from fastembed.late_interaction_multimodal * Fixed method misspelling * Tests, which do not run in CI Docstring improvements * Fix tests * Bump colpali to version v1.3 * Remove colpali v1.2 * Remove colpali v1.2 from tests * partial fix of change requests: descriptions docs black * query_max_length * black colpali * Added comment for EMPTY_TEXT_PLACEHOLDER * Review fixes * Removed redundant VISUAL_PROMPT_PREFIX * type fix + model info * new: add specific model path to colpali * fix: revert accidental renaming * fix: remove max_length from encode_batch * refactoring: remove redundant QUERY_MAX_LENGTH variable * refactoring: remove redundant document marker token id * fix: fix type hints, fix tests, handle single image path embed, rename model, update description * license: add gemma to NOTICE * fix: do not run colpali test in ci * fix: fix colpali test --------- Co-authored-by: d.rudenko <[email protected]>
1 parent 969ea29 commit 2fe33c5

File tree

9 files changed

+826
-4
lines changed

9 files changed

+826
-4
lines changed

NOTICE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,11 @@ This distribution includes the following Jina AI models, each with its respectiv
1212

1313
These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.
1414

15+
This distribution includes the following Google models, each with its respective license:
16+
- vidore/colpali-v1.3
17+
- License: gemma
18+
19+
Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms
20+
1521
Additional Notes:
1622
This project also includes third-party libraries with their respective licenses. Please refer to the documentation of each library for details regarding its usage and licensing terms.

fastembed/image/image_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def embed(
7979
**kwargs: Any,
8080
) -> Iterable[NumpyArray]:
8181
"""
82-
Encode a list of documents into list of embeddings.
83-
We use mean pooling with attention so that the model can handle variable-length inputs.
82+
Encode a list of images into list of embeddings.
8483
8584
Args:
8685
images: Iterator of image paths or single image path to embed

fastembed/image/transform/operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) ->
139139
@classmethod
140140
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
141141
mode = config.get("image_processor_type", "CLIPImageProcessor")
142-
if mode == "CLIPImageProcessor":
142+
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
143143
if config.get("do_resize", False):
144144
size = config["size"]
145145
if "shortest_edge" in size:
@@ -202,7 +202,7 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> Non
202202
@staticmethod
203203
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
204204
mode = config.get("image_processor_type", "CLIPImageProcessor")
205-
if mode == "CLIPImageProcessor":
205+
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
206206
if config.get("do_center_crop", False):
207207
crop_size_raw = config["crop_size"]
208208
crop_size: tuple[int, int]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import (
2+
LateInteractionMultimodalEmbedding,
3+
)
4+
5+
__all__ = ["LateInteractionMultimodalEmbedding"]
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from typing import Any, Iterable, Optional, Sequence, Type, Union
2+
3+
import numpy as np
4+
from tokenizers import Encoding
5+
6+
from fastembed.common import OnnxProvider, ImageInput
7+
from fastembed.common.onnx_model import OnnxOutputContext
8+
from fastembed.common.utils import define_cache_dir
9+
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
10+
LateInteractionMultimodalEmbeddingBase,
11+
)
12+
from fastembed.late_interaction_multimodal.onnx_multimodal_model import (
13+
OnnxMultimodalModel,
14+
TextEmbeddingWorker,
15+
ImageEmbeddingWorker,
16+
)
17+
18+
supported_colpali_models = [
19+
{
20+
"model": "Qdrant/colpali-v1.3-fp16",
21+
"dim": 128,
22+
"description": "Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.",
23+
"license": "mit",
24+
"size_in_GB": 6.5,
25+
"sources": {
26+
"hf": "Qdrant/colpali-v1.3-fp16",
27+
},
28+
"additional_files": [
29+
"model.onnx_data",
30+
],
31+
"model_file": "model.onnx",
32+
},
33+
]
34+
35+
36+
class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]):
37+
QUERY_PREFIX = "Query: "
38+
BOS_TOKEN = "<s>"
39+
PAD_TOKEN = "<pad>"
40+
QUERY_MARKER_TOKEN_ID = [2, 5098]
41+
IMAGE_PLACEHOLDER_SIZE = (3, 448, 448)
42+
EMPTY_TEXT_PLACEHOLDER = np.array(
43+
[257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]
44+
) # This is a tokenization of '<image>' * 1024 + '<bos>Describe the image.\n' line which is used as placeholder
45+
# while processing an image
46+
EVEN_ATTENTION_MASK = np.array([1] * 1030)
47+
48+
def __init__(
49+
self,
50+
model_name: str,
51+
cache_dir: Optional[str] = None,
52+
threads: Optional[int] = None,
53+
providers: Optional[Sequence[OnnxProvider]] = None,
54+
cuda: bool = False,
55+
device_ids: Optional[list[int]] = None,
56+
lazy_load: bool = False,
57+
device_id: Optional[int] = None,
58+
specific_model_path: Optional[str] = None,
59+
**kwargs,
60+
):
61+
"""
62+
Args:
63+
model_name (str): The name of the model to use.
64+
cache_dir (str, optional): The path to the cache directory.
65+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
66+
Defaults to `fastembed_cache` in the system's temp directory.
67+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
68+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
69+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
70+
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
71+
Defaults to False.
72+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
73+
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
74+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
75+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
76+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
77+
78+
Raises:
79+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
80+
"""
81+
82+
super().__init__(model_name, cache_dir, threads, **kwargs)
83+
self.providers = providers
84+
self.lazy_load = lazy_load
85+
86+
# List of device ids, that can be used for data parallel processing in workers
87+
self.device_ids = device_ids
88+
self.cuda = cuda
89+
90+
# This device_id will be used if we need to load model in current process
91+
if device_id is not None:
92+
self.device_id = device_id
93+
elif self.device_ids is not None:
94+
self.device_id = self.device_ids[0]
95+
else:
96+
self.device_id = None
97+
98+
self.model_description = self._get_model_description(model_name)
99+
self.cache_dir = define_cache_dir(cache_dir)
100+
101+
self._model_dir = self.download_model(
102+
self.model_description,
103+
self.cache_dir,
104+
local_files_only=self._local_files_only,
105+
specific_model_path=specific_model_path,
106+
)
107+
self.mask_token_id = None
108+
self.pad_token_id = None
109+
110+
if not self.lazy_load:
111+
self.load_onnx_model()
112+
113+
@classmethod
114+
def list_supported_models(cls) -> list[dict[str, Any]]:
115+
"""Lists the supported models.
116+
117+
Returns:
118+
list[dict[str, Any]]: A list of dictionaries containing the model information.
119+
"""
120+
return supported_colpali_models
121+
122+
def load_onnx_model(self) -> None:
123+
self._load_onnx_model(
124+
model_dir=self._model_dir,
125+
model_file=self.model_description["model_file"],
126+
threads=self.threads,
127+
providers=self.providers,
128+
cuda=self.cuda,
129+
device_id=self.device_id,
130+
)
131+
132+
def _post_process_onnx_image_output(
133+
self,
134+
output: OnnxOutputContext,
135+
) -> Iterable[np.ndarray]:
136+
"""
137+
Post-process the ONNX model output to convert it into a usable format.
138+
139+
Args:
140+
output (OnnxOutputContext): The raw output from the ONNX model.
141+
142+
Returns:
143+
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
144+
"""
145+
return output.model_output.reshape(
146+
output.model_output.shape[0], -1, self.model_description["dim"]
147+
).astype(np.float32)
148+
149+
def _post_process_onnx_text_output(
150+
self,
151+
output: OnnxOutputContext,
152+
) -> Iterable[np.ndarray]:
153+
"""
154+
Post-process the ONNX model output to convert it into a usable format.
155+
156+
Args:
157+
output (OnnxOutputContext): The raw output from the ONNX model.
158+
159+
Returns:
160+
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
161+
"""
162+
return output.model_output.astype(np.float32)
163+
164+
def tokenize(self, documents: list[str], **_) -> list[Encoding]:
165+
texts_query: list[str] = []
166+
for query in documents:
167+
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
168+
query += "\n"
169+
170+
texts_query.append(query)
171+
encoded = self.tokenizer.encode_batch(texts_query)
172+
return encoded
173+
174+
def _preprocess_onnx_text_input(
175+
self, onnx_input: dict[str, np.ndarray], **kwargs
176+
) -> dict[str, np.ndarray]:
177+
onnx_input["input_ids"] = np.array(
178+
[
179+
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
180+
for input_ids in onnx_input["input_ids"]
181+
]
182+
)
183+
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
184+
onnx_input["pixel_values"] = np.array(
185+
[empty_image_placeholder for _ in onnx_input["input_ids"]]
186+
)
187+
return onnx_input
188+
189+
def _preprocess_onnx_image_input(
190+
self, onnx_input: dict[str, np.ndarray], **kwargs
191+
) -> dict[str, np.ndarray]:
192+
"""
193+
Add placeholders for text input when processing image data for ONNX.
194+
Args:
195+
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
196+
**kwargs: Additional arguments.
197+
Returns:
198+
Dict[str, np.ndarray]: ONNX input with text placeholders.
199+
"""
200+
201+
onnx_input["input_ids"] = np.array(
202+
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]]
203+
)
204+
onnx_input["attention_mask"] = np.array(
205+
[self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]]
206+
)
207+
return onnx_input
208+
209+
def embed_text(
210+
self,
211+
documents: Union[str, Iterable[str]],
212+
batch_size: int = 256,
213+
parallel: Optional[int] = None,
214+
**kwargs,
215+
) -> Iterable[np.ndarray]:
216+
"""
217+
Encode a list of documents into list of embeddings.
218+
219+
Args:
220+
documents: Iterator of documents or single document to embed
221+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
222+
parallel:
223+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
224+
If 0, use all available cores.
225+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
226+
227+
Returns:
228+
List of embeddings, one per document
229+
"""
230+
yield from self._embed_documents(
231+
model_name=self.model_name,
232+
cache_dir=str(self.cache_dir),
233+
documents=documents,
234+
batch_size=batch_size,
235+
parallel=parallel,
236+
providers=self.providers,
237+
cuda=self.cuda,
238+
device_ids=self.device_ids,
239+
**kwargs,
240+
)
241+
242+
def embed_image(
243+
self,
244+
images: ImageInput,
245+
batch_size: int = 16,
246+
parallel: Optional[int] = None,
247+
**kwargs,
248+
) -> Iterable[np.ndarray]:
249+
"""
250+
Encode a list of images into list of embeddings.
251+
252+
Args:
253+
images: Iterator of image paths or single image path to embed
254+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
255+
parallel:
256+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
257+
If 0, use all available cores.
258+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
259+
260+
Returns:
261+
List of embeddings, one per document
262+
"""
263+
yield from self._embed_images(
264+
model_name=self.model_name,
265+
cache_dir=str(self.cache_dir),
266+
images=images,
267+
batch_size=batch_size,
268+
parallel=parallel,
269+
providers=self.providers,
270+
cuda=self.cuda,
271+
device_ids=self.device_ids,
272+
**kwargs,
273+
)
274+
275+
@classmethod
276+
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]:
277+
return ColPaliTextEmbeddingWorker
278+
279+
@classmethod
280+
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]:
281+
return ColPaliImageEmbeddingWorker
282+
283+
284+
class ColPaliTextEmbeddingWorker(TextEmbeddingWorker):
285+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
286+
return ColPali(
287+
model_name=model_name,
288+
cache_dir=cache_dir,
289+
threads=1,
290+
**kwargs,
291+
)
292+
293+
294+
class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker):
295+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
296+
return ColPali(
297+
model_name=model_name,
298+
cache_dir=cache_dir,
299+
threads=1,
300+
**kwargs,
301+
)

0 commit comments

Comments
 (0)