From 9ed91785662af27e6ee338d3110c5e2e242fdd2e Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 14:48:19 +0100 Subject: [PATCH] Custom rerankers support --- fastembed/rerank/cross_encoder/__init__.py | 3 +- .../cross_encoder/custom_reranker_model.py | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 fastembed/rerank/cross_encoder/custom_reranker_model.py diff --git a/fastembed/rerank/cross_encoder/__init__.py b/fastembed/rerank/cross_encoder/__init__.py index 23c1e3591..58fbe6027 100644 --- a/fastembed/rerank/cross_encoder/__init__.py +++ b/fastembed/rerank/cross_encoder/__init__.py @@ -1,3 +1,4 @@ from fastembed.rerank.cross_encoder.text_cross_encoder import TextCrossEncoder +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel -__all__ = ["TextCrossEncoder"] +__all__ = ["TextCrossEncoder", "CustomCrossEncoderModel"] diff --git a/fastembed/rerank/cross_encoder/custom_reranker_model.py b/fastembed/rerank/cross_encoder/custom_reranker_model.py new file mode 100644 index 000000000..0c62ea720 --- /dev/null +++ b/fastembed/rerank/cross_encoder/custom_reranker_model.py @@ -0,0 +1,48 @@ +from typing import Optional, Sequence, Any + +from fastembed.common import OnnxProvider +from fastembed.common.model_description import ( + DenseModelDescription, +) +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder + + +class CustomCrossEncoderModel(OnnxTextCrossEncoder): + SUPPORTED_MODELS: list[DenseModelDescription] = [] + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + device_id=device_id, + specific_model_path=specific_model_path, + **kwargs, + ) + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return cls.SUPPORTED_MODELS + + @classmethod + def add_model( + cls, + model_description: DenseModelDescription, + ) -> None: + cls.SUPPORTED_MODELS.append(model_description)