Skip to content

Commit 877d963

Browse files
Rerank type hints (#459)
* chore: Update type hints * remove redundant array creation, update type hints --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 37a66d9 commit 877d963

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,14 @@ def __init__(
134134
)
135135

136136
# This device_id will be used if we need to load model in current process
137+
self.device_id: Optional[int] = None
137138
if device_id is not None:
138139
self.device_id = device_id
139140
elif self.device_ids is not None:
140141
self.device_id = self.device_ids[0]
141-
else:
142-
self.device_id = None
143142

144143
self.model_description = self._get_model_description(model_name)
145-
self.cache_dir = define_cache_dir(cache_dir)
144+
self.cache_dir = str(define_cache_dir(cache_dir))
146145
self._model_dir = self.download_model(
147146
self.model_description,
148147
self.cache_dir,

fastembed/rerank/cross_encoder/onnx_text_model.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22
from multiprocessing import get_all_start_methods
33
from pathlib import Path
4-
from typing import Any, Iterable, Optional, Sequence, Type, Union
4+
from typing import Any, Iterable, Optional, Sequence, Type
55

66
import numpy as np
7-
from numpy.typing import NDArray
87
from tokenizers import Encoding
98

109
from fastembed.common.onnx_model import (
@@ -13,6 +12,7 @@
1312
OnnxOutputContext,
1413
OnnxProvider,
1514
)
15+
from fastembed.common.types import NumpyArray
1616
from fastembed.common.preprocessor_utils import load_tokenizer
1717
from fastembed.common.utils import iter_batch
1818
from fastembed.parallel_processor import ParallelWorkerPool
@@ -47,11 +47,9 @@ def _load_onnx_model(
4747
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
4848
return self.tokenizer.encode_batch(pairs)
4949

50-
def _build_onnx_input(
51-
self, tokenized_input
52-
) -> dict[str, NDArray[Union[np.float32, np.int64]]]:
53-
input_names = {node.name for node in self.model.get_inputs()}
54-
inputs = {
50+
def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:
51+
input_names: set[str] = {node.name for node in self.model.get_inputs()}
52+
inputs: dict[str, NumpyArray] = {
5553
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
5654
}
5755
if "token_type_ids" in input_names:
@@ -74,7 +72,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO
7472
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
7573
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
7674
relevant_output = outputs[0]
77-
scores = relevant_output[:, 0]
75+
scores: NumpyArray = relevant_output[:, 0]
7876
return OnnxOutputContext(model_output=scores)
7977

8078
def _rerank_documents(
@@ -100,7 +98,7 @@ def _rerank_pairs(
10098
is_small = False
10199

102100
if isinstance(pairs, tuple):
103-
pairs = [pairs]
101+
pairs = [pairs] # type: ignore
104102
is_small = True
105103

106104
if isinstance(pairs, list):
@@ -138,15 +136,32 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float
138136
raise NotImplementedError("Subclasses must implement this method")
139137

140138
def _preprocess_onnx_input(
141-
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
142-
) -> dict[str, np.ndarray]:
139+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
140+
) -> dict[str, NumpyArray]:
143141
"""
144142
Preprocess the onnx input.
145143
"""
146144
return onnx_input
147145

148146

149-
class TextRerankerWorker(EmbeddingWorker):
147+
class TextRerankerWorker(EmbeddingWorker[float]):
148+
def __init__(
149+
self,
150+
model_name: str,
151+
cache_dir: str,
152+
**kwargs: Any,
153+
):
154+
self.model: OnnxCrossEncoderModel
155+
super().__init__(model_name, cache_dir, **kwargs)
156+
157+
def init_embedding(
158+
self,
159+
model_name: str,
160+
cache_dir: str,
161+
**kwargs: Any,
162+
) -> OnnxCrossEncoderModel:
163+
raise NotImplementedError()
164+
150165
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
151166
for idx, batch in items:
152167
onnx_output = self.model.onnx_embed_pairs(batch)

fastembed/rerank/cross_encoder/text_cross_encoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3333
]
3434
```
3535
"""
36-
result = []
36+
result: list[dict[str, Any]] = []
3737
for encoder in cls.CROSS_ENCODER_REGISTRY:
3838
result.extend(encoder.list_supported_models())
3939
return result

0 commit comments

Comments
 (0)