Skip to content

Commit 44644e7

Browse files
refactor: Call one model
1 parent 671b874 commit 44644e7

6 files changed

+18
-89
lines changed

tests/test_image_onnx_embeddings.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,9 @@
2727
}
2828

2929
ALL_IMAGE_MODEL_DESC = ImageEmbedding._list_supported_models()
30-
smallest_model = min(ALL_IMAGE_MODEL_DESC, key=lambda m: m.size_in_GB).model
3130

3231

33-
@pytest.mark.parametrize(
34-
"model_name",
35-
[
36-
smallest_model
37-
if smallest_model in CANONICAL_VECTOR_VALUES
38-
else "Qdrant/clip-ViT-B-32-vision"
39-
],
40-
)
32+
@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
4133
def test_embedding(model_name: str) -> None:
4234
is_ci = os.getenv("CI")
4335
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

tests/test_late_interaction_embeddings.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,7 @@ def test_batch_embedding(model_name: str):
170170
delete_model_cache(model.model._model_dir)
171171

172172

173-
@pytest.mark.parametrize(
174-
"model_name",
175-
["answerdotai/answerai-colbert-small-v1"],
176-
)
173+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
177174
def test_single_embedding(model_name: str):
178175
is_ci = os.getenv("CI")
179176
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -203,10 +200,7 @@ def test_single_embedding(model_name: str):
203200
delete_model_cache(model.model._model_dir)
204201

205202

206-
@pytest.mark.parametrize(
207-
"model_name",
208-
["answerdotai/answerai-colbert-small-v1"],
209-
)
203+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
210204
def test_single_embedding_query(model_name: str):
211205
is_ci = os.getenv("CI")
212206
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -236,10 +230,7 @@ def test_single_embedding_query(model_name: str):
236230
delete_model_cache(model.model._model_dir)
237231

238232

239-
@pytest.mark.parametrize(
240-
"token_dim,model_name",
241-
[(96, "answerdotai/answerai-colbert-small-v1")],
242-
)
233+
@pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")])
243234
def test_parallel_processing(token_dim: int, model_name: str):
244235
is_ci = os.getenv("CI")
245236
model = LateInteractionTextEmbedding(model_name=model_name)
@@ -262,10 +253,7 @@ def test_parallel_processing(token_dim: int, model_name: str):
262253
delete_model_cache(model.model._model_dir)
263254

264255

265-
@pytest.mark.parametrize(
266-
"model_name",
267-
["answerdotai/answerai-colbert-small-v1"],
268-
)
256+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
269257
def test_lazy_load(model_name: str):
270258
is_ci = os.getenv("CI")
271259

tests/test_sparse_embeddings.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,7 @@ def test_batch_embedding() -> None:
6666
delete_model_cache(model.model._model_dir)
6767

6868

69-
@pytest.mark.parametrize(
70-
"model_name",
71-
[
72-
min(ALL_SPARSE_MODEL_DESC, key=lambda m: m.size_in_GB).model
73-
if CANONICAL_COLUMN_VALUES
74-
else "prithivida/Splade_PP_en_v1"
75-
],
76-
)
69+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
7770
def test_single_embedding(model_name: str) -> None:
7871
is_ci = os.getenv("CI")
7972
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -195,10 +188,7 @@ def test_disable_stemmer_behavior(disable_stemmer: bool) -> None:
195188
assert result == expected, f"Expected {expected}, but got {result}"
196189

197190

198-
@pytest.mark.parametrize(
199-
"model_name",
200-
["prithivida/Splade_PP_en_v1"],
201-
)
191+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
202192
def test_lazy_load(model_name: str) -> None:
203193
is_ci = os.getenv("CI")
204194
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)

tests/test_text_cross_encoder.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,7 @@
1818
ALL_RERANK_MODEL_DESC = TextCrossEncoder._list_supported_models()
1919

2020

21-
@pytest.mark.parametrize(
22-
"model_name",
23-
[
24-
min(ALL_RERANK_MODEL_DESC, key=lambda m: m.size_in_GB).model
25-
if CANONICAL_SCORE_VALUES
26-
else "Xenova/ms-marco-MiniLM-L-6-v2"
27-
],
28-
)
21+
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
2922
def test_rerank(model_name: str) -> None:
3023
is_ci = os.getenv("CI")
3124
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -61,10 +54,7 @@ def test_rerank(model_name: str) -> None:
6154
delete_model_cache(model.model._model_dir)
6255

6356

64-
@pytest.mark.parametrize(
65-
"model_name",
66-
["Xenova/ms-marco-MiniLM-L-6-v2"],
67-
)
57+
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
6858
def test_batch_rerank(model_name: str) -> None:
6959
is_ci = os.getenv("CI")
7060

@@ -90,10 +80,7 @@ def test_batch_rerank(model_name: str) -> None:
9080
delete_model_cache(model.model._model_dir)
9181

9282

93-
@pytest.mark.parametrize(
94-
"model_name",
95-
["Xenova/ms-marco-MiniLM-L-6-v2"],
96-
)
83+
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
9784
def test_lazy_load(model_name: str) -> None:
9885
is_ci = os.getenv("CI")
9986
model = TextCrossEncoder(model_name=model_name, lazy_load=True)
@@ -107,10 +94,7 @@ def test_lazy_load(model_name: str) -> None:
10794
delete_model_cache(model.model._model_dir)
10895

10996

110-
@pytest.mark.parametrize(
111-
"model_name",
112-
["Xenova/ms-marco-MiniLM-L-6-v2"],
113-
)
97+
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
11498
def test_rerank_pairs_parallel(model_name: str) -> None:
11599
is_ci = os.getenv("CI")
116100

tests/test_text_multitask_embeddings.py

+6-24
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@
6060
docs = ["Hello World", "Follow the white rabbit."]
6161

6262

63-
@pytest.mark.parametrize(
64-
"dim,model_name",
65-
[(1024, "jinaai/jina-embeddings-v3")],
66-
)
63+
@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")])
6764
def test_batch_embedding(dim: int, model_name: str):
6865
is_ci = os.getenv("CI")
6966
docs_to_embed = docs * 10
@@ -85,10 +82,7 @@ def test_batch_embedding(dim: int, model_name: str):
8582
delete_model_cache(model.model._model_dir)
8683

8784

88-
@pytest.mark.parametrize(
89-
"model_name",
90-
["jinaai/jina-embeddings-v3"],
91-
)
85+
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
9286
def test_single_embedding(model_name: str):
9387
is_ci = os.getenv("CI")
9488
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -128,10 +122,7 @@ def test_single_embedding(model_name: str):
128122
delete_model_cache(model.model._model_dir)
129123

130124

131-
@pytest.mark.parametrize(
132-
"model_name",
133-
["jinaai/jina-embeddings-v3"],
134-
)
125+
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
135126
def test_single_embedding_query(model_name: str):
136127
is_ci = os.getenv("CI")
137128
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -171,10 +162,7 @@ def test_single_embedding_query(model_name: str):
171162
delete_model_cache(model.model._model_dir)
172163

173164

174-
@pytest.mark.parametrize(
175-
"model_name",
176-
["jinaai/jina-embeddings-v3"],
177-
)
165+
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
178166
def test_single_embedding_passage(model_name: str):
179167
is_ci = os.getenv("CI")
180168
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
@@ -214,10 +202,7 @@ def test_single_embedding_passage(model_name: str):
214202
delete_model_cache(model.model._model_dir)
215203

216204

217-
@pytest.mark.parametrize(
218-
"dim,model_name",
219-
[(1024, "jinaai/jina-embeddings-v3")],
220-
)
205+
@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")])
221206
def test_parallel_processing(dim: int, model_name: str):
222207
is_ci = os.getenv("CI")
223208

@@ -263,10 +248,7 @@ def test_task_assignment():
263248
delete_model_cache(model.model._model_dir)
264249

265250

266-
@pytest.mark.parametrize(
267-
"model_name",
268-
["jinaai/jina-embeddings-v3"],
269-
)
251+
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
270252
def test_lazy_load(model_name: str):
271253
is_ci = os.getenv("CI")
272254
model = TextEmbedding(model_name=model_name, lazy_load=True)

tests/test_text_onnx_embeddings.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,7 @@
7474
ALL_TEXT_MODEL_DESC = TextEmbedding._list_supported_models()
7575

7676

77-
@pytest.mark.parametrize(
78-
"model_name",
79-
[
80-
min(ALL_TEXT_MODEL_DESC, key=lambda m: m.size_in_GB).model
81-
if CANONICAL_VECTOR_VALUES
82-
else "BAAI/bge-small-en-v1.5"
83-
],
84-
)
77+
@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"])
8578
def test_embedding(model_name: str) -> None:
8679
is_ci = os.getenv("CI")
8780
is_mac = platform.system() == "Darwin"

0 commit comments

Comments
 (0)