diff --git a/fastembed/rerank/cross_encoder/custom_reranker_model.py b/fastembed/rerank/cross_encoder/custom_reranker_model.py new file mode 100644 index 00000000..7964f79c --- /dev/null +++ b/fastembed/rerank/cross_encoder/custom_reranker_model.py @@ -0,0 +1,46 @@ +from typing import Optional, Sequence, Any + +from fastembed.common import OnnxProvider +from fastembed.common.model_description import BaseModelDescription +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder + + +class CustomCrossEncoderModel(OnnxTextCrossEncoder): + SUPPORTED_MODELS: list[BaseModelDescription] = [] + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + device_id=device_id, + specific_model_path=specific_model_path, + **kwargs, + ) + + @classmethod + def _list_supported_models(cls) -> list[BaseModelDescription]: + return cls.SUPPORTED_MODELS + + @classmethod + def add_model( + cls, + model_description: BaseModelDescription, + ) -> None: + cls.SUPPORTED_MODELS.append(model_description) diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index 573053e0..f204a0bd 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -3,13 +3,19 @@ from fastembed.common import OnnxProvider from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel + from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase -from fastembed.common.model_description import BaseModelDescription +from fastembed.common.model_description import ( + ModelSource, + BaseModelDescription, +) class TextCrossEncoder(TextCrossEncoderBase): CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [ OnnxTextCrossEncoder, + CustomCrossEncoderModel, ] @classmethod @@ -124,3 +130,34 @@ def rerank_pairs( yield from self.model.rerank_pairs( pairs, batch_size=batch_size, parallel=parallel, **kwargs ) + + @classmethod + def add_custom_model( + cls, + model: str, + sources: ModelSource, + model_file: str = "onnx/model.onnx", + description: str = "", + license: str = "", + size_in_gb: float = 0.0, + additional_files: Optional[list[str]] = None, + ) -> None: + registered_models = cls._list_supported_models() + for registered_model in registered_models: + if model == registered_model.model: + raise ValueError( + f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, " + f"please use another model name" + ) + + CustomCrossEncoderModel.add_model( + BaseModelDescription( + model=model, + sources=sources, + model_file=model_file, + description=description, + license=license, + size_in_GB=size_in_gb, + additional_files=additional_files or [], + ) + ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 082d4c8c..42e6a20f 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3,10 +3,17 @@ import numpy as np import pytest -from fastembed.common.model_description import PoolingType, ModelSource, DenseModelDescription +from fastembed.common.model_description import ( + PoolingType, + ModelSource, + DenseModelDescription, + BaseModelDescription, +) from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import normalize, mean_pooling from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig +from fastembed.rerank.cross_encoder.custom_reranker_model import CustomCrossEncoderModel +from fastembed.rerank.cross_encoder import TextCrossEncoder from fastembed.text.text_embedding import TextEmbedding from tests.utils import delete_model_cache @@ -14,8 +21,10 @@ @pytest.fixture(autouse=True) def restore_custom_models_fixture(): CustomTextEmbedding.SUPPORTED_MODELS = [] + CustomCrossEncoderModel.SUPPORTED_MODELS = [] yield CustomTextEmbedding.SUPPORTED_MODELS = [] + CustomCrossEncoderModel.SUPPORTED_MODELS = [] def test_text_custom_model(): @@ -65,6 +74,45 @@ def test_text_custom_model(): delete_model_cache(model.model._model_dir) +def test_cross_encoder_custom_model(): + is_ci = os.getenv("CI") + custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2" + size_in_gb = 0.08 + source = ModelSource(hf=custom_model_name) + canonical_vector = np.array([-5.7170815, -11.112114], dtype=np.float32) + + TextCrossEncoder.add_custom_model( + custom_model_name, + model_file="onnx/model.onnx", + sources=source, + size_in_gb=size_in_gb, + additional_files=["onnx/model.onnx_data"], + ) + + assert CustomCrossEncoderModel.SUPPORTED_MODELS[0] == BaseModelDescription( + model=custom_model_name, + sources=source, + model_file="onnx/model.onnx", + description="", + license="", + size_in_GB=size_in_gb, + additional_files=["onnx/model.onnx_data"], + ) + + model = TextCrossEncoder(custom_model_name) + pairs = [ + ("What is AI?", "Artificial intelligence is ..."), + ("What is ML?", "Machine learning is ..."), + ] + scores = list(model.rerank_pairs(pairs)) + + embeddings = np.stack(scores, axis=0) + assert embeddings.shape == (2,) + assert np.allclose(embeddings, canonical_vector, atol=1e-3) + if is_ci: + delete_model_cache(model.model._model_dir) + + def test_mock_add_custom_models(): dim = 5 size_in_gb = 0.1 @@ -156,3 +204,28 @@ def test_do_not_add_existing_model(): dim=384, size_in_gb=0.47, ) + + +def test_do_not_add_existing_cross_encoder(): + existing_base_model = "Xenova/ms-marco-MiniLM-L-6-v2" + custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2" + + with pytest.raises(ValueError, match=f"Model {existing_base_model} is already registered"): + TextCrossEncoder.add_custom_model( + existing_base_model, + sources=ModelSource(hf=existing_base_model), + size_in_gb=0.08, + ) + + TextCrossEncoder.add_custom_model( + custom_model_name, + sources=ModelSource(hf=existing_base_model), + size_in_gb=0.08, + ) + + with pytest.raises(ValueError, match=f"Model {custom_model_name} is already registered"): + TextCrossEncoder.add_custom_model( + custom_model_name, + sources=ModelSource(hf=custom_model_name), + size_in_gb=0.08, + )