Skip to content

Commit

Permalink
tests: Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Mar 5, 2025
1 parent 4092ba6 commit 3826294
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 32 deletions.
7 changes: 2 additions & 5 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fastembed import ImageEmbedding
from tests.config import TEST_MISC_DIR
from tests.utils import delete_model_cache
from tests.utils import delete_model_cache, should_test_model

CANONICAL_VECTOR_VALUES = {
"Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]),
Expand All @@ -33,10 +33,7 @@ def test_embedding(model_name: str) -> None:
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

for model_desc in ImageEmbedding._list_supported_models():
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

dim = model_desc.dim
Expand Down
12 changes: 3 additions & 9 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastembed.late_interaction.late_interaction_text_embedding import (
LateInteractionTextEmbedding,
)
from tests.utils import delete_model_cache
from tests.utils import delete_model_cache, should_test_model

# vectors are abridged and rounded for brevity
CANONICAL_COLUMN_VALUES = {
Expand Down Expand Up @@ -177,10 +177,7 @@ def test_single_embedding(model_name: str):
docs_to_embed = docs

for model_desc in LateInteractionTextEmbedding._list_supported_models():
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

print("evaluating", model_name)
Expand All @@ -201,10 +198,7 @@ def test_single_embedding_query(model_name: str):
queries_to_embed = docs

for model_desc in LateInteractionTextEmbedding._list_supported_models():
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

print("evaluating", model_name)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastembed.sparse.bm25 import Bm25
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
from tests.utils import delete_model_cache
from tests.utils import delete_model_cache, should_test_model

CANONICAL_COLUMN_VALUES = {
"prithivida/Splade_PP_en_v1": {
Expand Down Expand Up @@ -73,10 +73,7 @@ def test_single_embedding(model_name: str) -> None:
for model_desc in SparseTextEmbedding._list_supported_models():
if model_desc.model not in CANONICAL_COLUMN_VALUES:
continue
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

model = SparseTextEmbedding(model_name=model_name)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from fastembed.rerank.cross_encoder import TextCrossEncoder
from tests.utils import delete_model_cache
from tests.utils import delete_model_cache, should_test_model

CANONICAL_SCORE_VALUES = {
"Xenova/ms-marco-MiniLM-L-6-v2": np.array([8.500708, -2.541011]),
Expand All @@ -22,10 +22,7 @@ def test_rerank(model_name: str) -> None:
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

for model_desc in TextCrossEncoder._list_supported_models():
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

model = TextCrossEncoder(model_name=model_name)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ def test_task_assignment():
continue

model_name = model_desc.model
if model_name not in CANONICAL_VECTOR_VALUES:
continue

model = TextEmbedding(model_name=model_name)

Expand Down
7 changes: 2 additions & 5 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from fastembed.text.text_embedding import TextEmbedding
from tests.utils import delete_model_cache
from tests.utils import delete_model_cache, should_test_model

CANONICAL_VECTOR_VALUES = {
"BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]),
Expand Down Expand Up @@ -83,10 +83,7 @@ def test_embedding(model_name: str) -> None:
is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q"
):
continue
if not is_ci:
if model_desc.size_in_GB > 1:
continue
elif not is_manual and model_desc.model != model_name:
if not should_test_model(model_name, model_desc, is_ci, is_manual):
continue

dim = model_desc.dim
Expand Down
16 changes: 15 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pathlib import Path
from types import TracebackType
from typing import Union, Callable, Any, Type
from typing import Union, Callable, Any, Type, Optional

from fastembed.common.model_description import BaseModelDescription


def delete_model_cache(model_dir: Union[str, Path]) -> None:
Expand Down Expand Up @@ -35,3 +37,15 @@ def on_error(
if model_dir.exists():
# todo: PermissionDenied is raised on blobs removal in Windows, with blobs > 2GB
shutil.rmtree(model_dir, onerror=on_error)


def should_test_model(
model_name: str, model_desc: BaseModelDescription, is_ci: Optional[str], is_manual: bool
):
"""Determine if a model should be tested based on environment"""
if not is_ci:
if model_desc.size_in_GB > 1:
return False
elif not is_manual and model_desc.model != model_name:
return False
return True

0 comments on commit 3826294

Please sign in to comment.