Skip to content

Commit 5461012

Browse files
authored
new: add bm25, fix param propagation in parallel mode, fix bm42 parallel (#274)
* new: add bm25, fix param propagation in parallel mode, fix bm42 parallel * refactoring: remove redundant example * fix: fix mp start method in bm25 * refactoring: refactor token id generation * new: replace model repository
1 parent 29cfcda commit 5461012

18 files changed

+488
-92
lines changed

fastembed/common/onnx_model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,21 @@ def init_embedding(
9292
self,
9393
model_name: str,
9494
cache_dir: str,
95+
**kwargs,
9596
) -> OnnxModel:
9697
raise NotImplementedError()
9798

9899
def __init__(
99100
self,
100101
model_name: str,
101102
cache_dir: str,
103+
**kwargs,
102104
):
103-
self.model = self.init_embedding(model_name, cache_dir)
105+
self.model = self.init_embedding(model_name, cache_dir, **kwargs)
104106

105107
@classmethod
106108
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker":
107-
return cls(
108-
model_name=model_name,
109-
cache_dir=cache_dir,
110-
)
109+
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
111110

112111
def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]:
113112
raise NotImplementedError("Subclasses must implement this method")

fastembed/image/image_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
5454
if any(model_name.lower() == model["model"].lower() for model in supported_models):
5555
self.model = EMBEDDING_MODEL_TYPE(
56-
model_name, cache_dir, threads, providers=providers, **kwargs
56+
model_name, cache_dir, threads=threads, providers=providers, **kwargs
5757
)
5858
return
5959

fastembed/image/image_embedding_base.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def embed(
3131
3232
Args:
3333
images - The list of image paths to preprocess and embed.
34+
batch_size: Batch size for encoding
35+
parallel:
36+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
37+
If 0, use all available cores.
38+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
3439
**kwargs: Additional keyword argument to pass to the embed method.
3540
3641
Yields:

fastembed/image/onnx_embedding.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def embed(
106106
images=images,
107107
batch_size=batch_size,
108108
parallel=parallel,
109+
**kwargs,
109110
)
110111

111112
@classmethod
@@ -126,9 +127,5 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.nd
126127

127128

128129
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
129-
def init_embedding(
130-
self,
131-
model_name: str,
132-
cache_dir: str,
133-
) -> OnnxImageEmbedding:
134-
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1)
130+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
131+
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

fastembed/image/onnx_image_model.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def onnx_embed(self, images: List[PathInput], **kwargs) -> OnnxOutputContext:
5959
onnx_input = self._preprocess_onnx_input(onnx_input)
6060
model_output = self.model.run(None, onnx_input)
6161
embeddings = model_output[0].reshape(len(images), -1)
62-
return OnnxOutputContext(
63-
model_output=embeddings
64-
)
62+
return OnnxOutputContext(model_output=embeddings)
6563

6664
def _embed_images(
6765
self,
@@ -70,6 +68,7 @@ def _embed_images(
7068
images: ImageInput,
7169
batch_size: int = 256,
7270
parallel: Optional[int] = None,
71+
**kwargs,
7372
) -> Iterable[T]:
7473
is_small = False
7574

@@ -89,10 +88,7 @@ def _embed_images(
8988
yield from self._post_process_onnx_output(self.onnx_embed(batch))
9089
else:
9190
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
92-
params = {
93-
"model_name": model_name,
94-
"cache_dir": cache_dir,
95-
}
91+
params = {"model_name": model_name, "cache_dir": cache_dir, **kwargs}
9692
pool = ParallelWorkerPool(
9793
parallel, self._get_worker_class(), start_method=start_method
9894
)

fastembed/late_interaction/colbert.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def embed(
171171
documents=documents,
172172
batch_size=batch_size,
173173
parallel=parallel,
174+
**kwargs,
174175
)
175176

176177
def query_embed(self, query: Union[str, List[str]], **kwargs) -> np.ndarray:
@@ -188,9 +189,5 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
188189

189190

190191
class ColbertEmbeddingWorker(TextEmbeddingWorker):
191-
def init_embedding(
192-
self,
193-
model_name: str,
194-
cache_dir: str,
195-
) -> Colbert:
196-
return Colbert(model_name=model_name, cache_dir=cache_dir, threads=1)
192+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> Colbert:
193+
return Colbert(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

0 commit comments

Comments
 (0)