Skip to content

Commit 5bd5c0a

Browse files
authored
fix: fix colpali preprocessing, add examples to readme (#487)
1 parent 58ee7cc commit 5bd5c0a

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

README.md

+38
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ embeddings = list(model.embed(documents))
6363

6464
```
6565

66+
Dense text embedding can also be extended with models which are not in the list of supported models.
67+
68+
```python
69+
from fastembed import TextEmbedding
70+
from fastembed.common.model_description import PoolingType, ModelSource
71+
72+
TextEmbedding.add_custom_model(
73+
model="intfloat/multilingual-e5-small",
74+
pooling=PoolingType.MEAN,
75+
normalization=True,
76+
sources=ModelSource(hf="intfloat/multilingual-e5-small"), # can be used with an `url` to load files from a private storage
77+
dim=384,
78+
model_file="onnx/model.onnx", # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
79+
)
80+
model = TextEmbedding(model_name="intfloat/multilingual-e5-small")
81+
embeddings = list(model.embed(documents))
82+
```
6683

6784

6885
### 🔱 Sparse text embeddings
@@ -137,6 +154,27 @@ embeddings = list(model.embed(images))
137154
# ]
138155
```
139156

157+
### Late interaction multimodal models (ColPali)
158+
159+
```python
160+
from fastembed import LateInteractionMultimodalEmbedding
161+
162+
doc_images = [
163+
"./path/to/qdrant_pdf_doc_1_screenshot.jpg",
164+
"./path/to/colpali_pdf_doc_2_screenshot.jpg",
165+
]
166+
167+
query = "What is Qdrant?"
168+
169+
model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16")
170+
doc_images_embeddings = list(model.embed_image(doc_images))
171+
# shape (2, 1030, 128)
172+
# [array([[-0.03353882, -0.02090454, ..., -0.15576172, -0.07678223]], dtype=float32)]
173+
query_embedding = model.embed_text(query)
174+
# shape (1, 20, 128)
175+
# [array([[-0.00218201, 0.14758301, ..., -0.02207947, 0.16833496]], dtype=float32)]
176+
```
177+
140178
### 🔄 Rerankers
141179
```python
142180
from fastembed.rerank.cross_encoder import TextCrossEncoder

fastembed/late_interaction_multimodal/colpali.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,11 @@ def _preprocess_onnx_image_input(
197197
Returns:
198198
Dict[str, NumpyArray]: ONNX input with text placeholders.
199199
"""
200-
201200
onnx_input["input_ids"] = np.array(
202-
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]]
201+
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["pixel_values"]]
203202
)
204203
onnx_input["attention_mask"] = np.array(
205-
[self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]]
204+
[self.EVEN_ATTENTION_MASK for _ in onnx_input["pixel_values"]]
206205
)
207206
return onnx_input
208207

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def _load_onnx_model(
7373
cuda=cuda,
7474
device_id=device_id,
7575
)
76-
assert self.tokenizer is not None
7776
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
77+
assert self.tokenizer is not None
7878
self.processor = load_preprocessor(model_dir=model_dir)
7979

8080
def load_onnx_model(self) -> None:
@@ -159,10 +159,6 @@ def _embed_documents(
159159
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
160160
yield from self._post_process_onnx_text_output(batch) # type: ignore
161161

162-
def _build_onnx_image_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
163-
input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
164-
return {input_name: encoded}
165-
166162
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
167163
with contextlib.ExitStack():
168164
image_files = [
@@ -171,7 +167,7 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
171167
]
172168
assert self.processor is not None, "Processor is not initialized"
173169
encoded = np.array(self.processor(image_files))
174-
onnx_input = self._build_onnx_image_input(encoded)
170+
onnx_input = {"pixel_values": encoded}
175171
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
176172
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
177173
embeddings = model_output[0].reshape(len(images), -1)

0 commit comments

Comments
 (0)