Skip to content

Commit e1ecfe9

Browse files
authored
fix: fix None cache dir in parallel mode (#277)
1 parent 3312079 commit e1ecfe9

File tree

8 files changed

+35
-82
lines changed

8 files changed

+35
-82
lines changed

fastembed/image/onnx_embedding.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def __init__(
5656
super().__init__(model_name, cache_dir, threads, **kwargs)
5757

5858
model_description = self._get_model_description(model_name)
59-
cache_dir = define_cache_dir(cache_dir)
59+
self.cache_dir = define_cache_dir(cache_dir)
6060
model_dir = self.download_model(
61-
model_description, cache_dir, local_files_only=self._local_files_only
61+
model_description, self.cache_dir, local_files_only=self._local_files_only
6262
)
6363

6464
self.load_onnx_model(
@@ -122,16 +122,10 @@ def _preprocess_onnx_input(
122122

123123
return onnx_input
124124

125-
def _post_process_onnx_output(
126-
self, output: OnnxOutputContext
127-
) -> Iterable[np.ndarray]:
125+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
128126
return normalize(output.model_output).astype(np.float32)
129127

130128

131129
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
132-
def init_embedding(
133-
self, model_name: str, cache_dir: str, **kwargs
134-
) -> OnnxImageEmbedding:
135-
return OnnxImageEmbedding(
136-
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
137-
)
130+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
131+
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

fastembed/late_interaction/colbert.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def _post_process_onnx_output(
4343
if token_id in self.skip_list or token_id == self.pad_token_id:
4444
output.attention_mask[i, j] = 0
4545

46-
output.model_output *= np.expand_dims(output.attention_mask, 2).astype(
47-
np.float32
48-
)
46+
output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
4947
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
5048
norm_clamped = np.maximum(norm, 1e-12)
5149
output.model_output /= norm_clamped
@@ -126,10 +124,10 @@ def __init__(
126124
super().__init__(model_name, cache_dir, threads, **kwargs)
127125

128126
model_description = self._get_model_description(model_name)
129-
cache_dir = define_cache_dir(cache_dir)
127+
self.cache_dir = define_cache_dir(cache_dir)
130128

131129
model_dir = self.download_model(
132-
model_description, cache_dir, local_files_only=self._local_files_only
130+
model_description, self.cache_dir, local_files_only=self._local_files_only
133131
)
134132

135133
self.load_onnx_model(

fastembed/sparse/bm25.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def __init__(
7979
self.avg_len = avg_len
8080

8181
model_description = self._get_model_description(model_name)
82-
cache_dir = define_cache_dir(cache_dir)
82+
self.cache_dir = define_cache_dir(cache_dir)
8383

8484
model_dir = self.download_model(
85-
model_description, cache_dir, local_files_only=self._local_files_only
85+
model_description, self.cache_dir, local_files_only=self._local_files_only
8686
)
8787

8888
self.punctuation = set(string.punctuation)
@@ -133,9 +133,7 @@ def _embed_documents(
133133
for batch in iter_batch(documents, batch_size):
134134
yield from self.raw_embed(batch)
135135
else:
136-
start_method = (
137-
"forkserver" if "forkserver" in get_all_start_methods() else "spawn"
138-
)
136+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
139137
params = {
140138
"model_name": model_name,
141139
"cache_dir": cache_dir,
@@ -241,9 +239,7 @@ def _term_frequency(self, tokens: List[str]) -> Dict[int, float]:
241239
def compute_token_id(cls, token: str) -> int:
242240
return abs(mmh3.hash(token))
243241

244-
def query_embed(
245-
self, query: Union[str, Iterable[str]], **kwargs
246-
) -> Iterable[SparseEmbedding]:
242+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[SparseEmbedding]:
247243
"""To emulate BM25 behaviour, we don't need to use weights in the query, and
248244
it's enough to just hash the tokens and assign a weight of 1.0 to them.
249245
"""

fastembed/sparse/bm42.py

+8-20
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def __init__(
8181
super().__init__(model_name, cache_dir, threads, **kwargs)
8282

8383
model_description = self._get_model_description(model_name)
84-
cache_dir = define_cache_dir(cache_dir)
84+
self.cache_dir = define_cache_dir(cache_dir)
8585

8686
model_dir = self.download_model(
87-
model_description, cache_dir, local_files_only=self._local_files_only
87+
model_description, self.cache_dir, local_files_only=self._local_files_only
8888
)
8989

9090
self.load_onnx_model(
@@ -106,9 +106,7 @@ def __init__(
106106
self.stemmer = get_stemmer(MODEL_TO_LANGUAGE[model_name])
107107
self.alpha = alpha
108108

109-
def _filter_pair_tokens(
110-
self, tokens: List[Tuple[str, Any]]
111-
) -> List[Tuple[str, Any]]:
109+
def _filter_pair_tokens(self, tokens: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]:
112110
result = []
113111
for token, value in tokens:
114112
if token in self.stopwords or token in self.punctuation:
@@ -180,19 +178,13 @@ def _rescore_vector(self, vector: Dict[str, float]) -> Dict[int, float]:
180178

181179
return new_vector
182180

183-
def _post_process_onnx_output(
184-
self, output: OnnxOutputContext
185-
) -> Iterable[SparseEmbedding]:
181+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[SparseEmbedding]:
186182
token_ids_batch = output.input_ids
187183

188184
# attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
189-
pooled_attention = (
190-
np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
191-
)
185+
pooled_attention = np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
192186

193-
for document_token_ids, attention_value in zip(
194-
token_ids_batch, pooled_attention
195-
):
187+
for document_token_ids, attention_value in zip(token_ids_batch, pooled_attention):
196188
document_tokens_with_ids = (
197189
(idx, self.invert_vocab[token_id])
198190
for idx, token_id in enumerate(document_token_ids)
@@ -272,9 +264,7 @@ def _query_rehash(cls, tokens: Iterable[str]) -> Dict[int, float]:
272264
result[token_id] = 1.0
273265
return result
274266

275-
def query_embed(
276-
self, query: Union[str, Iterable[str]], **kwargs
277-
) -> Iterable[SparseEmbedding]:
267+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[SparseEmbedding]:
278268
"""
279269
To emulate BM25 behaviour, we don't need to use smart weights in the query, and
280270
it's enough to just hash the tokens and assign a weight of 1.0 to them.
@@ -290,9 +280,7 @@ def query_embed(
290280
filtered = self._filter_pair_tokens(reconstructed)
291281
stemmed = self._stem_pair_tokens(filtered)
292282

293-
yield SparseEmbedding.from_dict(
294-
self._query_rehash(token for token, _ in stemmed)
295-
)
283+
yield SparseEmbedding.from_dict(self._query_rehash(token for token, _ in stemmed))
296284

297285
@classmethod
298286
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:

fastembed/sparse/splade_pp.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@
3636

3737

3838
class SpladePP(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
39-
def _post_process_onnx_output(
40-
self, output: OnnxOutputContext
41-
) -> Iterable[SparseEmbedding]:
39+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[SparseEmbedding]:
4240
relu_log = np.log(1 + np.maximum(output.model_output, 0))
4341

4442
weighted_log = relu_log * np.expand_dims(output.attention_mask, axis=-1)
@@ -84,10 +82,10 @@ def __init__(
8482
super().__init__(model_name, cache_dir, threads, **kwargs)
8583

8684
model_description = self._get_model_description(model_name)
87-
cache_dir = define_cache_dir(cache_dir)
85+
self.cache_dir = define_cache_dir(cache_dir)
8886

8987
model_dir = self.download_model(
90-
model_description, cache_dir, local_files_only=self._local_files_only
88+
model_description, self.cache_dir, local_files_only=self._local_files_only
9189
)
9290

9391
self.load_onnx_model(

fastembed/text/mini_lm_embedding.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
2828
return MiniLMEmbeddingWorker
2929

3030
@classmethod
31-
def mean_pooling(self, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
31+
def mean_pooling(cls, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
3232
token_embeddings = model_output
3333
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
34-
input_mask_expanded = np.tile(
35-
input_mask_expanded, (1, 1, token_embeddings.shape[-1])
36-
)
34+
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
3735
input_mask_expanded = input_mask_expanded.astype(float)
3836
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
3937
sum_mask = np.sum(input_mask_expanded, axis=1)
@@ -49,21 +47,12 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
4947
"""
5048
return supported_mini_lm_models
5149

52-
def _post_process_onnx_output(
53-
self, output: OnnxOutputContext
54-
) -> Iterable[np.ndarray]:
50+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
5551
embeddings = output.model_output
5652
attn_mask = output.attention_mask
5753
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
5854

5955

6056
class MiniLMEmbeddingWorker(OnnxTextEmbeddingWorker):
61-
def init_embedding(
62-
self,
63-
model_name: str,
64-
cache_dir: str,
65-
**kwargs
66-
) -> OnnxTextEmbedding:
67-
return MiniLMOnnxEmbedding(
68-
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
69-
)
57+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxTextEmbedding:
58+
return MiniLMOnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

fastembed/text/onnx_embedding.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ def __init__(
219219
super().__init__(model_name, cache_dir, threads, **kwargs)
220220

221221
model_description = self._get_model_description(model_name)
222-
cache_dir = define_cache_dir(cache_dir)
222+
self.cache_dir = define_cache_dir(cache_dir)
223223
model_dir = self.download_model(
224-
model_description, cache_dir, local_files_only=self._local_files_only
224+
model_description, self.cache_dir, local_files_only=self._local_files_only
225225
)
226226

227227
self.load_onnx_model(
@@ -274,9 +274,7 @@ def _preprocess_onnx_input(
274274
"""
275275
return onnx_input
276276

277-
def _post_process_onnx_output(
278-
self, output: OnnxOutputContext
279-
) -> Iterable[np.ndarray]:
277+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
280278
embeddings = output.model_output
281279
return normalize(embeddings[:, 0]).astype(np.float32)
282280

@@ -288,6 +286,4 @@ def init_embedding(
288286
cache_dir: str,
289287
**kwargs,
290288
) -> OnnxTextEmbedding:
291-
return OnnxTextEmbedding(
292-
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
293-
)
289+
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

tests/test_late_interaction_embeddings.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def test_single_embedding():
7777

7878
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
7979
print("evaluating", model_name)
80-
model = LateInteractionTextEmbedding(
81-
model_name=model_name, cache_dir="colbert-cache"
82-
)
80+
model = LateInteractionTextEmbedding(model_name=model_name)
8381
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
8482
token_num, abridged_dim = expected_result.shape
8583
assert np.allclose(result[:, :abridged_dim], expected_result, atol=10e-4)
@@ -90,18 +88,14 @@ def test_single_embedding_query():
9088

9189
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
9290
print("evaluating", model_name)
93-
model = LateInteractionTextEmbedding(
94-
model_name=model_name, cache_dir="colbert-cache"
95-
)
91+
model = LateInteractionTextEmbedding(model_name=model_name)
9692
result = next(iter(model.query_embed(queries_to_embed)))
9793
token_num, abridged_dim = expected_result.shape
9894
assert np.allclose(result[:, :abridged_dim], expected_result, atol=10e-4)
9995

10096

10197
def test_parallel_processing():
102-
model = LateInteractionTextEmbedding(
103-
model_name="colbert-ir/colbertv2.0", cache_dir="colbert-cache"
104-
)
98+
model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
10599
token_dim = 128
106100
docs = ["hello world", "flag embedding"] * 100
107101
embeddings = list(model.embed(docs, batch_size=10, parallel=2))

0 commit comments

Comments
 (0)