Skip to content

Commit 8b9f50c

Browse files
committed
refactor: move colmodernvbert related onnx embed to its class
1 parent ef9c496 commit 8b9f50c

2 files changed

Lines changed: 68 additions & 119 deletions

File tree

fastembed/late_interaction_multimodal/colmodernvbert.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import contextlib
12
from typing import Any, Iterable, Type, Optional, Sequence
23
import json
34

45
import numpy as np
56
from tokenizers import Encoding
7+
from PIL import Image
68

79
from fastembed.common import ImageInput
810
from fastembed.common.model_description import DenseModelDescription, ModelSource
@@ -211,6 +213,68 @@ def token_count(
211213
token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)])
212214
return token_num
213215

216+
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
217+
with contextlib.ExitStack() as stack:
218+
image_files = [
219+
stack.enter_context(Image.open(image))
220+
if not isinstance(image, Image.Image)
221+
else image
222+
for image in images
223+
]
224+
assert self.processor is not None, "Processor is not initialized"
225+
processed = self.processor(image_files)
226+
encoded, attention_mask, metadata = self._process_nested_patches(processed) # type: ignore[arg-type]
227+
228+
onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask}
229+
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
230+
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
231+
232+
return OnnxOutputContext(
233+
model_output=model_output[0],
234+
attention_mask=attention_mask, # type: ignore[arg-type]
235+
metadata=metadata,
236+
)
237+
238+
@staticmethod
239+
def _process_nested_patches(
240+
processed: list[list[NumpyArray]],
241+
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
242+
"""
243+
Process nested image patches (from ImageSplitter).
244+
245+
Args:
246+
processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
247+
248+
Returns:
249+
tuple: (encoded array, attention_mask, metadata)
250+
- encoded: (batch_size, max_patches, C, H, W)
251+
- attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
252+
- metadata: Dict with 'patch_counts' key
253+
"""
254+
patch_counts = [len(patches) for patches in processed]
255+
max_patches = max(patch_counts)
256+
257+
# Get dimensions from first patch
258+
channels, height, width = processed[0][0].shape
259+
batch_size = len(processed)
260+
261+
# Create padded array
262+
encoded = np.zeros(
263+
(batch_size, max_patches, channels, height, width), dtype=processed[0][0].dtype
264+
)
265+
266+
# Create attention mask (1 for real patches, 0 for padding)
267+
attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
268+
269+
# Fill in patches and attention mask
270+
for i, patches in enumerate(processed):
271+
for j, patch in enumerate(patches):
272+
encoded[i, j] = patch
273+
attention_mask[i, j] = 1
274+
275+
metadata = {"patch_counts": patch_counts}
276+
return encoded, attention_mask, metadata # type: ignore[return-value]
277+
214278
def _preprocess_onnx_image_input(
215279
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
216280
) -> dict[str, NumpyArray]:

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 4 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -178,127 +178,12 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
178178
for image in images
179179
]
180180
assert self.processor is not None, "Processor is not initialized"
181-
processed = self.processor(image_files)
182-
183-
# Dispatch to appropriate handler based on structure.
184-
# ColModernVBERT processors divides the original image into
185-
# subimages and processes them separately.
186-
if isinstance(processed[0], list):
187-
encoded, attention_mask, metadata = self._process_nested_patches(processed)
188-
else:
189-
encoded, attention_mask, metadata = self._process_flat_images(
190-
processed, # type: ignore[arg-type]
191-
len(images),
192-
)
193-
194-
onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask}
181+
encoded = np.array(self.processor(image_files))
182+
onnx_input = {"pixel_values": encoded}
195183
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
196184
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
197-
198-
return OnnxOutputContext(
199-
model_output=model_output[0],
200-
attention_mask=attention_mask, # type: ignore[arg-type]
201-
metadata=metadata,
202-
)
203-
204-
def _process_nested_patches(
205-
self, processed: list[list[NumpyArray]]
206-
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
207-
"""
208-
Process nested image patches (from ImageSplitter).
209-
210-
Args:
211-
processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
212-
213-
Returns:
214-
tuple: (encoded array, attention_mask, metadata)
215-
- encoded: (batch_size, max_patches, C, H, W)
216-
- attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
217-
- metadata: Dict with 'patch_counts' key
218-
"""
219-
patch_counts = [len(patches) for patches in processed]
220-
max_patches = max(patch_counts)
221-
222-
# Get dimensions from first patch
223-
C, H, W = processed[0][0].shape
224-
batch_size = len(processed)
225-
226-
# Create padded array
227-
encoded = np.zeros((batch_size, max_patches, C, H, W), dtype=processed[0][0].dtype)
228-
229-
# Create attention mask (1 for real patches, 0 for padding)
230-
attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
231-
232-
# Fill in patches and attention mask
233-
for i, patches in enumerate(processed):
234-
for j, patch in enumerate(patches):
235-
encoded[i, j] = patch
236-
attention_mask[i, j] = 1
237-
238-
metadata = {"patch_counts": patch_counts}
239-
return encoded, attention_mask, metadata # type: ignore[return-value]
240-
241-
def _process_flat_images(
242-
self, processed: list[NumpyArray], num_images: int
243-
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
244-
"""
245-
Process flat image arrays (from standard processors like SiglipImageProcessor).
246-
247-
For models expecting 5D input (Idefics3-based), adds patch dimension.
248-
For models expecting 4D input, keeps original shape.
249-
250-
Args:
251-
processed: List of image arrays
252-
num_images: Number of images being processed
253-
254-
Returns:
255-
tuple: (encoded array, attention_mask, metadata)
256-
- encoded: (batch_size, C, H, W) for 4D models OR (batch_size, 1, C, H, W) for 5D models
257-
- attention_mask: (batch_size, 1) with all ones
258-
- metadata: Dict with 'patch_counts' key
259-
"""
260-
encoded = np.array(processed)
261-
262-
# Check if model needs patch dimension based on ONNX signature
263-
if len(encoded.shape) == 4 and self._needs_patch_dimension():
264-
# Add patch dimension for Idefics3-based models: (batch, 1, C, H, W)
265-
encoded = encoded[:, np.newaxis, ...]
266-
267-
# Determine attention mask shape based on final tensor shape
268-
if len(encoded.shape) == 5:
269-
# 5D tensor: attention_mask shape is (batch, num_patches)
270-
attention_mask = np.ones((num_images, encoded.shape[1]), dtype=np.int64)
271-
metadata = {"patch_counts": [encoded.shape[1]] * num_images}
272-
else:
273-
# 4D tensor: attention_mask shape is (batch, 1)
274-
attention_mask = np.ones((num_images, 1), dtype=np.int64)
275-
metadata = {"patch_counts": [1] * num_images}
276-
277-
return encoded, attention_mask, metadata # type: ignore[return-value]
278-
279-
def _needs_patch_dimension(self) -> bool:
280-
"""
281-
Determine if this model needs the patch dimension by checking ONNX input shape.
282-
283-
Idefics3-based models (like ColModernVBERT) need 5D tensors (batch_size, patch_count, C, H, W).
284-
Earlier models (like ColPali v1.3) need 4D tensors (batch_size, C, H, W).
285-
286-
Returns:
287-
bool: True if pixel_values input has 5 dimensions, False if 4 dimensions
288-
"""
289-
if not hasattr(self, "model") or self.model is None:
290-
return False
291-
292-
# Get pixel_values input metadata
293-
for input_meta in self.model.get_inputs():
294-
if input_meta.name == "pixel_values":
295-
# input_meta.shape is a list like
296-
# ['batch_size', 'sequence_length', 'num_channels', 'height', 'width']
297-
# or ['batch_size', 'num_channels', 'height', 'width']
298-
return len(input_meta.shape) == 5
299-
300-
# Default to False for backward compatibility
301-
return False
185+
embeddings = model_output[0].reshape(len(images), -1)
186+
return OnnxOutputContext(model_output=embeddings)
302187

303188
def _embed_images(
304189
self,

0 commit comments

Comments
 (0)