@@ -178,127 +178,12 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
178178 for image in images
179179 ]
180180 assert self .processor is not None , "Processor is not initialized"
181- processed = self .processor (image_files )
182-
183- # Dispatch to appropriate handler based on structure.
184- # ColModernVBERT processors divides the original image into
185- # subimages and processes them separately.
186- if isinstance (processed [0 ], list ):
187- encoded , attention_mask , metadata = self ._process_nested_patches (processed )
188- else :
189- encoded , attention_mask , metadata = self ._process_flat_images (
190- processed , # type: ignore[arg-type]
191- len (images ),
192- )
193-
194- onnx_input = {"pixel_values" : encoded , "attention_mask" : attention_mask }
181+ encoded = np .array (self .processor (image_files ))
182+ onnx_input = {"pixel_values" : encoded }
195183 onnx_input = self ._preprocess_onnx_image_input (onnx_input , ** kwargs )
196184 model_output = self .model .run (None , onnx_input ) # type: ignore[union-attr]
197-
198- return OnnxOutputContext (
199- model_output = model_output [0 ],
200- attention_mask = attention_mask , # type: ignore[arg-type]
201- metadata = metadata ,
202- )
203-
204- def _process_nested_patches (
205- self , processed : list [list [NumpyArray ]]
206- ) -> tuple [NumpyArray , NumpyArray , dict [str , Any ]]:
207- """
208- Process nested image patches (from ImageSplitter).
209-
210- Args:
211- processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
212-
213- Returns:
214- tuple: (encoded array, attention_mask, metadata)
215- - encoded: (batch_size, max_patches, C, H, W)
216- - attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
217- - metadata: Dict with 'patch_counts' key
218- """
219- patch_counts = [len (patches ) for patches in processed ]
220- max_patches = max (patch_counts )
221-
222- # Get dimensions from first patch
223- C , H , W = processed [0 ][0 ].shape
224- batch_size = len (processed )
225-
226- # Create padded array
227- encoded = np .zeros ((batch_size , max_patches , C , H , W ), dtype = processed [0 ][0 ].dtype )
228-
229- # Create attention mask (1 for real patches, 0 for padding)
230- attention_mask = np .zeros ((batch_size , max_patches ), dtype = np .int64 )
231-
232- # Fill in patches and attention mask
233- for i , patches in enumerate (processed ):
234- for j , patch in enumerate (patches ):
235- encoded [i , j ] = patch
236- attention_mask [i , j ] = 1
237-
238- metadata = {"patch_counts" : patch_counts }
239- return encoded , attention_mask , metadata # type: ignore[return-value]
240-
241- def _process_flat_images (
242- self , processed : list [NumpyArray ], num_images : int
243- ) -> tuple [NumpyArray , NumpyArray , dict [str , Any ]]:
244- """
245- Process flat image arrays (from standard processors like SiglipImageProcessor).
246-
247- For models expecting 5D input (Idefics3-based), adds patch dimension.
248- For models expecting 4D input, keeps original shape.
249-
250- Args:
251- processed: List of image arrays
252- num_images: Number of images being processed
253-
254- Returns:
255- tuple: (encoded array, attention_mask, metadata)
256- - encoded: (batch_size, C, H, W) for 4D models OR (batch_size, 1, C, H, W) for 5D models
257- - attention_mask: (batch_size, 1) with all ones
258- - metadata: Dict with 'patch_counts' key
259- """
260- encoded = np .array (processed )
261-
262- # Check if model needs patch dimension based on ONNX signature
263- if len (encoded .shape ) == 4 and self ._needs_patch_dimension ():
264- # Add patch dimension for Idefics3-based models: (batch, 1, C, H, W)
265- encoded = encoded [:, np .newaxis , ...]
266-
267- # Determine attention mask shape based on final tensor shape
268- if len (encoded .shape ) == 5 :
269- # 5D tensor: attention_mask shape is (batch, num_patches)
270- attention_mask = np .ones ((num_images , encoded .shape [1 ]), dtype = np .int64 )
271- metadata = {"patch_counts" : [encoded .shape [1 ]] * num_images }
272- else :
273- # 4D tensor: attention_mask shape is (batch, 1)
274- attention_mask = np .ones ((num_images , 1 ), dtype = np .int64 )
275- metadata = {"patch_counts" : [1 ] * num_images }
276-
277- return encoded , attention_mask , metadata # type: ignore[return-value]
278-
279- def _needs_patch_dimension (self ) -> bool :
280- """
281- Determine if this model needs the patch dimension by checking ONNX input shape.
282-
283- Idefics3-based models (like ColModernVBERT) need 5D tensors (batch_size, patch_count, C, H, W).
284- Earlier models (like ColPali v1.3) need 4D tensors (batch_size, C, H, W).
285-
286- Returns:
287- bool: True if pixel_values input has 5 dimensions, False if 4 dimensions
288- """
289- if not hasattr (self , "model" ) or self .model is None :
290- return False
291-
292- # Get pixel_values input metadata
293- for input_meta in self .model .get_inputs ():
294- if input_meta .name == "pixel_values" :
295- # input_meta.shape is a list like
296- # ['batch_size', 'sequence_length', 'num_channels', 'height', 'width']
297- # or ['batch_size', 'num_channels', 'height', 'width']
298- return len (input_meta .shape ) == 5
299-
300- # Default to False for backward compatibility
301- return False
185+ embeddings = model_output [0 ].reshape (len (images ), - 1 )
186+ return OnnxOutputContext (model_output = embeddings )
302187
303188 def _embed_images (
304189 self ,
0 commit comments