Skip to content

Commit 0fa1596

Browse files
new: Add missing type hints (#464)
* new: Add missing type hints * refactor: Removed type ignore * fix: fix mypy complaints * fix: remove redundant type coercion, fix skip list type * new: more precise type for sparse embedding inference, a small revert for parallel processor --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 2fe33c5 commit 0fa1596

File tree

10 files changed

+43
-32
lines changed

10 files changed

+43
-32
lines changed

fastembed/common/onnx_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from pathlib import Path
44
from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar
55

6+
import numpy as np
67
import onnxruntime as ort
78

9+
from numpy.typing import NDArray
10+
811
from fastembed.common.types import OnnxProvider, NumpyArray
912
from fastembed.parallel_processor import Worker
1013

@@ -15,8 +18,8 @@
1518
@dataclass
1619
class OnnxOutputContext:
1720
model_output: NumpyArray
18-
attention_mask: Optional[NumpyArray] = None
19-
input_ids: Optional[NumpyArray] = None
21+
attention_mask: Optional[NDArray[np.int64]] = None
22+
input_ids: Optional[NDArray[np.int64]] = None
2023

2124

2225
class OnnxModel(Generic[T]):
@@ -90,6 +93,7 @@ def _load_onnx_model(
9093
str(model_path), providers=onnx_providers, sess_options=so
9194
)
9295
if "CUDAExecutionProvider" in requested_provider_names:
96+
assert self.model is not None
9397
current_providers = self.model.get_providers()
9498
if "CUDAExecutionProvider" not in current_providers:
9599
warnings.warn(

fastembed/image/onnx_image_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def load_onnx_model(self) -> None:
6161
raise NotImplementedError("Subclasses must implement this method")
6262

6363
def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
64-
input_name = self.model.get_inputs()[0].name
64+
input_name = self.model.get_inputs()[0].name # type: ignore
6565
return {input_name: encoded}
6666

6767
def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
@@ -74,7 +74,7 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
7474
encoded = np.array(self.processor(image_files))
7575
onnx_input = self._build_onnx_input(encoded)
7676
onnx_input = self._preprocess_onnx_input(onnx_input)
77-
model_output = self.model.run(None, onnx_input)
77+
model_output = self.model.run(None, onnx_input) # type: ignore
7878
embeddings = model_output[0].reshape(len(images), -1)
7979
return OnnxOutputContext(model_output=embeddings)
8080

@@ -125,7 +125,7 @@ def _embed_images(
125125
start_method=start_method,
126126
)
127127
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
128-
yield from self._post_process_onnx_output(batch)
128+
yield from self._post_process_onnx_output(batch) # type: ignore
129129

130130

131131
class ImageEmbeddingWorker(EmbeddingWorker[T]):

fastembed/late_interaction/colbert.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _post_process_onnx_output(
5858
)
5959

6060
for i, token_sequence in enumerate(output.input_ids):
61-
for j, token_id in enumerate(token_sequence):
61+
for j, token_id in enumerate(token_sequence): # type: ignore
6262
if token_id in self.skip_list or token_id == self.pad_token_id:
6363
output.attention_mask[i, j] = 0
6464

@@ -88,6 +88,8 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->
8888
)
8989

9090
def _tokenize_query(self, query: str) -> list[Encoding]:
91+
assert self.tokenizer is not None
92+
9193
encoded = self.tokenizer.encode_batch([query])
9294
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
9395
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
@@ -107,6 +109,7 @@ def _tokenize_query(self, query: str) -> list[Encoding]:
107109
return encoded
108110

109111
def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
112+
assert self.tokenizer is not None
110113
encoded = self.tokenizer.encode_batch(documents)
111114
return encoded
112115

@@ -163,12 +166,11 @@ def __init__(
163166
self.cuda = cuda
164167

165168
# This device_id will be used if we need to load model in current process
169+
self.device_id: Optional[int] = None
166170
if device_id is not None:
167171
self.device_id = device_id
168172
elif self.device_ids is not None:
169173
self.device_id = self.device_ids[0]
170-
else:
171-
self.device_id = None
172174

173175
self.model_description = self._get_model_description(model_name)
174176
self.cache_dir = str(define_cache_dir(cache_dir))
@@ -181,7 +183,7 @@ def __init__(
181183
)
182184
self.mask_token_id: Optional[int] = None
183185
self.pad_token_id: Optional[int] = None
184-
self.skip_list: set[str] = set()
186+
self.skip_list: set[int] = set()
185187

186188
if not self.lazy_load:
187189
self.load_onnx_model()
@@ -195,6 +197,7 @@ def load_onnx_model(self) -> None:
195197
cuda=self.cuda,
196198
device_id=self.device_id,
197199
)
200+
assert self.tokenizer is not None
198201
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
199202
self.pad_token_id = self.tokenizer.padding["pad_id"]
200203
self.skip_list = {

fastembed/parallel_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def start(self, **kwargs: Any) -> None:
140140
self.processes.append(process)
141141

142142
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
143-
buffer: defaultdict[int, Any] = defaultdict(Any)
143+
buffer: defaultdict[int, Any] = defaultdict(Any) # type: ignore
144144
next_expected = 0
145145

146146
for idx, item in self.semi_ordered_map(stream, *args, **kwargs):

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ def _load_onnx_model(
4343
device_id=device_id,
4444
)
4545
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
46+
assert self.tokenizer is not None
4647

4748
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
48-
return self.tokenizer.encode_batch(pairs)
49+
return self.tokenizer.encode_batch(pairs) # type: ignore
4950

5051
def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:
51-
input_names: set[str] = {node.name for node in self.model.get_inputs()}
52+
input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore
5253
inputs: dict[str, NumpyArray] = {
5354
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
5455
}
@@ -70,7 +71,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO
7071
tokenized_input = self.tokenize(pairs, **kwargs)
7172
inputs = self._build_onnx_input(tokenized_input)
7273
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
73-
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
74+
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore
7475
relevant_output = outputs[0]
7576
scores: NumpyArray = relevant_output[:, 0]
7677
return OnnxOutputContext(model_output=scores)
@@ -98,7 +99,7 @@ def _rerank_pairs(
9899
is_small = False
99100

100101
if isinstance(pairs, tuple):
101-
pairs = [pairs] # type: ignore
102+
pairs = [pairs]
102103
is_small = True
103104

104105
if isinstance(pairs, list):
@@ -130,7 +131,7 @@ def _rerank_pairs(
130131
start_method=start_method,
131132
)
132133
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
133-
yield from self._post_process_onnx_output(batch)
134+
yield from self._post_process_onnx_output(batch) # type: ignore
134135

135136
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
136137
raise NotImplementedError("Subclasses must implement this method")

fastembed/sparse/bm25.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _embed_documents(
206206
)
207207
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
208208
for record in batch:
209-
yield record
209+
yield record # type: ignore
210210

211211
def embed(
212212
self,
@@ -343,7 +343,9 @@ def __init__(
343343
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
344344
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
345345

346-
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
346+
def process(
347+
self, items: Iterable[tuple[int, Any]]
348+
) -> Iterable[tuple[int, list[SparseEmbedding]]]:
347349
for idx, batch in items:
348350
onnx_output = self.model.raw_embed(batch)
349351
yield idx, onnx_output

fastembed/sparse/bm42.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,11 @@ def __init__(
102102
self.cuda = cuda
103103

104104
# This device_id will be used if we need to load model in current process
105+
self.device_id: Optional[int] = None
105106
if device_id is not None:
106107
self.device_id = device_id
107108
elif self.device_ids is not None:
108109
self.device_id = self.device_ids[0]
109-
else:
110-
self.device_id = None
111110

112111
self.model_description = self._get_model_description(model_name)
113112
self.cache_dir = str(define_cache_dir(cache_dir))
@@ -140,6 +139,7 @@ def load_onnx_model(self) -> None:
140139
cuda=self.cuda,
141140
device_id=self.device_id,
142141
)
142+
assert self.tokenizer is not None
143143
for token, idx in self.tokenizer.get_vocab().items():
144144
self.invert_vocab[idx] = token
145145
self.special_tokens = set(self.special_token_to_id.keys())
@@ -178,7 +178,7 @@ def _reconstruct_bpe(
178178
acc: str = ""
179179
acc_idx: list[int] = []
180180

181-
continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix
181+
continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore
182182
continuing_subword_prefix_len = len(continuing_subword_prefix)
183183

184184
for idx, token in bpe_tokens:
@@ -222,7 +222,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars
222222
if output.input_ids is None:
223223
raise ValueError("input_ids must be provided for document post-processing")
224224

225-
token_ids_batch = output.input_ids
225+
token_ids_batch = output.input_ids.astype(int)
226226

227227
# attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
228228
pooled_attention = np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
@@ -325,7 +325,7 @@ def query_embed(
325325
self.load_onnx_model()
326326

327327
for text in query:
328-
encoded = self.tokenizer.encode(text)
328+
encoded = self.tokenizer.encode(text) # type: ignore
329329
document_tokens_with_ids = enumerate(encoded.tokens)
330330
reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
331331
filtered = self._filter_pair_tokens(reconstructed)

fastembed/sparse/sparse_embedding_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Iterable, Optional, Union, Any
33

44
import numpy as np
5+
from numpy.typing import NDArray
56

67
from fastembed.common.types import NumpyArray
78
from fastembed.common.model_management import ModelManagement
@@ -10,7 +11,7 @@
1011
@dataclass
1112
class SparseEmbedding:
1213
values: NumpyArray
13-
indices: NumpyArray
14+
indices: Union[NDArray[np.int64], NDArray[np.int32]]
1415

1516
def as_object(self) -> dict[str, NumpyArray]:
1617
return {
@@ -19,7 +20,7 @@ def as_object(self) -> dict[str, NumpyArray]:
1920
}
2021

2122
def as_dict(self) -> dict[int, float]:
22-
return {i: v for i, v in zip(self.indices, self.values)}
23+
return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore[arg-type]
2324

2425
@classmethod
2526
def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding":

fastembed/sparse/splade_pp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,11 @@ def __init__(
106106
self.cuda = cuda
107107

108108
# This device_id will be used if we need to load model in current process
109+
self.device_id: Optional[int] = None
109110
if device_id is not None:
110111
self.device_id = device_id
111112
elif self.device_ids is not None:
112113
self.device_id = self.device_ids[0]
113-
else:
114-
self.device_id = None
115114

116115
self.model_description = self._get_model_description(model_name)
117116
self.cache_dir = str(define_cache_dir(cache_dir))

fastembed/text/onnx_text_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Iterable, Optional, Sequence, Type, Union
55

66
import numpy as np
7+
from numpy.typing import NDArray
78
from tokenizers import Encoding
89

910
from fastembed.common.types import NumpyArray, OnnxProvider
@@ -23,14 +24,14 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
2324
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2425
raise NotImplementedError("Subclasses must implement this method")
2526

26-
def __init__(self):
27+
def __init__(self) -> None:
2728
super().__init__()
2829
self.tokenizer = None
2930
self.special_token_to_id: dict[str, int] = {}
3031

3132
def _preprocess_onnx_input(
3233
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
33-
) -> dict[str, NumpyArray]:
34+
) -> dict[str, Union[NumpyArray, NDArray[np.int64]]]:
3435
"""
3536
Preprocess the onnx input.
3637
"""
@@ -60,7 +61,7 @@ def load_onnx_model(self) -> None:
6061
raise NotImplementedError("Subclasses must implement this method")
6162

6263
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
63-
return self.tokenizer.encode_batch(documents)
64+
return self.tokenizer.encode_batch(documents) # type: ignore
6465

6566
def onnx_embed(
6667
self,
@@ -70,7 +71,7 @@ def onnx_embed(
7071
encoded = self.tokenize(documents, **kwargs)
7172
input_ids = np.array([e.ids for e in encoded])
7273
attention_mask = np.array([e.attention_mask for e in encoded])
73-
input_names = {node.name for node in self.model.get_inputs()}
74+
input_names = {node.name for node in self.model.get_inputs()} # type: ignore
7475
onnx_input: dict[str, NumpyArray] = {
7576
"input_ids": np.array(input_ids, dtype=np.int64),
7677
}
@@ -82,7 +83,7 @@ def onnx_embed(
8283
)
8384
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
8485

85-
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
86+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore
8687
return OnnxOutputContext(
8788
model_output=model_output[0],
8889
attention_mask=onnx_input.get("attention_mask", attention_mask),
@@ -136,7 +137,7 @@ def _embed_documents(
136137
start_method=start_method,
137138
)
138139
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
139-
yield from self._post_process_onnx_output(batch)
140+
yield from self._post_process_onnx_output(batch) # type: ignore
140141

141142

142143
class TextEmbeddingWorker(EmbeddingWorker[T]):

0 commit comments

Comments
 (0)