-
Notifications
You must be signed in to change notification settings - Fork 127
/
Copy pathcustom_reranker_model.py
46 lines (40 loc) · 1.4 KB
/
custom_reranker_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Optional, Sequence, Any
from fastembed.common import OnnxProvider
from fastembed.common.model_description import BaseModelDescription
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
class CustomCrossEncoderModel(OnnxTextCrossEncoder):
SUPPORTED_MODELS: list[BaseModelDescription] = []
def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs: Any,
):
super().__init__(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
device_id=device_id,
specific_model_path=specific_model_path,
**kwargs,
)
@classmethod
def _list_supported_models(cls) -> list[BaseModelDescription]:
return cls.SUPPORTED_MODELS
@classmethod
def add_model(
cls,
model_description: BaseModelDescription,
) -> None:
cls.SUPPORTED_MODELS.append(model_description)