Skip to content

Commit

Permalink
new: Add missing type hints (#464)
Browse files Browse the repository at this point in the history
* new: Add missing type hints

* refactor: Removed type ignore

* fix: fix mypy complaints

* fix: remove redundant type coercion, fix skip list type

* new: more precise type for sparse embedding inference, a small revert for parallel processor

---------

Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
hh-space-invader and joein authored Feb 6, 2025
1 parent 2fe33c5 commit 0fa1596
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 32 deletions.
8 changes: 6 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from pathlib import Path
from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar

import numpy as np
import onnxruntime as ort

from numpy.typing import NDArray

from fastembed.common.types import OnnxProvider, NumpyArray
from fastembed.parallel_processor import Worker

Expand All @@ -15,8 +18,8 @@
@dataclass
class OnnxOutputContext:
model_output: NumpyArray
attention_mask: Optional[NumpyArray] = None
input_ids: Optional[NumpyArray] = None
attention_mask: Optional[NDArray[np.int64]] = None
input_ids: Optional[NDArray[np.int64]] = None


class OnnxModel(Generic[T]):
Expand Down Expand Up @@ -90,6 +93,7 @@ def _load_onnx_model(
str(model_path), providers=onnx_providers, sess_options=so
)
if "CUDAExecutionProvider" in requested_provider_names:
assert self.model is not None
current_providers = self.model.get_providers()
if "CUDAExecutionProvider" not in current_providers:
warnings.warn(
Expand Down
6 changes: 3 additions & 3 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
input_name = self.model.get_inputs()[0].name
input_name = self.model.get_inputs()[0].name # type: ignore
return {input_name: encoded}

def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
Expand All @@ -74,7 +74,7 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)
model_output = self.model.run(None, onnx_input)
model_output = self.model.run(None, onnx_input) # type: ignore
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)

Expand Down Expand Up @@ -125,7 +125,7 @@ def _embed_images(
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
yield from self._post_process_onnx_output(batch)
yield from self._post_process_onnx_output(batch) # type: ignore


class ImageEmbeddingWorker(EmbeddingWorker[T]):
Expand Down
11 changes: 7 additions & 4 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _post_process_onnx_output(
)

for i, token_sequence in enumerate(output.input_ids):
for j, token_id in enumerate(token_sequence):
for j, token_id in enumerate(token_sequence): # type: ignore
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

Expand Down Expand Up @@ -88,6 +88,8 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->
)

def _tokenize_query(self, query: str) -> list[Encoding]:
assert self.tokenizer is not None

encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
Expand All @@ -107,6 +109,7 @@ def _tokenize_query(self, query: str) -> list[Encoding]:
return encoded

def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
assert self.tokenizer is not None
encoded = self.tokenizer.encode_batch(documents)
return encoded

Expand Down Expand Up @@ -163,12 +166,11 @@ def __init__(
self.cuda = cuda

# This device_id will be used if we need to load model in current process
self.device_id: Optional[int] = None
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = str(define_cache_dir(cache_dir))
Expand All @@ -181,7 +183,7 @@ def __init__(
)
self.mask_token_id: Optional[int] = None
self.pad_token_id: Optional[int] = None
self.skip_list: set[str] = set()
self.skip_list: set[int] = set()

if not self.lazy_load:
self.load_onnx_model()
Expand All @@ -195,6 +197,7 @@ def load_onnx_model(self) -> None:
cuda=self.cuda,
device_id=self.device_id,
)
assert self.tokenizer is not None
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
self.pad_token_id = self.tokenizer.padding["pad_id"]
self.skip_list = {
Expand Down
2 changes: 1 addition & 1 deletion fastembed/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def start(self, **kwargs: Any) -> None:
self.processes.append(process)

def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
buffer: defaultdict[int, Any] = defaultdict(Any)
buffer: defaultdict[int, Any] = defaultdict(Any) # type: ignore
next_expected = 0

for idx, item in self.semi_ordered_map(stream, *args, **kwargs):
Expand Down
11 changes: 6 additions & 5 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def _load_onnx_model(
device_id=device_id,
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
assert self.tokenizer is not None

def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs)
return self.tokenizer.encode_batch(pairs) # type: ignore

def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:
input_names: set[str] = {node.name for node in self.model.get_inputs()}
input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore
inputs: dict[str, NumpyArray] = {
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
}
Expand All @@ -70,7 +71,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO
tokenized_input = self.tokenize(pairs, **kwargs)
inputs = self._build_onnx_input(tokenized_input)
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore
relevant_output = outputs[0]
scores: NumpyArray = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores)
Expand Down Expand Up @@ -98,7 +99,7 @@ def _rerank_pairs(
is_small = False

if isinstance(pairs, tuple):
pairs = [pairs] # type: ignore
pairs = [pairs]
is_small = True

if isinstance(pairs, list):
Expand Down Expand Up @@ -130,7 +131,7 @@ def _rerank_pairs(
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
yield from self._post_process_onnx_output(batch)
yield from self._post_process_onnx_output(batch) # type: ignore

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
raise NotImplementedError("Subclasses must implement this method")
Expand Down
6 changes: 4 additions & 2 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _embed_documents(
)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
for record in batch:
yield record
yield record # type: ignore

def embed(
self,
Expand Down Expand Up @@ -343,7 +343,9 @@ def __init__(
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]]
) -> Iterable[tuple[int, list[SparseEmbedding]]]:
for idx, batch in items:
onnx_output = self.model.raw_embed(batch)
yield idx, onnx_output
Expand Down
10 changes: 5 additions & 5 deletions fastembed/sparse/bm42.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,11 @@ def __init__(
self.cuda = cuda

# This device_id will be used if we need to load model in current process
self.device_id: Optional[int] = None
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = str(define_cache_dir(cache_dir))
Expand Down Expand Up @@ -140,6 +139,7 @@ def load_onnx_model(self) -> None:
cuda=self.cuda,
device_id=self.device_id,
)
assert self.tokenizer is not None
for token, idx in self.tokenizer.get_vocab().items():
self.invert_vocab[idx] = token
self.special_tokens = set(self.special_token_to_id.keys())
Expand Down Expand Up @@ -178,7 +178,7 @@ def _reconstruct_bpe(
acc: str = ""
acc_idx: list[int] = []

continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix
continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore
continuing_subword_prefix_len = len(continuing_subword_prefix)

for idx, token in bpe_tokens:
Expand Down Expand Up @@ -222,7 +222,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars
if output.input_ids is None:
raise ValueError("input_ids must be provided for document post-processing")

token_ids_batch = output.input_ids
token_ids_batch = output.input_ids.astype(int)

# attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
pooled_attention = np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
Expand Down Expand Up @@ -325,7 +325,7 @@ def query_embed(
self.load_onnx_model()

for text in query:
encoded = self.tokenizer.encode(text)
encoded = self.tokenizer.encode(text) # type: ignore
document_tokens_with_ids = enumerate(encoded.tokens)
reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
filtered = self._filter_pair_tokens(reconstructed)
Expand Down
5 changes: 3 additions & 2 deletions fastembed/sparse/sparse_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Iterable, Optional, Union, Any

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray
from fastembed.common.model_management import ModelManagement
Expand All @@ -10,7 +11,7 @@
@dataclass
class SparseEmbedding:
values: NumpyArray
indices: NumpyArray
indices: Union[NDArray[np.int64], NDArray[np.int32]]

def as_object(self) -> dict[str, NumpyArray]:
return {
Expand All @@ -19,7 +20,7 @@ def as_object(self) -> dict[str, NumpyArray]:
}

def as_dict(self) -> dict[int, float]:
return {i: v for i, v in zip(self.indices, self.values)}
return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore[arg-type]

@classmethod
def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding":
Expand Down
3 changes: 1 addition & 2 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ def __init__(
self.cuda = cuda

# This device_id will be used if we need to load model in current process
self.device_id: Optional[int] = None
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = str(define_cache_dir(cache_dir))
Expand Down
13 changes: 7 additions & 6 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from numpy.typing import NDArray
from tokenizers import Encoding

from fastembed.common.types import NumpyArray, OnnxProvider
Expand All @@ -23,14 +24,14 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.tokenizer = None
self.special_token_to_id: dict[str, int] = {}

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
) -> dict[str, NumpyArray]:
) -> dict[str, Union[NumpyArray, NDArray[np.int64]]]:
"""
Preprocess the onnx input.
"""
Expand Down Expand Up @@ -60,7 +61,7 @@ def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(documents)
return self.tokenizer.encode_batch(documents) # type: ignore

def onnx_embed(
self,
Expand All @@ -70,7 +71,7 @@ def onnx_embed(
encoded = self.tokenize(documents, **kwargs)
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
input_names = {node.name for node in self.model.get_inputs()}
input_names = {node.name for node in self.model.get_inputs()} # type: ignore
onnx_input: dict[str, NumpyArray] = {
"input_ids": np.array(input_ids, dtype=np.int64),
}
Expand All @@ -82,7 +83,7 @@ def onnx_embed(
)
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore
return OnnxOutputContext(
model_output=model_output[0],
attention_mask=onnx_input.get("attention_mask", attention_mask),
Expand Down Expand Up @@ -136,7 +137,7 @@ def _embed_documents(
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
yield from self._post_process_onnx_output(batch)
yield from self._post_process_onnx_output(batch) # type: ignore


class TextEmbeddingWorker(EmbeddingWorker[T]):
Expand Down

0 comments on commit 0fa1596

Please sign in to comment.