From 9ed91785662af27e6ee338d3110c5e2e242fdd2e Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 14:48:19 +0100 Subject: [PATCH 1/5] Custom rerankers support --- fastembed/rerank/cross_encoder/__init__.py | 3 +- .../cross_encoder/custom_reranker_model.py | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 fastembed/rerank/cross_encoder/custom_reranker_model.py diff --git a/fastembed/rerank/cross_encoder/__init__.py b/fastembed/rerank/cross_encoder/__init__.py index 23c1e3591..58fbe6027 100644 --- a/fastembed/rerank/cross_encoder/__init__.py +++ b/fastembed/rerank/cross_encoder/__init__.py @@ -1,3 +1,4 @@ from fastembed.rerank.cross_encoder.text_cross_encoder import TextCrossEncoder +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel -__all__ = ["TextCrossEncoder"] +__all__ = ["TextCrossEncoder", "CustomCrossEncoderModel"] diff --git a/fastembed/rerank/cross_encoder/custom_reranker_model.py b/fastembed/rerank/cross_encoder/custom_reranker_model.py new file mode 100644 index 000000000..0c62ea720 --- /dev/null +++ b/fastembed/rerank/cross_encoder/custom_reranker_model.py @@ -0,0 +1,48 @@ +from typing import Optional, Sequence, Any + +from fastembed.common import OnnxProvider +from fastembed.common.model_description import ( + DenseModelDescription, +) +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder + + +class CustomCrossEncoderModel(OnnxTextCrossEncoder): + SUPPORTED_MODELS: list[DenseModelDescription] = [] + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + device_id=device_id, + specific_model_path=specific_model_path, + **kwargs, + ) + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return cls.SUPPORTED_MODELS + + @classmethod + def add_model( + cls, + model_description: DenseModelDescription, + ) -> None: + cls.SUPPORTED_MODELS.append(model_description) From 73726b67aa554b0a2886ed6698595fb051f3ff2e Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 15:38:05 +0100 Subject: [PATCH 2/5] Test for reranker_custom_model --- fastembed/rerank/cross_encoder/__init__.py | 3 +- .../cross_encoder/custom_reranker_model.py | 4 +- .../cross_encoder/text_cross_encoder.py | 42 ++++++++++++++++- tests/test_custom_models.py | 46 +++++++++++++++++++ 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/fastembed/rerank/cross_encoder/__init__.py b/fastembed/rerank/cross_encoder/__init__.py index 58fbe6027..23c1e3591 100644 --- a/fastembed/rerank/cross_encoder/__init__.py +++ b/fastembed/rerank/cross_encoder/__init__.py @@ -1,4 +1,3 @@ from fastembed.rerank.cross_encoder.text_cross_encoder import TextCrossEncoder -from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel -__all__ = ["TextCrossEncoder", "CustomCrossEncoderModel"] +__all__ = ["TextCrossEncoder"] diff --git a/fastembed/rerank/cross_encoder/custom_reranker_model.py b/fastembed/rerank/cross_encoder/custom_reranker_model.py index 0c62ea720..68ce0bc55 100644 --- a/fastembed/rerank/cross_encoder/custom_reranker_model.py +++ b/fastembed/rerank/cross_encoder/custom_reranker_model.py @@ -1,9 +1,7 @@ from typing import Optional, Sequence, Any from fastembed.common import OnnxProvider -from fastembed.common.model_description import ( - DenseModelDescription, -) +from fastembed.common.model_description import DenseModelDescription from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index 573053e07..b6fb27616 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -3,13 +3,20 @@ from fastembed.common import OnnxProvider from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel + from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase -from fastembed.common.model_description import BaseModelDescription +from fastembed.common.model_description import ( + DenseModelDescription, + ModelSource, + BaseModelDescription, +) class TextCrossEncoder(TextCrossEncoderBase): CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [ OnnxTextCrossEncoder, + CustomCrossEncoderModel, ] @classmethod @@ -124,3 +131,36 @@ def rerank_pairs( yield from self.model.rerank_pairs( pairs, batch_size=batch_size, parallel=parallel, **kwargs ) + + @classmethod + def add_custom_model( + cls, + model: str, + sources: ModelSource, + dim: int, + model_file: str = "onnx/model.onnx", + description: str = "", + license: str = "", + size_in_gb: float = 0.0, + additional_files: Optional[list[str]] = None, + ) -> None: + registered_models = cls._list_supported_models() + for registered_model in registered_models: + if model == registered_model.model: + raise ValueError( + f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, " + f"please use another model name" + ) + + CustomCrossEncoderModel.add_model( + DenseModelDescription( + model=model, + sources=sources, + dim=dim, + model_file=model_file, + description=description, + license=license, + size_in_GB=size_in_gb, + additional_files=additional_files or [], + ), + ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 082d4c8c3..af82f8b2a 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -7,6 +7,8 @@ from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import normalize, mean_pooling from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel +from fastembed.rerank.cross_encoder import TextCrossEncoder from fastembed.text.text_embedding import TextEmbedding from tests.utils import delete_model_cache @@ -65,6 +67,50 @@ def test_text_custom_model(): delete_model_cache(model.model._model_dir) +def test_cross_encoder_custom_model(): + is_ci = os.getenv("CI") + custom_model_name = "viplao5/bge-reranker-v2-m3-onnx" + canonical_vector = np.array([1.3330, -1.2428], dtype=np.float32) + dim = 1 + size_in_gb = 2.5 + source = ModelSource(hf=custom_model_name) + + TextCrossEncoder.add_custom_model( + custom_model_name, + model_file="model.onnx", + sources=source, + dim=dim, + size_in_gb=size_in_gb, + # additional_files=['model.onnx_data'] + ) + + assert CustomCrossEncoderModel.SUPPORTED_MODELS[0] == DenseModelDescription( + model=custom_model_name, + sources=source, + model_file="model.onnx", + description="", + license="", + size_in_GB=size_in_gb, + additional_files=[], + dim=dim, + tasks={}, + ) + + model = TextCrossEncoder(custom_model_name) + pairs = [ + ("What is AI?", "Artificial intelligence is ..."), + ("What is ML?", "Machine learning is ..."), + ] + scores = list(model.rerank_pairs(pairs)) + + embeddings = np.stack(scores, axis=0) + assert embeddings.shape == (2,) + + assert np.allclose(embeddings[: canonical_vector.shape[0]], canonical_vector, atol=1e-3) + if is_ci: + delete_model_cache(model.model._model_dir) + + def test_mock_add_custom_models(): dim = 5 size_in_gb = 0.1 From ec83910e6fc8a8dd486246da02b877155e4b2f3e Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 15:43:01 +0100 Subject: [PATCH 3/5] test fix --- tests/test_custom_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index af82f8b2a..ba5402e59 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -81,7 +81,7 @@ def test_cross_encoder_custom_model(): sources=source, dim=dim, size_in_gb=size_in_gb, - # additional_files=['model.onnx_data'] + additional_files=["model.onnx_data"], ) assert CustomCrossEncoderModel.SUPPORTED_MODELS[0] == DenseModelDescription( @@ -91,7 +91,7 @@ def test_cross_encoder_custom_model(): description="", license="", size_in_GB=size_in_gb, - additional_files=[], + additional_files=["model.onnx_data"], dim=dim, tasks={}, ) From d2a312a89107d78c3c26427a75ed05764c2d0c50 Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 15:48:24 +0100 Subject: [PATCH 4/5] Model description type fix --- fastembed/rerank/cross_encoder/custom_reranker_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastembed/rerank/cross_encoder/custom_reranker_model.py b/fastembed/rerank/cross_encoder/custom_reranker_model.py index 68ce0bc55..7964f79c6 100644 --- a/fastembed/rerank/cross_encoder/custom_reranker_model.py +++ b/fastembed/rerank/cross_encoder/custom_reranker_model.py @@ -1,12 +1,12 @@ from typing import Optional, Sequence, Any from fastembed.common import OnnxProvider -from fastembed.common.model_description import DenseModelDescription +from fastembed.common.model_description import BaseModelDescription from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder class CustomCrossEncoderModel(OnnxTextCrossEncoder): - SUPPORTED_MODELS: list[DenseModelDescription] = [] + SUPPORTED_MODELS: list[BaseModelDescription] = [] def __init__( self, @@ -35,12 +35,12 @@ def __init__( ) @classmethod - def _list_supported_models(cls) -> list[DenseModelDescription]: + def _list_supported_models(cls) -> list[BaseModelDescription]: return cls.SUPPORTED_MODELS @classmethod def add_model( cls, - model_description: DenseModelDescription, + model_description: BaseModelDescription, ) -> None: cls.SUPPORTED_MODELS.append(model_description) From 2021d5c64b91519b89cdebfd82e51a4a95f01212 Mon Sep 17 00:00:00 2001 From: "d.rudenko" Date: Mon, 10 Mar 2025 16:33:56 +0100 Subject: [PATCH 5/5] Test fix --- .../cross_encoder/text_cross_encoder.py | 7 +-- tests/test_custom_models.py | 57 ++++++++++++++----- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index b6fb27616..f204a0bd3 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -7,7 +7,6 @@ from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase from fastembed.common.model_description import ( - DenseModelDescription, ModelSource, BaseModelDescription, ) @@ -137,7 +136,6 @@ def add_custom_model( cls, model: str, sources: ModelSource, - dim: int, model_file: str = "onnx/model.onnx", description: str = "", license: str = "", @@ -153,14 +151,13 @@ def add_custom_model( ) CustomCrossEncoderModel.add_model( - DenseModelDescription( + BaseModelDescription( model=model, sources=sources, - dim=dim, model_file=model_file, description=description, license=license, size_in_GB=size_in_gb, additional_files=additional_files or [], - ), + ) ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index ba5402e59..42e6a20f8 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3,7 +3,12 @@ import numpy as np import pytest -from fastembed.common.model_description import PoolingType, ModelSource, DenseModelDescription +from fastembed.common.model_description import ( + PoolingType, + ModelSource, + DenseModelDescription, + BaseModelDescription, +) from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import normalize, mean_pooling from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig @@ -16,8 +21,10 @@ @pytest.fixture(autouse=True) def restore_custom_models_fixture(): CustomTextEmbedding.SUPPORTED_MODELS = [] + CustomCrossEncoderModel.SUPPORTED_MODELS = [] yield CustomTextEmbedding.SUPPORTED_MODELS = [] + CustomCrossEncoderModel.SUPPORTED_MODELS = [] def test_text_custom_model(): @@ -69,31 +76,27 @@ def test_text_custom_model(): def test_cross_encoder_custom_model(): is_ci = os.getenv("CI") - custom_model_name = "viplao5/bge-reranker-v2-m3-onnx" - canonical_vector = np.array([1.3330, -1.2428], dtype=np.float32) - dim = 1 - size_in_gb = 2.5 + custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2" + size_in_gb = 0.08 source = ModelSource(hf=custom_model_name) + canonical_vector = np.array([-5.7170815, -11.112114], dtype=np.float32) TextCrossEncoder.add_custom_model( custom_model_name, - model_file="model.onnx", + model_file="onnx/model.onnx", sources=source, - dim=dim, size_in_gb=size_in_gb, - additional_files=["model.onnx_data"], + additional_files=["onnx/model.onnx_data"], ) - assert CustomCrossEncoderModel.SUPPORTED_MODELS[0] == DenseModelDescription( + assert CustomCrossEncoderModel.SUPPORTED_MODELS[0] == BaseModelDescription( model=custom_model_name, sources=source, - model_file="model.onnx", + model_file="onnx/model.onnx", description="", license="", size_in_GB=size_in_gb, - additional_files=["model.onnx_data"], - dim=dim, - tasks={}, + additional_files=["onnx/model.onnx_data"], ) model = TextCrossEncoder(custom_model_name) @@ -105,8 +108,7 @@ def test_cross_encoder_custom_model(): embeddings = np.stack(scores, axis=0) assert embeddings.shape == (2,) - - assert np.allclose(embeddings[: canonical_vector.shape[0]], canonical_vector, atol=1e-3) + assert np.allclose(embeddings, canonical_vector, atol=1e-3) if is_ci: delete_model_cache(model.model._model_dir) @@ -202,3 +204,28 @@ def test_do_not_add_existing_model(): dim=384, size_in_gb=0.47, ) + + +def test_do_not_add_existing_cross_encoder(): + existing_base_model = "Xenova/ms-marco-MiniLM-L-6-v2" + custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2" + + with pytest.raises(ValueError, match=f"Model {existing_base_model} is already registered"): + TextCrossEncoder.add_custom_model( + existing_base_model, + sources=ModelSource(hf=existing_base_model), + size_in_gb=0.08, + ) + + TextCrossEncoder.add_custom_model( + custom_model_name, + sources=ModelSource(hf=existing_base_model), + size_in_gb=0.08, + ) + + with pytest.raises(ValueError, match=f"Model {custom_model_name} is already registered"): + TextCrossEncoder.add_custom_model( + custom_model_name, + sources=ModelSource(hf=custom_model_name), + size_in_gb=0.08, + )