@@ -81,10 +81,10 @@ def __init__(
81
81
super ().__init__ (model_name , cache_dir , threads , ** kwargs )
82
82
83
83
model_description = self ._get_model_description (model_name )
84
- cache_dir = define_cache_dir (cache_dir )
84
+ self . cache_dir = define_cache_dir (cache_dir )
85
85
86
86
model_dir = self .download_model (
87
- model_description , cache_dir , local_files_only = self ._local_files_only
87
+ model_description , self . cache_dir , local_files_only = self ._local_files_only
88
88
)
89
89
90
90
self .load_onnx_model (
@@ -106,9 +106,7 @@ def __init__(
106
106
self .stemmer = get_stemmer (MODEL_TO_LANGUAGE [model_name ])
107
107
self .alpha = alpha
108
108
109
- def _filter_pair_tokens (
110
- self , tokens : List [Tuple [str , Any ]]
111
- ) -> List [Tuple [str , Any ]]:
109
+ def _filter_pair_tokens (self , tokens : List [Tuple [str , Any ]]) -> List [Tuple [str , Any ]]:
112
110
result = []
113
111
for token , value in tokens :
114
112
if token in self .stopwords or token in self .punctuation :
@@ -180,19 +178,13 @@ def _rescore_vector(self, vector: Dict[str, float]) -> Dict[int, float]:
180
178
181
179
return new_vector
182
180
183
- def _post_process_onnx_output (
184
- self , output : OnnxOutputContext
185
- ) -> Iterable [SparseEmbedding ]:
181
+ def _post_process_onnx_output (self , output : OnnxOutputContext ) -> Iterable [SparseEmbedding ]:
186
182
token_ids_batch = output .input_ids
187
183
188
184
# attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
189
- pooled_attention = (
190
- np .mean (output .model_output [:, :, 0 ], axis = 1 ) * output .attention_mask
191
- )
185
+ pooled_attention = np .mean (output .model_output [:, :, 0 ], axis = 1 ) * output .attention_mask
192
186
193
- for document_token_ids , attention_value in zip (
194
- token_ids_batch , pooled_attention
195
- ):
187
+ for document_token_ids , attention_value in zip (token_ids_batch , pooled_attention ):
196
188
document_tokens_with_ids = (
197
189
(idx , self .invert_vocab [token_id ])
198
190
for idx , token_id in enumerate (document_token_ids )
@@ -272,9 +264,7 @@ def _query_rehash(cls, tokens: Iterable[str]) -> Dict[int, float]:
272
264
result [token_id ] = 1.0
273
265
return result
274
266
275
- def query_embed (
276
- self , query : Union [str , Iterable [str ]], ** kwargs
277
- ) -> Iterable [SparseEmbedding ]:
267
+ def query_embed (self , query : Union [str , Iterable [str ]], ** kwargs ) -> Iterable [SparseEmbedding ]:
278
268
"""
279
269
To emulate BM25 behaviour, we don't need to use smart weights in the query, and
280
270
it's enough to just hash the tokens and assign a weight of 1.0 to them.
@@ -290,9 +280,7 @@ def query_embed(
290
280
filtered = self ._filter_pair_tokens (reconstructed )
291
281
stemmed = self ._stem_pair_tokens (filtered )
292
282
293
- yield SparseEmbedding .from_dict (
294
- self ._query_rehash (token for token , _ in stemmed )
295
- )
283
+ yield SparseEmbedding .from_dict (self ._query_rehash (token for token , _ in stemmed ))
296
284
297
285
@classmethod
298
286
def _get_worker_class (cls ) -> Type [TextEmbeddingWorker ]:
0 commit comments