From 4599b933ca11d4594566f4ea9e6c41a94eae1d54 Mon Sep 17 00:00:00 2001 From: George Date: Thu, 6 Feb 2025 20:28:10 +0100 Subject: [PATCH] wip: type hints for colpali (#469) * wip: type hints for colpali * new: Add colpali type hints * refactor: Remove redundant type ignore * fix: address remaining mypy issues --------- Co-authored-by: hh-space-invader --- fastembed/common/onnx_model.py | 5 +- fastembed/image/onnx_image_model.py | 4 +- fastembed/late_interaction/colbert.py | 4 +- .../late_interaction_multimodal/colpali.py | 62 ++++++------ .../late_interaction_multimodal_embedding.py | 15 ++- ...e_interaction_multimodal_embedding_base.py | 16 +-- .../onnx_multimodal_model.py | 99 +++++++++++++------ .../rerank/cross_encoder/onnx_text_model.py | 6 +- fastembed/sparse/bm42.py | 8 +- fastembed/sparse/sparse_embedding_base.py | 2 +- fastembed/text/onnx_text_model.py | 10 +- 11 files changed, 135 insertions(+), 96 deletions(-) diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 6227c3eb..52b08b43 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -7,6 +7,7 @@ import onnxruntime as ort from numpy.typing import NDArray +from tokenizers import Tokenizer from fastembed.common.types import OnnxProvider, NumpyArray from fastembed.parallel_processor import Worker @@ -31,8 +32,8 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]: raise NotImplementedError("Subclasses must implement this method") def __init__(self) -> None: - self.model = None - self.tokenizer = None + self.model: Optional[ort.InferenceSession] = None + self.tokenizer: Optional[Tokenizer] = None def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], **kwargs: Any diff --git a/fastembed/image/onnx_image_model.py b/fastembed/image/onnx_image_model.py index 0556e29d..a178e6c5 100644 --- a/fastembed/image/onnx_image_model.py +++ b/fastembed/image/onnx_image_model.py @@ -61,7 +61,7 @@ def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]: - input_name = self.model.get_inputs()[0].name # type: ignore + input_name = self.model.get_inputs()[0].name # type: ignore[union-attr] return {input_name: encoded} def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: @@ -74,7 +74,7 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte encoded = np.array(self.processor(image_files)) onnx_input = self._build_onnx_input(encoded) onnx_input = self._preprocess_onnx_input(onnx_input) - model_output = self.model.run(None, onnx_input) # type: ignore + model_output = self.model.run(None, onnx_input) # type: ignore[union-attr] embeddings = model_output[0].reshape(len(images), -1) return OnnxOutputContext(model_output=embeddings) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 6180614b..c544774a 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -89,7 +89,6 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> def _tokenize_query(self, query: str) -> list[Encoding]: assert self.tokenizer is not None - encoded = self.tokenizer.encode_batch([query]) # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance if len(encoded[0].ids) < self.MIN_QUERY_LENGTH: @@ -109,8 +108,7 @@ def _tokenize_query(self, query: str) -> list[Encoding]: return encoded def _tokenize_documents(self, documents: list[str]) -> list[Encoding]: - assert self.tokenizer is not None - encoded = self.tokenizer.encode_batch(documents) + encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr] return encoded @classmethod diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 5944ed5f..053ecacb 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -5,6 +5,7 @@ from fastembed.common import OnnxProvider, ImageInput from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray from fastembed.common.utils import define_cache_dir from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( LateInteractionMultimodalEmbeddingBase, @@ -33,7 +34,7 @@ ] -class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]): +class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]): QUERY_PREFIX = "Query: " BOS_TOKEN = "" PAD_TOKEN = "" @@ -56,7 +57,7 @@ def __init__( lazy_load: bool = False, device_id: Optional[int] = None, specific_model_path: Optional[str] = None, - **kwargs, + **kwargs: Any, ): """ Args: @@ -88,15 +89,14 @@ def __init__( self.cuda = cuda # This device_id will be used if we need to load model in current process + self.device_id: Optional[int] = None if device_id is not None: self.device_id = device_id elif self.device_ids is not None: self.device_id = self.device_ids[0] - else: - self.device_id = None self.model_description = self._get_model_description(model_name) - self.cache_dir = define_cache_dir(cache_dir) + self.cache_dir = str(define_cache_dir(cache_dir)) self._model_dir = self.download_model( self.model_description, @@ -132,7 +132,7 @@ def load_onnx_model(self) -> None: def _post_process_onnx_image_output( self, output: OnnxOutputContext, - ) -> Iterable[np.ndarray]: + ) -> Iterable[NumpyArray]: """ Post-process the ONNX model output to convert it into a usable format. @@ -140,7 +140,7 @@ def _post_process_onnx_image_output( output (OnnxOutputContext): The raw output from the ONNX model. Returns: - Iterable[np.ndarray]: Post-processed output as NumPy arrays. + Iterable[NumpyArray]: Post-processed output as NumPy arrays. """ return output.model_output.reshape( output.model_output.shape[0], -1, self.model_description["dim"] @@ -149,7 +149,7 @@ def _post_process_onnx_image_output( def _post_process_onnx_text_output( self, output: OnnxOutputContext, - ) -> Iterable[np.ndarray]: + ) -> Iterable[NumpyArray]: """ Post-process the ONNX model output to convert it into a usable format. @@ -157,45 +157,47 @@ def _post_process_onnx_text_output( output (OnnxOutputContext): The raw output from the ONNX model. Returns: - Iterable[np.ndarray]: Post-processed output as NumPy arrays. + Iterable[NumpyArray]: Post-processed output as NumPy arrays. """ return output.model_output.astype(np.float32) - def tokenize(self, documents: list[str], **_) -> list[Encoding]: + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] for query in documents: query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" texts_query.append(query) - encoded = self.tokenizer.encode_batch(texts_query) + encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr] return encoded def _preprocess_onnx_text_input( - self, onnx_input: dict[str, np.ndarray], **kwargs - ) -> dict[str, np.ndarray]: + self, onnx_input: dict[str, NumpyArray], **kwargs: Any + ) -> dict[str, NumpyArray]: onnx_input["input_ids"] = np.array( [ self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist() for input_ids in onnx_input["input_ids"] ] ) - empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32) + empty_image_placeholder: NumpyArray = np.zeros( + self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32 + ) onnx_input["pixel_values"] = np.array( - [empty_image_placeholder for _ in onnx_input["input_ids"]] + [empty_image_placeholder for _ in onnx_input["input_ids"]], ) return onnx_input def _preprocess_onnx_image_input( - self, onnx_input: dict[str, np.ndarray], **kwargs - ) -> dict[str, np.ndarray]: + self, onnx_input: dict[str, np.ndarray], **kwargs: Any + ) -> dict[str, NumpyArray]: """ Add placeholders for text input when processing image data for ONNX. Args: - onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs. + onnx_input (Dict[str, NumpyArray]): Preprocessed image inputs. **kwargs: Additional arguments. Returns: - Dict[str, np.ndarray]: ONNX input with text placeholders. + Dict[str, NumpyArray]: ONNX input with text placeholders. """ onnx_input["input_ids"] = np.array( @@ -211,8 +213,8 @@ def embed_text( documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Encode a list of documents into list of embeddings. @@ -241,11 +243,11 @@ def embed_text( def embed_image( self, - images: ImageInput, + images: Union[ImageInput, Iterable[ImageInput]], batch_size: int = 16, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Encode a list of images into list of embeddings. @@ -273,16 +275,16 @@ def embed_image( ) @classmethod - def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]: + def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: return ColPaliTextEmbeddingWorker @classmethod - def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]: + def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]: return ColPaliImageEmbeddingWorker -class ColPaliTextEmbeddingWorker(TextEmbeddingWorker): - def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: +class ColPaliTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali: return ColPali( model_name=model_name, cache_dir=cache_dir, @@ -291,8 +293,8 @@ def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: ) -class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker): - def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: +class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali: return ColPali( model_name=model_name, cache_dir=cache_dir, diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index f1c7b794..08819a53 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -1,8 +1,7 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union -import numpy as np - from fastembed.common import OnnxProvider, ImageInput +from fastembed.common.types import NumpyArray from fastembed.late_interaction_multimodal.colpali import ColPali from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( @@ -41,7 +40,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]: ] ``` """ - result = [] + result: list[dict[str, Any]] = [] for embedding in cls.EMBEDDINGS_REGISTRY: result.extend(embedding.list_supported_models()) return result @@ -55,7 +54,7 @@ def __init__( cuda: bool = False, device_ids: Optional[list[int]] = None, lazy_load: bool = False, - **kwargs, + **kwargs: Any, ): super().__init__(model_name, cache_dir, threads, **kwargs) for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: @@ -83,8 +82,8 @@ def embed_text( documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Encode a list of documents into list of embeddings. @@ -106,8 +105,8 @@ def embed_image( images: Union[ImageInput, Iterable[ImageInput]], batch_size: int = 16, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Encode a list of images into list of embeddings. diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 5cfe45ba..64ee8643 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -1,9 +1,9 @@ -from typing import Iterable, Optional, Union +from typing import Iterable, Optional, Union, Any -import numpy as np from fastembed.common import ImageInput from fastembed.common.model_management import ModelManagement +from fastembed.common.types import NumpyArray class LateInteractionMultimodalEmbeddingBase(ModelManagement): @@ -12,7 +12,7 @@ def __init__( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, - **kwargs, + **kwargs: Any, ): self.model_name = model_name self.cache_dir = cache_dir @@ -24,8 +24,8 @@ def embed_text( documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Embeds a list of documents into a list of embeddings. @@ -39,7 +39,7 @@ def embed_text( **kwargs: Additional keyword argument to pass to the embed method. Yields: - Iterable[np.ndarray]: The embeddings. + Iterable[NumpyArray]: The embeddings. """ raise NotImplementedError() @@ -48,8 +48,8 @@ def embed_image( images: Union[ImageInput, Iterable[ImageInput]], batch_size: int = 16, parallel: Optional[int] = None, - **kwargs, - ) -> Iterable[np.ndarray]: + **kwargs: Any, + ) -> Iterable[NumpyArray]: """ Encode a list of images into list of embeddings. Args: diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 75031cfa..e34b8c0e 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -2,16 +2,18 @@ import os from multiprocessing import get_all_start_methods from pathlib import Path -from typing import Any, Iterable, Optional, Sequence, Type, Union, get_args +from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np from PIL import Image -from tokenizers import Encoding +from tokenizers import Encoding, Tokenizer from fastembed.common import OnnxProvider, ImageInput from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor +from fastembed.common.types import NumpyArray from fastembed.common.utils import iter_batch +from fastembed.image.transform.operators import Compose from fastembed.parallel_processor import ParallelWorkerPool @@ -20,32 +22,32 @@ class OnnxMultimodalModel(OnnxModel[T]): def __init__(self) -> None: super().__init__() - self.tokenizer = None - self.processor = None - self.special_token_to_id = {} + self.tokenizer: Optional[Tokenizer] = None + self.processor: Optional[Compose] = None + self.special_token_to_id: dict[str, int] = {} def _preprocess_onnx_text_input( - self, onnx_input: dict[str, np.ndarray], **kwargs - ) -> dict[str, np.ndarray]: + self, onnx_input: dict[str, NumpyArray], **kwargs: Any + ) -> dict[str, NumpyArray]: """ Preprocess the onnx input. """ return onnx_input def _preprocess_onnx_image_input( - self, onnx_input: dict[str, np.ndarray], **kwargs - ) -> dict[str, np.ndarray]: + self, onnx_input: dict[str, NumpyArray], **kwargs: Any + ) -> dict[str, NumpyArray]: """ Preprocess the onnx input. """ return onnx_input @classmethod - def _get_text_worker_class(cls) -> Type["TextEmbeddingWorker"]: + def _get_text_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]: raise NotImplementedError("Subclasses must implement this method") @classmethod - def _get_image_worker_class(cls) -> Type["ImageEmbeddingWorker"]: + def _get_image_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]: raise NotImplementedError("Subclasses must implement this method") def _post_process_onnx_image_output(self, output: OnnxOutputContext) -> Iterable[T]: @@ -71,25 +73,26 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) + assert self.tokenizer is not None self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) self.processor = load_preprocessor(model_dir=model_dir) def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def tokenize(self, documents: list[str], **kwargs) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] def onnx_embed_text( self, documents: list[str], - **kwargs, + **kwargs: Any, ) -> OnnxOutputContext: encoded = self.tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - input_names = {node.name for node in self.model.get_inputs()} - onnx_input = { + attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr] + input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] + onnx_input: dict[str, NumpyArray] = { "input_ids": np.array(input_ids, dtype=np.int64), } if "attention_mask" in input_names: @@ -100,7 +103,7 @@ def onnx_embed_text( ) onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs) - model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] return OnnxOutputContext( model_output=model_output[0], attention_mask=onnx_input.get("attention_mask", attention_mask), @@ -117,7 +120,7 @@ def _embed_documents( providers: Optional[Sequence[OnnxProvider]] = None, cuda: bool = False, device_ids: Optional[list[int]] = None, - **kwargs, + **kwargs: Any, ) -> Iterable[T]: is_small = False @@ -154,21 +157,23 @@ def _embed_documents( start_method=start_method, ) for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): - yield from self._post_process_onnx_text_output(batch) + yield from self._post_process_onnx_text_output(batch) # type: ignore - def _build_onnx_image_input(self, encoded: np.ndarray) -> dict[str, np.ndarray]: - return {node.name: encoded for node in self.model.get_inputs()} + def _build_onnx_image_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]: + input_name = self.model.get_inputs()[0].name # type: ignore[union-attr] + return {input_name: encoded} - def onnx_embed_image(self, images: list[ImageInput], **kwargs) -> OnnxOutputContext: + def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: with contextlib.ExitStack(): image_files = [ Image.open(image) if not isinstance(image, Image.Image) else image for image in images ] - encoded = self.processor(image_files) + assert self.processor is not None, "Processor is not initialized" + encoded = np.array(self.processor(image_files)) onnx_input = self._build_onnx_image_input(encoded) onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs) - model_output = self.model.run(None, onnx_input) + model_output = self.model.run(None, onnx_input) # type: ignore[union-attr] embeddings = model_output[0].reshape(len(images), -1) return OnnxOutputContext(model_output=embeddings) @@ -182,11 +187,11 @@ def _embed_images( providers: Optional[Sequence[OnnxProvider]] = None, cuda: bool = False, device_ids: Optional[list[int]] = None, - **kwargs, + **kwargs: Any, ) -> Iterable[T]: is_small = False - if isinstance(images, get_args(ImageInput)): + if isinstance(images, (str, Path, Image.Image)): images = [images] is_small = True @@ -219,17 +224,51 @@ def _embed_images( start_method=start_method, ) for batch in pool.ordered_map(iter_batch(images, batch_size), **params): - yield from self._post_process_onnx_image_output(batch) + yield from self._post_process_onnx_image_output(batch) # type: ignore -class TextEmbeddingWorker(EmbeddingWorker): +class TextEmbeddingWorker(EmbeddingWorker[T]): + def __init__( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ): + self.model: OnnxMultimodalModel + super().__init__(model_name, cache_dir, **kwargs) + + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxMultimodalModel: + raise NotImplementedError() + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: for idx, batch in items: onnx_output = self.model.onnx_embed_text(batch) yield idx, onnx_output -class ImageEmbeddingWorker(EmbeddingWorker): +class ImageEmbeddingWorker(EmbeddingWorker[T]): + def __init__( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ): + self.model: OnnxMultimodalModel + super().__init__(model_name, cache_dir, **kwargs) + + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxMultimodalModel: + raise NotImplementedError() + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: for idx, batch in items: embeddings = self.model.onnx_embed_image(batch) diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index f6473413..bc619856 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -46,10 +46,10 @@ def _load_onnx_model( assert self.tokenizer is not None def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(pairs) # type: ignore + return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr] def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]: - input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore + input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] inputs: dict[str, NumpyArray] = { "input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64), } @@ -71,7 +71,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO tokenized_input = self.tokenize(pairs, **kwargs) inputs = self._build_onnx_input(tokenized_input) onnx_input = self._preprocess_onnx_input(inputs, **kwargs) - outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore + outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] relevant_output = outputs[0] scores: NumpyArray = relevant_output[:, 0] return OnnxOutputContext(model_output=scores) diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 6f848d32..f34abb29 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -139,8 +139,8 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) - assert self.tokenizer is not None - for token, idx in self.tokenizer.get_vocab().items(): + + for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr] self.invert_vocab[idx] = token self.special_tokens = set(self.special_token_to_id.keys()) self.special_tokens_ids = set(self.special_token_to_id.values()) @@ -178,7 +178,7 @@ def _reconstruct_bpe( acc: str = "" acc_idx: list[int] = [] - continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore + continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore[union-attr] continuing_subword_prefix_len = len(continuing_subword_prefix) for idx, token in bpe_tokens: @@ -325,7 +325,7 @@ def query_embed( self.load_onnx_model() for text in query: - encoded = self.tokenizer.encode(text) # type: ignore + encoded = self.tokenizer.encode(text) # type: ignore[union-attr] document_tokens_with_ids = enumerate(encoded.tokens) reconstructed = self._reconstruct_bpe(document_tokens_with_ids) filtered = self._filter_pair_tokens(reconstructed) diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index 7c056ac1..c6dc3393 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -20,7 +20,7 @@ def as_object(self) -> dict[str, NumpyArray]: } def as_dict(self) -> dict[int, float]: - return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore[arg-type] + return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore @classmethod def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding": diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 45423a19..294c3670 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from tokenizers import Encoding +from tokenizers import Encoding, Tokenizer from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T @@ -26,7 +26,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]: def __init__(self) -> None: super().__init__() - self.tokenizer = None + self.tokenizer: Optional[Tokenizer] = None self.special_token_to_id: dict[str, int] = {} def _preprocess_onnx_input( @@ -61,7 +61,7 @@ def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(documents) # type: ignore + return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] def onnx_embed( self, @@ -71,7 +71,7 @@ def onnx_embed( encoded = self.tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) - input_names = {node.name for node in self.model.get_inputs()} # type: ignore + input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] onnx_input: dict[str, NumpyArray] = { "input_ids": np.array(input_ids, dtype=np.int64), } @@ -83,7 +83,7 @@ def onnx_embed( ) onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) - model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] return OnnxOutputContext( model_output=model_output[0], attention_mask=onnx_input.get("attention_mask", attention_mask),