Skip to content

Commit c8b1a18

Browse files
I8dNLojoein
andauthored
Cross encoders parallelism (#419)
* Merge master * rerank_pairs interface + parallelism support * remove test notebook * Removed unused code * New tests for cross encoders and new interface * Importing Self fix. We will need it for mypy support in newer versions * Removed Self typing * Removed non-needed changes from text * Isort + black * wip: start reviewing (#420) Co-authored-by: Dmitrii Ogn <[email protected]> * Test fix * Update fastembed/rerank/cross_encoder/text_cross_encoder.py Co-authored-by: George <[email protected]> * Update fastembed/rerank/cross_encoder/text_cross_encoder.py Co-authored-by: George <[email protected]> * Update fastembed/rerank/cross_encoder/text_cross_encoder.py Co-authored-by: George <[email protected]> * Update fastembed/rerank/cross_encoder/text_cross_encoder_base.py Co-authored-by: George <[email protected]> * Test for parallel processing + bugfix of PosixPath passing * Removed non-needed import and added docstring * Typing fix + argument passing * Test parametrization Moved to selected models set to test * Run base test on all models * Typing fix + improvement of input_names check * nit: fix post process, update docstring, update tokenize, remove redundant imports --------- Co-authored-by: George <[email protected]>
1 parent 3b5e4c8 commit c8b1a18

File tree

5 files changed

+267
-49
lines changed

5 files changed

+267
-49
lines changed

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

+53-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import Iterable, Any, Sequence, Optional
1+
from typing import Any, Iterable, Optional, Sequence, Type
22

33
from loguru import logger
44

55
from fastembed.common import OnnxProvider
6-
from fastembed.rerank.cross_encoder.onnx_text_model import OnnxCrossEncoderModel
7-
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
6+
from fastembed.common.onnx_model import OnnxOutputContext
87
from fastembed.common.utils import define_cache_dir
8+
from fastembed.rerank.cross_encoder.onnx_text_model import (
9+
OnnxCrossEncoderModel,
10+
TextRerankerWorker,
11+
)
12+
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
913

1014
supported_onnx_models = [
1115
{
@@ -91,7 +95,7 @@ def __init__(
9195
device_ids: Optional[list[int]] = None,
9296
lazy_load: bool = False,
9397
device_id: Optional[int] = None,
94-
**kwargs,
98+
**kwargs: Any,
9599
):
96100
"""
97101
Args:
@@ -138,7 +142,9 @@ def __init__(
138142
self.model_description = self._get_model_description(model_name)
139143
self.cache_dir = define_cache_dir(cache_dir)
140144
self._model_dir = self.download_model(
141-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
145+
self.model_description,
146+
self.cache_dir,
147+
local_files_only=self._local_files_only,
142148
)
143149

144150
if not self.lazy_load:
@@ -159,7 +165,7 @@ def rerank(
159165
query: str,
160166
documents: Iterable[str],
161167
batch_size: int = 64,
162-
**kwargs,
168+
**kwargs: Any,
163169
) -> Iterable[float]:
164170
"""Reranks documents based on their relevance to a given query.
165171
@@ -175,3 +181,44 @@ def rerank(
175181
yield from self._rerank_documents(
176182
query=query, documents=documents, batch_size=batch_size, **kwargs
177183
)
184+
185+
def rerank_pairs(
186+
self,
187+
pairs: Iterable[tuple[str, str]],
188+
batch_size: int = 64,
189+
parallel: Optional[int] = None,
190+
**kwargs: Any,
191+
) -> Iterable[float]:
192+
yield from self._rerank_pairs(
193+
model_name=self.model_name,
194+
cache_dir=str(self.cache_dir),
195+
pairs=pairs,
196+
batch_size=batch_size,
197+
parallel=parallel,
198+
providers=self.providers,
199+
cuda=self.cuda,
200+
device_ids=self.device_ids,
201+
**kwargs,
202+
)
203+
204+
@classmethod
205+
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
206+
return TextCrossEncoderWorker
207+
208+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
209+
return (float(elem) for elem in output.model_output)
210+
211+
212+
class TextCrossEncoderWorker(TextRerankerWorker):
213+
def init_embedding(
214+
self,
215+
model_name: str,
216+
cache_dir: str,
217+
**kwargs,
218+
) -> OnnxTextCrossEncoder:
219+
return OnnxTextCrossEncoder(
220+
model_name=model_name,
221+
cache_dir=cache_dir,
222+
threads=1,
223+
**kwargs,
224+
)
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
1-
from typing import Sequence, Optional, Iterable
1+
import os
2+
from multiprocessing import get_all_start_methods
23
from pathlib import Path
4+
from typing import Any, Iterable, Optional, Sequence, Type
35

46
import numpy as np
57
from tokenizers import Encoding
68

7-
from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext
9+
from fastembed.common.onnx_model import (
10+
EmbeddingWorker,
11+
OnnxModel,
12+
OnnxOutputContext,
13+
OnnxProvider,
14+
)
815
from fastembed.common.preprocessor_utils import load_tokenizer
916
from fastembed.common.utils import iter_batch
17+
from fastembed.parallel_processor import ParallelWorkerPool
1018

1119

12-
class OnnxCrossEncoderModel(OnnxModel):
20+
class OnnxCrossEncoderModel(OnnxModel[float]):
1321
ONNX_OUTPUT_NAMES: Optional[list[str]] = None
1422

23+
@classmethod
24+
def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
25+
raise NotImplementedError("Subclasses must implement this method")
26+
1527
def _load_onnx_model(
1628
self,
1729
model_dir: Path,
@@ -31,40 +43,108 @@ def _load_onnx_model(
3143
)
3244
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
3345

34-
def tokenize(self, query: str, documents: list[str], **kwargs) -> list[Encoding]:
35-
return self.tokenizer.encode_batch([(query, doc) for doc in documents])
36-
37-
def onnx_embed(self, query: str, documents: list[str], **kwargs) -> OnnxOutputContext:
38-
tokenized_input = self.tokenize(query, documents, **kwargs)
46+
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
47+
return self.tokenizer.encode_batch(pairs)
3948

49+
def _build_onnx_input(self, tokenized_input):
50+
input_names = {node.name for node in self.model.get_inputs()}
4051
inputs = {
4152
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
42-
"attention_mask": np.array(
43-
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
44-
),
4553
}
46-
input_names = {node.name for node in self.model.get_inputs()}
4754
if "token_type_ids" in input_names:
4855
inputs["token_type_ids"] = np.array(
4956
[enc.type_ids for enc in tokenized_input], dtype=np.int64
5057
)
58+
if "attention_mask" in input_names:
59+
inputs["attention_mask"] = np.array(
60+
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
61+
)
62+
return inputs
63+
64+
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
65+
pairs = [(query, doc) for doc in documents]
66+
return self.onnx_embed_pairs(pairs, **kwargs)
5167

68+
def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
69+
tokenized_input = self.tokenize(pairs, **kwargs)
70+
inputs = self._build_onnx_input(tokenized_input)
5271
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
5372
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
54-
return OnnxOutputContext(model_output=outputs[0][:, 0].tolist())
73+
relevant_output = outputs[0]
74+
scores = relevant_output[:, 0]
75+
return OnnxOutputContext(model_output=scores)
5576

5677
def _rerank_documents(
57-
self, query: str, documents: Iterable[str], batch_size: int, **kwargs
78+
self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
5879
) -> Iterable[float]:
5980
if not hasattr(self, "model") or self.model is None:
6081
self.load_onnx_model()
6182
for batch in iter_batch(documents, batch_size):
62-
yield from self.onnx_embed(query, batch, **kwargs).model_output
83+
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))
84+
85+
def _rerank_pairs(
86+
self,
87+
model_name: str,
88+
cache_dir: str,
89+
pairs: Iterable[tuple[str, str]],
90+
batch_size: int,
91+
parallel: Optional[int] = None,
92+
providers: Optional[Sequence[OnnxProvider]] = None,
93+
cuda: bool = False,
94+
device_ids: Optional[list[int]] = None,
95+
**kwargs: Any,
96+
) -> Iterable[float]:
97+
is_small = False
98+
99+
if isinstance(pairs, tuple):
100+
pairs = [pairs]
101+
is_small = True
102+
103+
if isinstance(pairs, list):
104+
if len(pairs) < batch_size:
105+
is_small = True
106+
107+
if parallel is None or is_small:
108+
if not hasattr(self, "model") or self.model is None:
109+
self.load_onnx_model()
110+
for batch in iter_batch(pairs, batch_size):
111+
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
112+
else:
113+
if parallel == 0:
114+
parallel = os.cpu_count()
115+
116+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
117+
params = {
118+
"model_name": model_name,
119+
"cache_dir": cache_dir,
120+
"providers": providers,
121+
**kwargs,
122+
}
123+
124+
pool = ParallelWorkerPool(
125+
num_workers=parallel or 1,
126+
worker=self._get_worker_class(),
127+
cuda=cuda,
128+
device_ids=device_ids,
129+
start_method=start_method,
130+
)
131+
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
132+
yield from self._post_process_onnx_output(batch)
133+
134+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
135+
raise NotImplementedError("Subclasses must implement this method")
63136

64137
def _preprocess_onnx_input(
65-
self, onnx_input: dict[str, np.ndarray], **kwargs
138+
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
66139
) -> dict[str, np.ndarray]:
67140
"""
68141
Preprocess the onnx input.
69142
"""
70143
return onnx_input
144+
145+
146+
class TextRerankerWorker(EmbeddingWorker):
147+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
148+
for idx, batch in items:
149+
onnx_output = self.model.onnx_embed_pairs(batch)
150+
yield idx, onnx_output

fastembed/rerank/cross_encoder/text_cross_encoder.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Any, Iterable, Optional, Sequence, Type
22

3-
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
4-
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
53
from fastembed.common import OnnxProvider
4+
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
5+
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
66

77

88
class TextCrossEncoder(TextCrossEncoderBase):
@@ -47,7 +47,7 @@ def __init__(
4747
cuda: bool = False,
4848
device_ids: Optional[list[int]] = None,
4949
lazy_load: bool = False,
50-
**kwargs,
50+
**kwargs: Any,
5151
):
5252
super().__init__(model_name, cache_dir, threads, **kwargs)
5353

@@ -72,7 +72,7 @@ def __init__(
7272
)
7373

7474
def rerank(
75-
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs
75+
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any
7676
) -> Iterable[float]:
7777
"""Rerank a list of documents based on a query.
7878
@@ -85,3 +85,36 @@ def rerank(
8585
Iterable of scores for each document
8686
"""
8787
yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs)
88+
89+
def rerank_pairs(
90+
self,
91+
pairs: Iterable[tuple[str, str]],
92+
batch_size: int = 64,
93+
parallel: Optional[int] = None,
94+
**kwargs: Any,
95+
) -> Iterable[float]:
96+
"""
97+
Rerank a list of query-document pairs.
98+
99+
Args:
100+
pairs (Iterable[tuple[str, str]]): An iterable of tuples, where each tuple contains a query and a document
101+
to be scored together.
102+
batch_size (int, optional): The number of query-document pairs to process in a single batch. Defaults to 64.
103+
parallel (Optional[int], optional): The number of parallel processes to use for reranking.
104+
If None, parallelization is disabled. Defaults to None.
105+
**kwargs (Any): Additional arguments to pass to the underlying reranking model.
106+
107+
Returns:
108+
Iterable[float]: An iterable of scores corresponding to each query-document pair in the input.
109+
Higher scores indicate a stronger match between the query and the document.
110+
111+
Example:
112+
>>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2")
113+
>>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")]
114+
>>> scores = list(encoder.rerank_pairs(pairs))
115+
>>> print(list(map(lambda x: round(x, 2), scores)))
116+
[-1.24, -10.6]
117+
"""
118+
yield from self.model.rerank_pairs(
119+
pairs, batch_size=batch_size, parallel=parallel, **kwargs
120+
)

fastembed/rerank/cross_encoder/text_cross_encoder_base.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Optional
1+
from typing import Any, Iterable, Optional
22

33
from fastembed.common.model_management import ModelManagement
44

@@ -23,7 +23,7 @@ def rerank(
2323
batch_size: int = 64,
2424
**kwargs,
2525
) -> Iterable[float]:
26-
"""Reranks a list of documents given a query.
26+
"""Rerank a list of documents given a query.
2727
2828
Args:
2929
query (str): The query to rerank the documents.
@@ -32,6 +32,27 @@ def rerank(
3232
**kwargs: Additional keyword argument to pass to the rerank method.
3333
3434
Yields:
35-
Iterable[float]: The scores of reranked the documents.
35+
Iterable[float]: The scores of the reranked the documents.
36+
"""
37+
raise NotImplementedError("This method should be overridden by subclasses")
38+
39+
def rerank_pairs(
40+
self,
41+
pairs: Iterable[tuple[str, str]],
42+
batch_size: int = 64,
43+
parallel: Optional[int] = None,
44+
**kwargs: Any,
45+
) -> Iterable[float]:
46+
"""Rerank query-document pairs.
47+
Args:
48+
pairs (Iterable[tuple[str, str]]): Query-document pairs to rerank
49+
batch_size (int): The batch size to use for reranking.
50+
parallel: parallel:
51+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
52+
If 0, use all available cores.
53+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
54+
**kwargs: Additional keyword argument to pass to the rerank method.
55+
Yields:
56+
Iterable[float]: Scores for each individual pair
3657
"""
3758
raise NotImplementedError("This method should be overridden by subclasses")

0 commit comments

Comments
 (0)