1
1
import os
2
2
from multiprocessing import get_all_start_methods
3
3
from pathlib import Path
4
- from typing import Any , Iterable , Optional , Sequence , Type , Union
4
+ from typing import Any , Iterable , Optional , Sequence , Type
5
5
6
6
import numpy as np
7
- from numpy .typing import NDArray
8
7
from tokenizers import Encoding
9
8
10
9
from fastembed .common .onnx_model import (
13
12
OnnxOutputContext ,
14
13
OnnxProvider ,
15
14
)
15
+ from fastembed .common .types import NumpyArray
16
16
from fastembed .common .preprocessor_utils import load_tokenizer
17
17
from fastembed .common .utils import iter_batch
18
18
from fastembed .parallel_processor import ParallelWorkerPool
@@ -47,11 +47,9 @@ def _load_onnx_model(
47
47
def tokenize (self , pairs : list [tuple [str , str ]], ** _ : Any ) -> list [Encoding ]:
48
48
return self .tokenizer .encode_batch (pairs )
49
49
50
- def _build_onnx_input (
51
- self , tokenized_input
52
- ) -> dict [str , NDArray [Union [np .float32 , np .int64 ]]]:
53
- input_names = {node .name for node in self .model .get_inputs ()}
54
- inputs = {
50
+ def _build_onnx_input (self , tokenized_input : list [Encoding ]) -> dict [str , NumpyArray ]:
51
+ input_names : set [str ] = {node .name for node in self .model .get_inputs ()}
52
+ inputs : dict [str , NumpyArray ] = {
55
53
"input_ids" : np .array ([enc .ids for enc in tokenized_input ], dtype = np .int64 ),
56
54
}
57
55
if "token_type_ids" in input_names :
@@ -74,7 +72,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO
74
72
onnx_input = self ._preprocess_onnx_input (inputs , ** kwargs )
75
73
outputs = self .model .run (self .ONNX_OUTPUT_NAMES , onnx_input )
76
74
relevant_output = outputs [0 ]
77
- scores = relevant_output [:, 0 ]
75
+ scores : NumpyArray = relevant_output [:, 0 ]
78
76
return OnnxOutputContext (model_output = scores )
79
77
80
78
def _rerank_documents (
@@ -100,7 +98,7 @@ def _rerank_pairs(
100
98
is_small = False
101
99
102
100
if isinstance (pairs , tuple ):
103
- pairs = [pairs ]
101
+ pairs = [pairs ] # type: ignore
104
102
is_small = True
105
103
106
104
if isinstance (pairs , list ):
@@ -138,15 +136,32 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float
138
136
raise NotImplementedError ("Subclasses must implement this method" )
139
137
140
138
def _preprocess_onnx_input (
141
- self , onnx_input : dict [str , np . ndarray ], ** kwargs : Any
142
- ) -> dict [str , np . ndarray ]:
139
+ self , onnx_input : dict [str , NumpyArray ], ** kwargs : Any
140
+ ) -> dict [str , NumpyArray ]:
143
141
"""
144
142
Preprocess the onnx input.
145
143
"""
146
144
return onnx_input
147
145
148
146
149
- class TextRerankerWorker (EmbeddingWorker ):
147
+ class TextRerankerWorker (EmbeddingWorker [float ]):
148
+ def __init__ (
149
+ self ,
150
+ model_name : str ,
151
+ cache_dir : str ,
152
+ ** kwargs : Any ,
153
+ ):
154
+ self .model : OnnxCrossEncoderModel
155
+ super ().__init__ (model_name , cache_dir , ** kwargs )
156
+
157
+ def init_embedding (
158
+ self ,
159
+ model_name : str ,
160
+ cache_dir : str ,
161
+ ** kwargs : Any ,
162
+ ) -> OnnxCrossEncoderModel :
163
+ raise NotImplementedError ()
164
+
150
165
def process (self , items : Iterable [tuple [int , Any ]]) -> Iterable [tuple [int , Any ]]:
151
166
for idx , batch in items :
152
167
onnx_output = self .model .onnx_embed_pairs (batch )
0 commit comments