Skip to content

Commit ce98631

Browse files
authored
Fix spladepp parallelism (#169)
* fix: add get_worker_class implementation to spladepp * fix: add tests for parallel embed for spladepp
1 parent 0a4ed42 commit ce98631

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

fastembed/sparse/splade_pp.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
1+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
22

33
import numpy as np
44

@@ -114,6 +114,10 @@ def embed(
114114
parallel=parallel,
115115
)
116116

117+
@classmethod
118+
def _get_worker_class(cls) -> Type[EmbeddingWorker]:
119+
return SpladePPEmbeddingWorker
120+
117121

118122
class SpladePPEmbeddingWorker(EmbeddingWorker):
119123
def init_embedding(

tests/test_sparse_embeddings.py

+25
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,28 @@ def test_single_embedding():
7171

7272
for i, value in enumerate(result.values):
7373
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
74+
75+
76+
def test_parallel_processing():
77+
import numpy as np
78+
79+
model = SparseTextEmbedding(
80+
model_name="prithivida/Splade_PP_en_v1",
81+
)
82+
docs = ["hello world", "flag embedding"] * 30
83+
sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
84+
sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
85+
sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None))
86+
87+
assert len(sparse_embeddings) == len(sparse_embeddings_duo) == len(sparse_embeddings_all) == len(docs)
88+
89+
for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip(
90+
sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all
91+
):
92+
assert (
93+
sparse_embedding.indices.tolist()
94+
== sparse_embedding_duo.indices.tolist()
95+
== sparse_embedding_all.indices.tolist()
96+
)
97+
assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
98+
assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)

0 commit comments

Comments
 (0)