@@ -110,7 +110,7 @@ def __init__(
110
110
self .device_id = None
111
111
112
112
self .model_description = self ._get_model_description (model_name )
113
- self .cache_dir = define_cache_dir (cache_dir )
113
+ self .cache_dir = str ( define_cache_dir (cache_dir ) )
114
114
115
115
self ._model_dir = self .download_model (
116
116
self .model_description ,
@@ -119,10 +119,10 @@ def __init__(
119
119
specific_model_path = specific_model_path ,
120
120
)
121
121
122
- self .invert_vocab = {}
122
+ self .invert_vocab : dict [ int , str ] = {}
123
123
124
- self .special_tokens = set ()
125
- self .special_tokens_ids = set ()
124
+ self .special_tokens : set [ str ] = set ()
125
+ self .special_tokens_ids : set [ int ] = set ()
126
126
self .punctuation = set (string .punctuation )
127
127
self .stopwords = set (self ._load_stopwords (self ._model_dir ))
128
128
self .stemmer = SnowballStemmer (MODEL_TO_LANGUAGE [model_name ])
@@ -147,15 +147,15 @@ def load_onnx_model(self) -> None:
147
147
self .stopwords = set (self ._load_stopwords (self ._model_dir ))
148
148
149
149
def _filter_pair_tokens (self , tokens : list [tuple [str , Any ]]) -> list [tuple [str , Any ]]:
150
- result = []
150
+ result : list [ tuple [ str , Any ]] = []
151
151
for token , value in tokens :
152
152
if token in self .stopwords or token in self .punctuation :
153
153
continue
154
154
result .append ((token , value ))
155
155
return result
156
156
157
157
def _stem_pair_tokens (self , tokens : list [tuple [str , Any ]]) -> list [tuple [str , Any ]]:
158
- result = []
158
+ result : list [ tuple [ str , Any ]] = []
159
159
for token , value in tokens :
160
160
processed_token = self .stemmer .stem_word (token )
161
161
result .append ((processed_token , value ))
@@ -165,7 +165,7 @@ def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, An
165
165
def _aggregate_weights (
166
166
cls , tokens : list [tuple [str , list [int ]]], weights : list [float ]
167
167
) -> list [tuple [str , float ]]:
168
- result = []
168
+ result : list [ tuple [ str , float ]] = []
169
169
for token , idxs in tokens :
170
170
sum_weight = sum (weights [idx ] for idx in idxs )
171
171
result .append ((token , sum_weight ))
@@ -174,9 +174,9 @@ def _aggregate_weights(
174
174
def _reconstruct_bpe (
175
175
self , bpe_tokens : Iterable [tuple [int , str ]]
176
176
) -> list [tuple [str , list [int ]]]:
177
- result = []
178
- acc = ""
179
- acc_idx = []
177
+ result : list [ tuple [ str , list [ int ]]] = []
178
+ acc : str = ""
179
+ acc_idx : list [ int ] = []
180
180
181
181
continuing_subword_prefix = self .tokenizer .model .continuing_subword_prefix
182
182
continuing_subword_prefix_len = len (continuing_subword_prefix )
@@ -206,7 +206,7 @@ def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
206
206
So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
207
207
"""
208
208
209
- new_vector = {}
209
+ new_vector : dict [ int , float ] = {}
210
210
211
211
for token , value in vector .items ():
212
212
token_id = abs (mmh3 .hash (token ))
@@ -241,7 +241,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars
241
241
242
242
weighted = self ._aggregate_weights (stemmed , attention_value )
243
243
244
- max_token_weight = {}
244
+ max_token_weight : dict [ str , float ] = {}
245
245
246
246
for token , weight in weighted :
247
247
max_token_weight [token ] = max (max_token_weight .get (token , 0 ), weight )
@@ -304,7 +304,7 @@ def embed(
304
304
305
305
@classmethod
306
306
def _query_rehash (cls , tokens : Iterable [str ]) -> dict [int , float ]:
307
- result = {}
307
+ result : dict [ int , float ] = {}
308
308
for token in tokens :
309
309
token_id = abs (mmh3 .hash (token ))
310
310
result [token_id ] = 1.0
@@ -334,11 +334,11 @@ def query_embed(
334
334
yield SparseEmbedding .from_dict (self ._query_rehash (token for token , _ in stemmed ))
335
335
336
336
@classmethod
337
- def _get_worker_class (cls ) -> Type [TextEmbeddingWorker ]:
337
+ def _get_worker_class (cls ) -> Type [TextEmbeddingWorker [ SparseEmbedding ] ]:
338
338
return Bm42TextEmbeddingWorker
339
339
340
340
341
- class Bm42TextEmbeddingWorker (TextEmbeddingWorker ):
341
+ class Bm42TextEmbeddingWorker (TextEmbeddingWorker [ SparseEmbedding ] ):
342
342
def init_embedding (self , model_name : str , cache_dir : str , ** kwargs : Any ) -> Bm42 :
343
343
return Bm42 (
344
344
model_name = model_name ,
0 commit comments