Skip to content

Commit a77fba3

Browse files
committed
wip: type hints for colpali
1 parent 0fa1596 commit a77fba3

6 files changed

+46
-41
lines changed

fastembed/image/onnx_image_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from multiprocessing import get_all_start_methods
44
from pathlib import Path
5-
from typing import Any, Iterable, Optional, Sequence, Type, Union
5+
from typing import Any, Iterable, Optional, Sequence, Type, Union, get_args
66

77
import numpy as np
88
from PIL import Image
@@ -92,7 +92,7 @@ def _embed_images(
9292
) -> Iterable[T]:
9393
is_small = False
9494

95-
if isinstance(images, (str, Path, Image.Image)):
95+
if isinstance(images, get_args(ImageInput)):
9696
images = [images]
9797
is_small = True
9898

fastembed/late_interaction_multimodal/colpali.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from fastembed.common import OnnxProvider, ImageInput
77
from fastembed.common.onnx_model import OnnxOutputContext
8+
from fastembed.common.types import NumpyArray
89
from fastembed.common.utils import define_cache_dir
910
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
1011
LateInteractionMultimodalEmbeddingBase,
@@ -96,7 +97,7 @@ def __init__(
9697
self.device_id = None
9798

9899
self.model_description = self._get_model_description(model_name)
99-
self.cache_dir = define_cache_dir(cache_dir)
100+
self.cache_dir = str(define_cache_dir(cache_dir))
100101

101102
self._model_dir = self.download_model(
102103
self.model_description,
@@ -132,15 +133,15 @@ def load_onnx_model(self) -> None:
132133
def _post_process_onnx_image_output(
133134
self,
134135
output: OnnxOutputContext,
135-
) -> Iterable[np.ndarray]:
136+
) -> Iterable[NumpyArray]:
136137
"""
137138
Post-process the ONNX model output to convert it into a usable format.
138139
139140
Args:
140141
output (OnnxOutputContext): The raw output from the ONNX model.
141142
142143
Returns:
143-
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
144+
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
144145
"""
145146
return output.model_output.reshape(
146147
output.model_output.shape[0], -1, self.model_description["dim"]
@@ -149,15 +150,15 @@ def _post_process_onnx_image_output(
149150
def _post_process_onnx_text_output(
150151
self,
151152
output: OnnxOutputContext,
152-
) -> Iterable[np.ndarray]:
153+
) -> Iterable[NumpyArray]:
153154
"""
154155
Post-process the ONNX model output to convert it into a usable format.
155156
156157
Args:
157158
output (OnnxOutputContext): The raw output from the ONNX model.
158159
159160
Returns:
160-
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
161+
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
161162
"""
162163
return output.model_output.astype(np.float32)
163164

@@ -172,30 +173,32 @@ def tokenize(self, documents: list[str], **_) -> list[Encoding]:
172173
return encoded
173174

174175
def _preprocess_onnx_text_input(
175-
self, onnx_input: dict[str, np.ndarray], **kwargs
176-
) -> dict[str, np.ndarray]:
176+
self, onnx_input: dict[str, NumpyArray], **kwargs
177+
) -> dict[str, NumpyArray]:
177178
onnx_input["input_ids"] = np.array(
178179
[
179180
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
180181
for input_ids in onnx_input["input_ids"]
181182
]
182183
)
183-
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
184+
empty_image_placeholder: NumpyArray = np.zeros(
185+
self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32
186+
)
184187
onnx_input["pixel_values"] = np.array(
185-
[empty_image_placeholder for _ in onnx_input["input_ids"]]
188+
[empty_image_placeholder for _ in onnx_input["input_ids"]],
186189
)
187190
return onnx_input
188191

189192
def _preprocess_onnx_image_input(
190193
self, onnx_input: dict[str, np.ndarray], **kwargs
191-
) -> dict[str, np.ndarray]:
194+
) -> dict[str, NumpyArray]:
192195
"""
193196
Add placeholders for text input when processing image data for ONNX.
194197
Args:
195-
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
198+
onnx_input (Dict[str, NumpyArray]): Preprocessed image inputs.
196199
**kwargs: Additional arguments.
197200
Returns:
198-
Dict[str, np.ndarray]: ONNX input with text placeholders.
201+
Dict[str, NumpyArray]: ONNX input with text placeholders.
199202
"""
200203

201204
onnx_input["input_ids"] = np.array(
@@ -212,7 +215,7 @@ def embed_text(
212215
batch_size: int = 256,
213216
parallel: Optional[int] = None,
214217
**kwargs,
215-
) -> Iterable[np.ndarray]:
218+
) -> Iterable[NumpyArray]:
216219
"""
217220
Encode a list of documents into list of embeddings.
218221
@@ -241,11 +244,11 @@ def embed_text(
241244

242245
def embed_image(
243246
self,
244-
images: ImageInput,
247+
images: Union[ImageInput, Iterable[ImageInput]],
245248
batch_size: int = 16,
246249
parallel: Optional[int] = None,
247250
**kwargs,
248-
) -> Iterable[np.ndarray]:
251+
) -> Iterable[NumpyArray]:
249252
"""
250253
Encode a list of images into list of embeddings.
251254

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

3-
import numpy as np
4-
53
from fastembed.common import OnnxProvider, ImageInput
4+
from fastembed.common.types import NumpyArray
65
from fastembed.late_interaction_multimodal.colpali import ColPali
76

87
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
@@ -55,7 +54,7 @@ def __init__(
5554
cuda: bool = False,
5655
device_ids: Optional[list[int]] = None,
5756
lazy_load: bool = False,
58-
**kwargs,
57+
**kwargs: Any,
5958
):
6059
super().__init__(model_name, cache_dir, threads, **kwargs)
6160
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
@@ -83,8 +82,8 @@ def embed_text(
8382
documents: Union[str, Iterable[str]],
8483
batch_size: int = 256,
8584
parallel: Optional[int] = None,
86-
**kwargs,
87-
) -> Iterable[np.ndarray]:
85+
**kwargs: Any,
86+
) -> Iterable[NumpyArray]:
8887
"""
8988
Encode a list of documents into list of embeddings.
9089
@@ -106,8 +105,8 @@ def embed_image(
106105
images: Union[ImageInput, Iterable[ImageInput]],
107106
batch_size: int = 16,
108107
parallel: Optional[int] = None,
109-
**kwargs,
110-
) -> Iterable[np.ndarray]:
108+
**kwargs: Any,
109+
) -> Iterable[NumpyArray]:
111110
"""
112111
Encode a list of images into list of embeddings.
113112

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Iterable, Optional, Union
22

3-
import numpy as np
43

54
from fastembed.common import ImageInput
65
from fastembed.common.model_management import ModelManagement
6+
from fastembed.common.types import NumpyArray
77

88

99
class LateInteractionMultimodalEmbeddingBase(ModelManagement):
@@ -25,7 +25,7 @@ def embed_text(
2525
batch_size: int = 256,
2626
parallel: Optional[int] = None,
2727
**kwargs,
28-
) -> Iterable[np.ndarray]:
28+
) -> Iterable[NumpyArray]:
2929
"""
3030
Embeds a list of documents into a list of embeddings.
3131
@@ -39,7 +39,7 @@ def embed_text(
3939
**kwargs: Additional keyword argument to pass to the embed method.
4040
4141
Yields:
42-
Iterable[np.ndarray]: The embeddings.
42+
Iterable[NumpyArray]: The embeddings.
4343
"""
4444
raise NotImplementedError()
4545

@@ -49,7 +49,7 @@ def embed_image(
4949
batch_size: int = 16,
5050
parallel: Optional[int] = None,
5151
**kwargs,
52-
) -> Iterable[np.ndarray]:
52+
) -> Iterable[NumpyArray]:
5353
"""
5454
Encode a list of images into list of embeddings.
5555
Args:

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastembed.common import OnnxProvider, ImageInput
1212
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1313
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
14+
from fastembed.common.types import NumpyArray
1415
from fastembed.common.utils import iter_batch
1516
from fastembed.parallel_processor import ParallelWorkerPool
1617

@@ -25,16 +26,16 @@ def __init__(self) -> None:
2526
self.special_token_to_id = {}
2627

2728
def _preprocess_onnx_text_input(
28-
self, onnx_input: dict[str, np.ndarray], **kwargs
29-
) -> dict[str, np.ndarray]:
29+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
30+
) -> dict[str, NumpyArray]:
3031
"""
3132
Preprocess the onnx input.
3233
"""
3334
return onnx_input
3435

3536
def _preprocess_onnx_image_input(
36-
self, onnx_input: dict[str, np.ndarray], **kwargs
37-
) -> dict[str, np.ndarray]:
37+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
38+
) -> dict[str, NumpyArray]:
3839
"""
3940
Preprocess the onnx input.
4041
"""
@@ -71,19 +72,20 @@ def _load_onnx_model(
7172
cuda=cuda,
7273
device_id=device_id,
7374
)
75+
assert self.tokenizer is not None
7476
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
7577
self.processor = load_preprocessor(model_dir=model_dir)
7678

7779
def load_onnx_model(self) -> None:
7880
raise NotImplementedError("Subclasses must implement this method")
7981

80-
def tokenize(self, documents: list[str], **kwargs) -> list[Encoding]:
82+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
8183
return self.tokenizer.encode_batch(documents)
8284

8385
def onnx_embed_text(
8486
self,
8587
documents: list[str],
86-
**kwargs,
88+
**kwargs: Any,
8789
) -> OnnxOutputContext:
8890
encoded = self.tokenize(documents, **kwargs)
8991
input_ids = np.array([e.ids for e in encoded])
@@ -100,7 +102,7 @@ def onnx_embed_text(
100102
)
101103

102104
onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs)
103-
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
105+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore
104106
return OnnxOutputContext(
105107
model_output=model_output[0],
106108
attention_mask=onnx_input.get("attention_mask", attention_mask),
@@ -117,7 +119,7 @@ def _embed_documents(
117119
providers: Optional[Sequence[OnnxProvider]] = None,
118120
cuda: bool = False,
119121
device_ids: Optional[list[int]] = None,
120-
**kwargs,
122+
**kwargs: Any,
121123
) -> Iterable[T]:
122124
is_small = False
123125

@@ -156,10 +158,11 @@ def _embed_documents(
156158
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
157159
yield from self._post_process_onnx_text_output(batch)
158160

159-
def _build_onnx_image_input(self, encoded: np.ndarray) -> dict[str, np.ndarray]:
160-
return {node.name: encoded for node in self.model.get_inputs()}
161+
def _build_onnx_image_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
162+
input_name = self.model.get_inputs()[0].name # type: ignore
163+
return {input_name: encoded}
161164

162-
def onnx_embed_image(self, images: list[ImageInput], **kwargs) -> OnnxOutputContext:
165+
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
163166
with contextlib.ExitStack():
164167
image_files = [
165168
Image.open(image) if not isinstance(image, Image.Image) else image
@@ -182,7 +185,7 @@ def _embed_images(
182185
providers: Optional[Sequence[OnnxProvider]] = None,
183186
cuda: bool = False,
184187
device_ids: Optional[list[int]] = None,
185-
**kwargs,
188+
**kwargs: Any,
186189
) -> Iterable[T]:
187190
is_small = False
188191

fastembed/sparse/sparse_embedding_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def as_object(self) -> dict[str, NumpyArray]:
2020
}
2121

2222
def as_dict(self) -> dict[int, float]:
23-
return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore[arg-type]
23+
return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore
2424

2525
@classmethod
2626
def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding":

0 commit comments

Comments
 (0)