-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom rerankers support #496
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThe changes introduce a new class, ✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
fastembed/rerank/cross_encoder/custom_reranker_model.py (3)
10-12
: Class-level SUPPORTED_MODELS could lead to issues.The
SUPPORTED_MODELS
list is defined at the class level and can be shared across all instances of the class. Consider alternatives:
- If this is intentionally a class-level shared registry, add a comment explaining this design.
- Otherwise, consider initializing in
__init__
if model lists should be instance-specific.Also, missing docstring for the class that explains its purpose and usage.
class CustomCrossEncoderModel(OnnxTextCrossEncoder): + """ + Custom cross-encoder model that allows adding and managing custom reranker models. + + This class extends OnnxTextCrossEncoder to provide functionality for registering + and using custom reranker models via a class-level registry of model descriptions. + """ SUPPORTED_MODELS: list[DenseModelDescription] = [] + # Class-level registry intentionally shared across all instances
13-38
: Constructor implementation is appropriate.The constructor properly forwards all parameters to the parent class, maintaining the expected interface.
Consider adding a docstring to document the parameters and their usage.
def __init__( self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence[OnnxProvider]] = None, cuda: bool = False, device_ids: Optional[list[int]] = None, lazy_load: bool = False, device_id: Optional[int] = None, specific_model_path: Optional[str] = None, **kwargs: Any, ): + """ + Initialize a custom cross-encoder model. + + Args: + model_name: Name of the model to use + cache_dir: Directory to cache models + threads: Number of threads to use for inference + providers: ONNX execution providers + cuda: Whether to use CUDA + device_ids: List of device IDs to use + lazy_load: Whether to load the model lazily + device_id: Specific device ID to use + specific_model_path: Path to a specific model file + **kwargs: Additional arguments passed to the parent class + """ super().__init__( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, cuda=cuda, device_ids=device_ids, lazy_load=lazy_load, device_id=device_id, specific_model_path=specific_model_path, **kwargs, )
43-49
: Add docstring and validation to add_model method.The
add_model
method lacks a docstring and input validation. Consider adding:
- A descriptive docstring that explains the method's purpose and parameters
- Validation to ensure the provided model description is valid
@classmethod def add_model( cls, model_description: DenseModelDescription, ) -> None: + """ + Add a custom model to the list of supported models. + + Args: + model_description: Description of the model to add + + Raises: + ValueError: If the model description is invalid or a model with the same name already exists + """ + # Validate model description + if not model_description.source or not model_description.model_file: + raise ValueError("Model description must include source and model file") + + # Check for duplicate model names + if any(model.name == model_description.name for model in cls.SUPPORTED_MODELS): + raise ValueError(f"Model with name '{model_description.name}' already exists") + cls.SUPPORTED_MODELS.append(model_description)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fastembed/rerank/cross_encoder/__init__.py
(1 hunks)fastembed/rerank/cross_encoder/custom_reranker_model.py
(1 hunks)
🧰 Additional context used
🪛 GitHub Actions: type-checkers
fastembed/rerank/cross_encoder/custom_reranker_model.py
[error] 40-40: Return type 'list[DenseModelDescription]' of '_list_supported_models' incompatible with return type 'list[BaseModelDescription]' in supertype 'OnnxTextCrossEncoder' [override]
[error] 40-40: Return type 'list[DenseModelDescription]' of '_list_supported_models' incompatible with return type 'list[BaseModelDescription]' in supertype 'ModelManagement' [override]
⏰ Context from checks skipped due to timeout of 90000ms (16)
- GitHub Check: Python 3.13.x on windows-latest test
- GitHub Check: Python 3.13.x on macos-latest test
- GitHub Check: Python 3.13.x on ubuntu-latest test
- GitHub Check: Python 3.12.x on windows-latest test
- GitHub Check: Python 3.12.x on macos-latest test
- GitHub Check: Python 3.12.x on ubuntu-latest test
- GitHub Check: Python 3.11.x on windows-latest test
- GitHub Check: Python 3.11.x on macos-latest test
- GitHub Check: Python 3.11.x on ubuntu-latest test
- GitHub Check: Python 3.10.x on windows-latest test
- GitHub Check: Python 3.10.x on macos-latest test
- GitHub Check: Python 3.10.x on ubuntu-latest test
- GitHub Check: Python 3.9.x on windows-latest test
- GitHub Check: Python 3.9.x on macos-latest test
- GitHub Check: Python 3.9.x on ubuntu-latest test
- GitHub Check: Analyze (python)
🔇 Additional comments (4)
fastembed/rerank/cross_encoder/__init__.py (2)
2-2
: Import statement looks good.The import of
CustomCrossEncoderModel
fromfastembed.rerank.cross_encoder.custom_reranker_model
is correctly added.
4-4
: Public interface update is appropriate.The
__all__
list is properly updated to include the newCustomCrossEncoderModel
class, making it part of the module's public API.fastembed/rerank/cross_encoder/custom_reranker_model.py (2)
1-8
: Imports are appropriate and organized well.The necessary imports for the implementation are included and well-organized, with standard library imports first, followed by package-specific imports.
39-42
:✅ Verification successful
Type error in _list_supported_models return annotation.
The pipeline failures indicate a type compatibility issue. The method overrides a parent method that returns
list[BaseModelDescription]
, but this implementation specifieslist[DenseModelDescription]
.@classmethod - def _list_supported_models(cls) -> list[DenseModelDescription]: + def _list_supported_models(cls) -> list["BaseModelDescription"]: return cls.SUPPORTED_MODELSThis assumes that
DenseModelDescription
is a subclass ofBaseModelDescription
. If not, you'll need to ensure type compatibility by other means.
🏁 Script executed:
#!/bin/bash # Check the inheritance hierarchy of DenseModelDescription grep -r "class DenseModelDescription" --include="*.py" . grep -r "BaseModelDescription" --include="*.py" .Length of output: 2686
Type Annotation Correction in
_list_supported_models
The issue is confirmed: the parent method expects a return type of
list[BaseModelDescription]
, while this override uses the more specificlist[DenseModelDescription]
. SinceDenseModelDescription
is indeed a subclass ofBaseModelDescription
(as verified infastembed/common/model_description.py
), the method signature should be updated to the parent’s annotation to ensure type compatibility.
- File:
fastembed/rerank/cross_encoder/custom_reranker_model.py
- Lines: 39–42
- Change: Update the return annotation from
list[DenseModelDescription]
tolist["BaseModelDescription"]
(using a string literal for forward reference) to correctly reflect the parent type.@classmethod - def _list_supported_models(cls) -> list[DenseModelDescription]: + def _list_supported_models(cls) -> list["BaseModelDescription"]: return cls.SUPPORTED_MODELS🧰 Tools
🪛 GitHub Actions: type-checkers
[error] 40-40: Return type 'list[DenseModelDescription]' of '_list_supported_models' incompatible with return type 'list[BaseModelDescription]' in supertype 'OnnxTextCrossEncoder' [override]
[error] 40-40: Return type 'list[DenseModelDescription]' of '_list_supported_models' incompatible with return type 'list[BaseModelDescription]' in supertype 'ModelManagement' [override]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
fastembed/rerank/cross_encoder/text_cross_encoder.py (1)
135-167
: Well-implemented custom model registration methodThe
add_custom_model
method is well-structured with appropriate parameter validation and error handling. It checks for existing models before registration and provides sensible defaults.Two minor suggestions:
- Consider adding docstring documentation for this public method, similar to other methods in the class
- The error message on line 151-152 refers to "CrossEncoderModel" but should probably say "TextCrossEncoder" for consistency
@classmethod def add_custom_model( cls, model: str, sources: ModelSource, dim: int, model_file: str = "onnx/model.onnx", description: str = "", license: str = "", size_in_gb: float = 0.0, additional_files: Optional[list[str]] = None, + ) -> None: + """Registers a custom cross-encoder model. + + Args: + model: Unique name for the custom model + sources: Model source information + dim: Model output dimension + model_file: Path to the model file, defaults to "onnx/model.onnx" + description: Optional model description + license: Optional license information + size_in_gb: Size of the model in GB + additional_files: List of additional files required by the model + + Raises: + ValueError: If a model with the given name is already registered + """ registered_models = cls._list_supported_models() for registered_model in registered_models: if model == registered_model.model: raise ValueError( - f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, " + f"Model {model} is already registered in TextCrossEncoder, if you still want to add this model, " f"please use another model name" )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fastembed/rerank/cross_encoder/custom_reranker_model.py
(1 hunks)fastembed/rerank/cross_encoder/text_cross_encoder.py
(2 hunks)tests/test_custom_models.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fastembed/rerank/cross_encoder/custom_reranker_model.py
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: Python 3.13.x on windows-latest test
- GitHub Check: Python 3.13.x on macos-latest test
- GitHub Check: Python 3.13.x on ubuntu-latest test
- GitHub Check: Python 3.12.x on windows-latest test
- GitHub Check: Python 3.12.x on macos-latest test
- GitHub Check: Python 3.12.x on ubuntu-latest test
- GitHub Check: Python 3.11.x on windows-latest test
- GitHub Check: Python 3.11.x on macos-latest test
- GitHub Check: Python 3.11.x on ubuntu-latest test
- GitHub Check: Python 3.10.x on windows-latest test
- GitHub Check: Python 3.10.x on macos-latest test
- GitHub Check: Python 3.10.x on ubuntu-latest test
- GitHub Check: Python 3.9.x on windows-latest test
- GitHub Check: Python 3.9.x on macos-latest test
- GitHub Check: Python 3.9.x on ubuntu-latest test
🔇 Additional comments (4)
fastembed/rerank/cross_encoder/text_cross_encoder.py (3)
6-7
: Appropriate addition of the custom reranker model importThe import for
CustomCrossEncoderModel
is necessary for the registry update and the new functionality added to theTextCrossEncoder
class.
9-13
: Clean expansion of model description importsAll necessary model description classes are now properly imported. The additional imports are required for the new
add_custom_model
method.
19-19
: Good addition of CustomCrossEncoderModel to the registryAdding the model to the registry enables the TextCrossEncoder to recognize and use custom reranker models.
tests/test_custom_models.py (1)
10-11
: Proper imports for testing custom rerankersThe imports for
CustomCrossEncoderModel
andTextCrossEncoder
are necessary for the new test function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fastembed/rerank/cross_encoder/custom_reranker_model.py (1)
41-47
: Consider adding duplicate model validation.The current implementation allows adding duplicate models with the same name to the SUPPORTED_MODELS list. Consider adding validation to prevent duplicates.
@classmethod def add_model( cls, model_description: BaseModelDescription, ) -> None: + # Check if model with same name already exists + for existing_model in cls.SUPPORTED_MODELS: + if existing_model.model_name == model_description.model_name: + raise ValueError(f"Model with name '{model_description.model_name}' already exists") cls.SUPPORTED_MODELS.append(model_description)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fastembed/rerank/cross_encoder/custom_reranker_model.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (16)
- GitHub Check: Python 3.13.x on windows-latest test
- GitHub Check: Python 3.13.x on macos-latest test
- GitHub Check: Python 3.13.x on ubuntu-latest test
- GitHub Check: Python 3.12.x on windows-latest test
- GitHub Check: Python 3.12.x on macos-latest test
- GitHub Check: Python 3.12.x on ubuntu-latest test
- GitHub Check: Python 3.11.x on windows-latest test
- GitHub Check: Python 3.11.x on macos-latest test
- GitHub Check: Python 3.11.x on ubuntu-latest test
- GitHub Check: Python 3.10.x on windows-latest test
- GitHub Check: Python 3.10.x on macos-latest test
- GitHub Check: Python 3.10.x on ubuntu-latest test
- GitHub Check: Python 3.9.x on windows-latest test
- GitHub Check: Python 3.9.x on macos-latest test
- GitHub Check: Python 3.13 test
- GitHub Check: Python 3.9.x on ubuntu-latest test
🔇 Additional comments (4)
fastembed/rerank/cross_encoder/custom_reranker_model.py (4)
8-9
: Nice implementation of a custom cross encoder model registry.The class design allows users to dynamically register their own cross encoder models, effectively extending the library's capabilities without modifying core code.
11-35
: Constructor implementation looks good.The constructor properly inherits from OnnxTextCrossEncoder and passes all the necessary parameters to the parent class. The type hints are well-defined, and the method follows best practices by accepting **kwargs to handle any additional parameters.
37-40
: Simple and effective implementation.This class method provides a clean way to access the list of supported models.
9-9
:❓ Verification inconclusive
Consider thread-safety for the model registry.
The
SUPPORTED_MODELS
class variable is mutable and could potentially be accessed from multiple threads. If this code will be used in a multi-threaded environment, consider using thread-safe data structures or implementing synchronization.
🏁 Script executed:
#!/bin/bash # Check if the code might be used in a multi-threaded context rg -A 5 -B 5 "Thread|thread|concurrent|Concurrent|Lock|lock|Semaphore|semaphore" --glob "**/*.py"Length of output: 77618
Attention: Thread-Safety for Model Registry
In
fastembed/rerank/cross_encoder/custom_reranker_model.py
, the class variableSUPPORTED_MODELS: list[BaseModelDescription] = []is defined as a mutable list without any synchronization. While it appears that this registry is populated during initialization (and not modified thereafter), the broader codebase shows several components that operate in a multi-threaded context. To avoid potential race conditions if the registry is ever updated concurrently, please verify the following:
- If modifications are expected at runtime: Consider using synchronization mechanisms (e.g., a threading lock) or a thread-safe data structure to guard updates.
- If the registry is static post-initialization: Converting it to an immutable tuple after population might be a safer alternative.
Please review the intended usage of
SUPPORTED_MODELS
under concurrent scenarios and adjust its implementation accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
fastembed/rerank/cross_encoder/text_cross_encoder.py (1)
133-164
: Minor naming enhancement suggested.The
add_custom_model
method is well-structured and clearly checks for duplicates before adding a new model. However, consider renaming the parametermodel
tomodel_name
for readability, since “model” might be confused with an instantiated model object.- def add_custom_model( - cls, - model: str, + def add_custom_model( + cls, + model_name: str, sources: ModelSource, ... ): ... for registered_model in registered_models: - if model == registered_model.model: + if model_name == registered_model.model: ... ... - CustomCrossEncoderModel.add_model( + CustomCrossEncoderModel.add_model( BaseModelDescription( - model=model, + model=model_name, ... ) )tests/test_custom_models.py (1)
77-114
: Rename to clarify meaning.The test logic is solid. However, variable
embeddings
might be misleading because the cross-encoder returns a scalar score for each pair. Consider renaming it toscores
oroutputs
for clarity.- embeddings = np.stack(scores, axis=0) - assert embeddings.shape == (2,) - assert np.allclose(embeddings, canonical_vector, atol=1e-3) + scored_outputs = np.stack(scores, axis=0) + assert scored_outputs.shape == (2,) + assert np.allclose(scored_outputs, canonical_vector, atol=1e-3)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fastembed/rerank/cross_encoder/text_cross_encoder.py
(2 hunks)tests/test_custom_models.py
(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (16)
- GitHub Check: Python 3.13.x on windows-latest test
- GitHub Check: Python 3.13.x on macos-latest test
- GitHub Check: Python 3.13.x on ubuntu-latest test
- GitHub Check: Python 3.12.x on windows-latest test
- GitHub Check: Python 3.12.x on macos-latest test
- GitHub Check: Python 3.12.x on ubuntu-latest test
- GitHub Check: Python 3.11.x on windows-latest test
- GitHub Check: Python 3.11.x on macos-latest test
- GitHub Check: Python 3.11.x on ubuntu-latest test
- GitHub Check: Python 3.10.x on windows-latest test
- GitHub Check: Python 3.10.x on macos-latest test
- GitHub Check: Python 3.13 test
- GitHub Check: Python 3.10.x on ubuntu-latest test
- GitHub Check: Python 3.9.x on windows-latest test
- GitHub Check: Python 3.9.x on macos-latest test
- GitHub Check: Python 3.9.x on ubuntu-latest test
🔇 Additional comments (7)
fastembed/rerank/cross_encoder/text_cross_encoder.py (3)
6-6
: Module import looks good.Including
CustomCrossEncoderModel
as part of the imports is correct and aligns with the new functionality for custom rerankers.
9-12
: Imports are valid.The added imports for
ModelSource
andBaseModelDescription
are necessary for the new custom model registration logic.
18-18
: Registration extension is appropriate.Adding
CustomCrossEncoderModel
to theCROSS_ENCODER_REGISTRY
ensures theTextCrossEncoder
can instantiate custom cross-encoder models.tests/test_custom_models.py (4)
6-11
: New imports are valid.Importing
PoolingType
,ModelSource
,DenseModelDescription
, andBaseModelDescription
is essential for describing and registering custom models.
15-16
: Importing cross-encoder classes is correct.Bringing in
CustomCrossEncoderModel
andTextCrossEncoder
aligns with the tests for custom cross-encoder functionality.
24-27
: Fixture cleanup is consistent.Resetting
CustomCrossEncoderModel.SUPPORTED_MODELS
before and after tests ensures a clean environment across tests.
209-232
: Validation of duplicate model addition is correct.The test correctly ensures a
ValueError
is raised when attempting to add an already registered cross-encoder.
Added support of custom rerankers. It unlocks 494 functionality as: