diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index eb23b571..8618804a 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -8,7 +8,7 @@ from fastembed import ImageEmbedding from tests.config import TEST_MISC_DIR -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_VECTOR_VALUES = { "Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]), @@ -33,10 +33,7 @@ def test_embedding(model_name: str) -> None: is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" for model_desc in ImageEmbedding._list_supported_models(): - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue dim = model_desc.dim diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index 7f71e9da..474ad0ab 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -6,7 +6,7 @@ from fastembed.late_interaction.late_interaction_text_embedding import ( LateInteractionTextEmbedding, ) -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model # vectors are abridged and rounded for brevity CANONICAL_COLUMN_VALUES = { @@ -177,10 +177,7 @@ def test_single_embedding(model_name: str): docs_to_embed = docs for model_desc in LateInteractionTextEmbedding._list_supported_models(): - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue print("evaluating", model_name) @@ -201,10 +198,7 @@ def test_single_embedding_query(model_name: str): queries_to_embed = docs for model_desc in LateInteractionTextEmbedding._list_supported_models(): - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue print("evaluating", model_name) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 833e1fbe..389a7689 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -5,7 +5,7 @@ from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_COLUMN_VALUES = { "prithivida/Splade_PP_en_v1": { @@ -73,10 +73,7 @@ def test_single_embedding(model_name: str) -> None: for model_desc in SparseTextEmbedding._list_supported_models(): if model_desc.model not in CANONICAL_COLUMN_VALUES: continue - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue model = SparseTextEmbedding(model_name=model_name) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index e7001623..0fb58e72 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -4,7 +4,7 @@ import pytest from fastembed.rerank.cross_encoder import TextCrossEncoder -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_SCORE_VALUES = { "Xenova/ms-marco-MiniLM-L-6-v2": np.array([8.500708, -2.541011]), @@ -22,10 +22,7 @@ def test_rerank(model_name: str) -> None: is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" for model_desc in TextCrossEncoder._list_supported_models(): - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue model = TextCrossEncoder(model_name=model_name) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 52389e1d..19886bd3 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -229,8 +229,6 @@ def test_task_assignment(): continue model_name = model_desc.model - if model_name not in CANONICAL_VECTOR_VALUES: - continue model = TextEmbedding(model_name=model_name) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 50221f3b..a8952d25 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -5,7 +5,7 @@ import pytest from fastembed.text.text_embedding import TextEmbedding -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_VECTOR_VALUES = { "BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]), @@ -83,10 +83,7 @@ def test_embedding(model_name: str) -> None: is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q" ): continue - if not is_ci: - if model_desc.size_in_GB > 1: - continue - elif not is_manual and model_desc.model != model_name: + if not should_test_model(model_name, model_desc, is_ci, is_manual): continue dim = model_desc.dim diff --git a/tests/utils.py b/tests/utils.py index cfd6ae8b..d9804d6f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,9 @@ from pathlib import Path from types import TracebackType -from typing import Union, Callable, Any, Type +from typing import Union, Callable, Any, Type, Optional + +from fastembed.common.model_description import BaseModelDescription def delete_model_cache(model_dir: Union[str, Path]) -> None: @@ -35,3 +37,15 @@ def on_error( if model_dir.exists(): # todo: PermissionDenied is raised on blobs removal in Windows, with blobs > 2GB shutil.rmtree(model_dir, onerror=on_error) + + +def should_test_model( + model_name: str, model_desc: BaseModelDescription, is_ci: Optional[str], is_manual: bool +): + """Determine if a model should be tested based on environment""" + if not is_ci: + if model_desc.size_in_GB > 1: + return False + elif not is_manual and model_desc.model != model_name: + return False + return True