Skip to content

Commit 3826294

Browse files
tests: Updated tests
1 parent 4092ba6 commit 3826294

7 files changed

+26
-32
lines changed

tests/test_image_onnx_embeddings.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fastembed import ImageEmbedding
1010
from tests.config import TEST_MISC_DIR
11-
from tests.utils import delete_model_cache
11+
from tests.utils import delete_model_cache, should_test_model
1212

1313
CANONICAL_VECTOR_VALUES = {
1414
"Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]),
@@ -33,10 +33,7 @@ def test_embedding(model_name: str) -> None:
3333
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
3434

3535
for model_desc in ImageEmbedding._list_supported_models():
36-
if not is_ci:
37-
if model_desc.size_in_GB > 1:
38-
continue
39-
elif not is_manual and model_desc.model != model_name:
36+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
4037
continue
4138

4239
dim = model_desc.dim

tests/test_late_interaction_embeddings.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastembed.late_interaction.late_interaction_text_embedding import (
77
LateInteractionTextEmbedding,
88
)
9-
from tests.utils import delete_model_cache
9+
from tests.utils import delete_model_cache, should_test_model
1010

1111
# vectors are abridged and rounded for brevity
1212
CANONICAL_COLUMN_VALUES = {
@@ -177,10 +177,7 @@ def test_single_embedding(model_name: str):
177177
docs_to_embed = docs
178178

179179
for model_desc in LateInteractionTextEmbedding._list_supported_models():
180-
if not is_ci:
181-
if model_desc.size_in_GB > 1:
182-
continue
183-
elif not is_manual and model_desc.model != model_name:
180+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
184181
continue
185182

186183
print("evaluating", model_name)
@@ -201,10 +198,7 @@ def test_single_embedding_query(model_name: str):
201198
queries_to_embed = docs
202199

203200
for model_desc in LateInteractionTextEmbedding._list_supported_models():
204-
if not is_ci:
205-
if model_desc.size_in_GB > 1:
206-
continue
207-
elif not is_manual and model_desc.model != model_name:
201+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
208202
continue
209203

210204
print("evaluating", model_name)

tests/test_sparse_embeddings.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from fastembed.sparse.bm25 import Bm25
77
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
8-
from tests.utils import delete_model_cache
8+
from tests.utils import delete_model_cache, should_test_model
99

1010
CANONICAL_COLUMN_VALUES = {
1111
"prithivida/Splade_PP_en_v1": {
@@ -73,10 +73,7 @@ def test_single_embedding(model_name: str) -> None:
7373
for model_desc in SparseTextEmbedding._list_supported_models():
7474
if model_desc.model not in CANONICAL_COLUMN_VALUES:
7575
continue
76-
if not is_ci:
77-
if model_desc.size_in_GB > 1:
78-
continue
79-
elif not is_manual and model_desc.model != model_name:
76+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
8077
continue
8178

8279
model = SparseTextEmbedding(model_name=model_name)

tests/test_text_cross_encoder.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from fastembed.rerank.cross_encoder import TextCrossEncoder
7-
from tests.utils import delete_model_cache
7+
from tests.utils import delete_model_cache, should_test_model
88

99
CANONICAL_SCORE_VALUES = {
1010
"Xenova/ms-marco-MiniLM-L-6-v2": np.array([8.500708, -2.541011]),
@@ -22,10 +22,7 @@ def test_rerank(model_name: str) -> None:
2222
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
2323

2424
for model_desc in TextCrossEncoder._list_supported_models():
25-
if not is_ci:
26-
if model_desc.size_in_GB > 1:
27-
continue
28-
elif not is_manual and model_desc.model != model_name:
25+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
2926
continue
3027

3128
model = TextCrossEncoder(model_name=model_name)

tests/test_text_multitask_embeddings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ def test_task_assignment():
229229
continue
230230

231231
model_name = model_desc.model
232-
if model_name not in CANONICAL_VECTOR_VALUES:
233-
continue
234232

235233
model = TextEmbedding(model_name=model_name)
236234

tests/test_text_onnx_embeddings.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from fastembed.text.text_embedding import TextEmbedding
8-
from tests.utils import delete_model_cache
8+
from tests.utils import delete_model_cache, should_test_model
99

1010
CANONICAL_VECTOR_VALUES = {
1111
"BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]),
@@ -83,10 +83,7 @@ def test_embedding(model_name: str) -> None:
8383
is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q"
8484
):
8585
continue
86-
if not is_ci:
87-
if model_desc.size_in_GB > 1:
88-
continue
89-
elif not is_manual and model_desc.model != model_name:
86+
if not should_test_model(model_name, model_desc, is_ci, is_manual):
9087
continue
9188

9289
dim = model_desc.dim

tests/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from pathlib import Path
55
from types import TracebackType
6-
from typing import Union, Callable, Any, Type
6+
from typing import Union, Callable, Any, Type, Optional
7+
8+
from fastembed.common.model_description import BaseModelDescription
79

810

911
def delete_model_cache(model_dir: Union[str, Path]) -> None:
@@ -35,3 +37,15 @@ def on_error(
3537
if model_dir.exists():
3638
# todo: PermissionDenied is raised on blobs removal in Windows, with blobs > 2GB
3739
shutil.rmtree(model_dir, onerror=on_error)
40+
41+
42+
def should_test_model(
43+
model_name: str, model_desc: BaseModelDescription, is_ci: Optional[str], is_manual: bool
44+
):
45+
"""Determine if a model should be tested based on environment"""
46+
if not is_ci:
47+
if model_desc.size_in_GB > 1:
48+
return False
49+
elif not is_manual and model_desc.model != model_name:
50+
return False
51+
return True

0 commit comments

Comments
 (0)