Skip to content

Commit 640f106

Browse files
tests: Updated tests
1 parent 0fcf2ef commit 640f106

8 files changed

+183
-149
lines changed

tests/get_all_model_hash.py

-21
This file was deleted.

tests/test_image_onnx_embeddings.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,16 @@
2929

3030
def test_embedding() -> None:
3131
is_ci = os.getenv("CI")
32+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
3233

33-
for model_desc in ImageEmbedding._list_supported_models():
34-
if not is_ci and model_desc.size_in_GB > 1:
34+
all_models = ImageEmbedding._list_supported_models()
35+
36+
models_to_test = [all_models[0]] if not is_manual else all_models
37+
38+
for model_desc in models_to_test:
39+
if (
40+
not is_ci and model_desc.size_in_GB > 1
41+
) or model_desc.model not in CANONICAL_VECTOR_VALUES:
3542
continue
3643

3744
dim = model_desc.dim
@@ -74,8 +81,12 @@ def test_batch_embedding(n_dims: int, model_name: str) -> None:
7481

7582
embeddings = list(model.embed(images, batch_size=10))
7683
embeddings = np.stack(embeddings, axis=0)
84+
assert np.allclose(embeddings[1], embeddings[2])
85+
86+
canonical_vector = CANONICAL_VECTOR_VALUES[model_name]
7787

7888
assert embeddings.shape == (len(test_images) * n_images, n_dims)
89+
assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3)
7990
if is_ci:
8091
delete_model_cache(model.model._model_dir)
8192

tests/test_late_interaction_embeddings.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -153,57 +153,81 @@
153153
docs = ["Hello World"]
154154

155155

156-
def test_batch_embedding():
156+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
157+
def test_batch_embedding(model_name: str):
157158
is_ci = os.getenv("CI")
158159
docs_to_embed = docs * 10
159160

160-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
161-
print("evaluating", model_name)
162-
model = LateInteractionTextEmbedding(model_name=model_name)
163-
result = list(model.embed(docs_to_embed, batch_size=6))
161+
model = LateInteractionTextEmbedding(model_name=model_name)
162+
result = list(model.embed(docs_to_embed, batch_size=6))
163+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
164164

165-
for value in result:
166-
token_num, abridged_dim = expected_result.shape
167-
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)
165+
for value in result:
166+
token_num, abridged_dim = expected_result.shape
167+
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)
168168

169-
if is_ci:
170-
delete_model_cache(model.model._model_dir)
169+
if is_ci:
170+
delete_model_cache(model.model._model_dir)
171171

172172

173173
def test_single_embedding():
174174
is_ci = os.getenv("CI")
175+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
175176
docs_to_embed = docs
176177

177-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
178+
all_models = LateInteractionTextEmbedding._list_supported_models()
179+
models_to_test = [all_models[0]] if not is_manual else all_models
180+
181+
for model_desc in models_to_test:
182+
model_name = model_desc.model
183+
if (
184+
not is_ci and model_desc.size_in_GB > 1
185+
) or model_desc.model not in CANONICAL_COLUMN_VALUES:
186+
continue
178187
print("evaluating", model_name)
179188
model = LateInteractionTextEmbedding(model_name=model_name)
180189
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
190+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
181191
token_num, abridged_dim = expected_result.shape
182192
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
183193

184-
if is_ci:
185-
delete_model_cache(model.model._model_dir)
194+
if is_ci:
195+
delete_model_cache(model.model._model_dir)
186196

187197

188198
def test_single_embedding_query():
189199
is_ci = os.getenv("CI")
200+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
190201
queries_to_embed = docs
191202

192-
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
203+
all_models = LateInteractionTextEmbedding._list_supported_models()
204+
models_to_test = [all_models[0]] if not is_manual else all_models
205+
206+
for model_desc in models_to_test:
207+
model_name = model_desc.model
208+
if (
209+
not is_ci and model_desc.size_in_GB > 1
210+
) or model_desc.model not in CANONICAL_QUERY_VALUES:
211+
continue
193212
print("evaluating", model_name)
194213
model = LateInteractionTextEmbedding(model_name=model_name)
195214
result = next(iter(model.query_embed(queries_to_embed)))
215+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
196216
token_num, abridged_dim = expected_result.shape
197217
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
198218

199219
if is_ci:
200220
delete_model_cache(model.model._model_dir)
201221

202222

203-
def test_parallel_processing():
223+
@pytest.mark.parametrize(
224+
"token_dim,model_name",
225+
[(96, "answerdotai/answerai-colbert-small-v1")],
226+
)
227+
def test_parallel_processing(token_dim: int, model_name: str):
204228
is_ci = os.getenv("CI")
205-
model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
206-
token_dim = 128
229+
model = LateInteractionTextEmbedding(model_name=model_name)
230+
207231
docs = ["hello world", "flag embedding"] * 100
208232
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
209233
embeddings = np.stack(embeddings, axis=0)
@@ -224,7 +248,7 @@ def test_parallel_processing():
224248

225249
@pytest.mark.parametrize(
226250
"model_name",
227-
["colbert-ir/colbertv2.0"],
251+
["answerdotai/answerai-colbert-small-v1"],
228252
)
229253
def test_lazy_load(model_name: str):
230254
is_ci = os.getenv("CI")

tests/test_sparse_embeddings.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,21 @@ def test_batch_embedding() -> None:
6666

6767
def test_single_embedding() -> None:
6868
is_ci = os.getenv("CI")
69-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
69+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
70+
71+
all_models = SparseTextEmbedding._list_supported_models()
72+
models_to_test = [all_models[0]] if not is_manual else all_models
73+
74+
for model_desc in models_to_test:
75+
model_name = model_desc.model
76+
if (not is_ci and model_desc.size_in_GB > 1) or model_name not in CANONICAL_COLUMN_VALUES:
77+
continue
78+
7079
model = SparseTextEmbedding(model_name=model_name)
7180

7281
passage_result = next(iter(model.embed(docs, batch_size=6)))
7382
query_result = next(iter(model.query_embed(docs)))
83+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
7484
for result in [passage_result, query_result]:
7585
assert result.indices.tolist() == expected_result["indices"]
7686

tests/test_text_cross_encoder.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,43 @@
1515
"jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]),
1616
}
1717

18-
SELECTED_MODELS = {
19-
"Xenova": "Xenova/ms-marco-MiniLM-L-6-v2",
20-
"BAAI": "BAAI/bge-reranker-base",
21-
"jinaai": "jinaai/jina-reranker-v1-tiny-en",
22-
}
23-
2418

25-
@pytest.mark.parametrize(
26-
"model_name",
27-
[model_name for model_name in CANONICAL_SCORE_VALUES],
28-
)
2919
def test_rerank(model_name: str) -> None:
3020
is_ci = os.getenv("CI")
21+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
3122

32-
model = TextCrossEncoder(model_name=model_name)
23+
all_models = TextCrossEncoder._list_supported_models()
24+
models_to_test = [all_models[0]] if not is_manual else all_models
3325

34-
query = "What is the capital of France?"
35-
documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."]
36-
scores = np.array(list(model.rerank(query, documents)))
26+
for model_desc in models_to_test:
27+
if (
28+
not is_ci and model_desc.size_in_GB > 1
29+
) or model_desc.model not in CANONICAL_SCORE_VALUES:
30+
continue
3731

38-
pairs = [(query, doc) for doc in documents]
39-
scores2 = np.array(list(model.rerank_pairs(pairs)))
40-
assert np.allclose(
41-
scores, scores2, atol=1e-5
42-
), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}"
32+
model = TextCrossEncoder(model_name=model_name)
4333

44-
canonical_scores = CANONICAL_SCORE_VALUES[model_name]
45-
assert np.allclose(
46-
scores, canonical_scores, atol=1e-3
47-
), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}"
48-
if is_ci:
49-
delete_model_cache(model.model._model_dir)
34+
query = "What is the capital of France?"
35+
documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."]
36+
scores = np.array(list(model.rerank(query, documents)))
37+
38+
pairs = [(query, doc) for doc in documents]
39+
scores2 = np.array(list(model.rerank_pairs(pairs)))
40+
assert np.allclose(
41+
scores, scores2, atol=1e-5
42+
), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}"
43+
44+
canonical_scores = CANONICAL_SCORE_VALUES[model_name]
45+
assert np.allclose(
46+
scores, canonical_scores, atol=1e-3
47+
), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}"
48+
if is_ci:
49+
delete_model_cache(model.model._model_dir)
5050

5151

5252
@pytest.mark.parametrize(
5353
"model_name",
54-
[model_name for model_name in SELECTED_MODELS.values()],
54+
["Xenova/ms-marco-MiniLM-L-6-v2"],
5555
)
5656
def test_batch_rerank(model_name: str) -> None:
5757
is_ci = os.getenv("CI")
@@ -97,7 +97,7 @@ def test_lazy_load(model_name: str) -> None:
9797

9898
@pytest.mark.parametrize(
9999
"model_name",
100-
[model_name for model_name in SELECTED_MODELS.values()],
100+
["Xenova/ms-marco-MiniLM-L-6-v2"],
101101
)
102102
def test_rerank_pairs_parallel(model_name: str) -> None:
103103
is_ci = os.getenv("CI")

0 commit comments

Comments
 (0)