Skip to content

Commit fa3f20c

Browse files
committed
Test of different models
1 parent 99ff62f commit fa3f20c

File tree

1 file changed

+47
-42
lines changed

1 file changed

+47
-42
lines changed

tests/test_add_custom_model.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,103 +5,108 @@
55
from fastembed.text.text_embedding import TextEmbedding
66
from tests.utils import delete_model_cache
77

8-
# Four variations of the same base model, with different (mean_pooling, normalization).
9-
# Each has a separate canonical vector for "hello world" (first 5 dims).
108
canonical_vectors = [
119
{
10+
"model": "intfloat/multilingual-e5-small",
1211
"mean_pooling": True,
1312
"normalization": True,
1413
"canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02],
1514
},
1615
{
16+
"model": "intfloat/multilingual-e5-small",
1717
"mean_pooling": True,
1818
"normalization": False,
1919
"canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01],
2020
},
2121
{
22+
"model": "mixedbread-ai/mxbai-embed-xsmall-v1",
2223
"mean_pooling": False,
2324
"normalization": False,
24-
"canonical_vector": [1.8612e-01, 9.1158e-02, -1.4521e-01, -3.3533e-01, 3.2876e-01],
25-
},
26-
{
27-
"mean_pooling": False,
28-
"normalization": True,
29-
"canonical_vector": [4.6600e-01, 2.1830e-01, -3.3190e-01, -4.2530e-01, 3.3240e-01],
25+
"canonical_vector": [
26+
2.49407589e-02,
27+
1.00189969e-02,
28+
1.07807154e-02,
29+
3.63860987e-02,
30+
-2.27128249e-02,
31+
],
3032
},
3133
]
3234

35+
DIMENSIONS = {
36+
"intfloat/multilingual-e5-small": 384,
37+
"mixedbread-ai/mxbai-embed-xsmall-v1": 384,
38+
}
39+
40+
SOURCES = {
41+
"intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small",
42+
"mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1",
43+
}
44+
3345

3446
@pytest.mark.parametrize("scenario", canonical_vectors)
3547
def test_add_custom_model_variations(scenario):
3648
"""
37-
Tests that add_custom_model successfully registers the same base model
38-
with different (mean_pooling, normalization) configurations. We check
39-
whether we get the correct partial embedding values for "hello world".
49+
Tests add_custom_model for different base models and different
50+
(mean_pooling, normalization) configs. Checks the first 5 dims
51+
of "hello world" match the scenario's canonical vector.
4052
"""
4153

42-
is_ci = os.getenv("CI", False)
54+
is_ci = bool(os.getenv("CI", False))
4355

44-
base_model_name = "intfloat/multilingual-e5-small"
56+
base_model_name = scenario["model"]
57+
mean_pooling = scenario["mean_pooling"]
58+
normalization = scenario["normalization"]
59+
cv = np.array(scenario["canonical_vector"], dtype=np.float32)
4560

46-
# Build a unique model name to avoid collisions in the registry
47-
suffix = []
48-
suffix.append("mean" if scenario["mean_pooling"] else "no-mean")
49-
suffix.append("norm" if scenario["normalization"] else "no-norm")
50-
suffix_str = "-".join(suffix) # e.g. "mean-norm" or "no-mean-norm", etc.
61+
suffixes = []
62+
suffixes.append("mean" if mean_pooling else "no-mean")
63+
suffixes.append("norm" if normalization else "no-norm")
64+
suffix_str = "-".join(suffixes)
5165

5266
custom_model_name = f"{base_model_name}-{suffix_str}"
5367

54-
# Build the base model_info dictionary
68+
dim = DIMENSIONS[base_model_name]
69+
hf_source = SOURCES[base_model_name]
70+
5571
model_info = {
56-
"model": custom_model_name, # The registry key
57-
"dim": 384,
58-
"description": f"E5-small with {suffix_str}",
72+
"model": custom_model_name,
73+
"dim": dim,
74+
"description": f"{base_model_name} with {suffix_str}",
5975
"license": "mit",
6076
"size_in_GB": 0.13,
6177
"sources": {
62-
"hf": "intfloat/multilingual-e5-small",
78+
"hf": hf_source,
6379
},
6480
"model_file": "onnx/model.onnx",
6581
"additional_files": [],
6682
}
6783

68-
# Possibly skip on CI if the model is large:
6984
if is_ci and model_info["size_in_GB"] > 1:
7085
pytest.skip(
71-
f"Skipping {custom_model_name} on CI because size_in_GB={model_info['size_in_GB']}"
86+
f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}"
7287
)
7388

74-
# Register it so TextEmbedding can find it
7589
TextEmbedding.add_custom_model(
76-
model_info=model_info,
77-
mean_pooling=scenario["mean_pooling"],
78-
normalization=scenario["normalization"],
90+
model_info=model_info, mean_pooling=mean_pooling, normalization=normalization
7991
)
8092

81-
# Instantiate the newly added custom model
8293
model = TextEmbedding(model_name=custom_model_name)
8394

84-
# Prepare docs and embed
8595
docs = ["hello world", "flag embedding"]
8696
embeddings = list(model.embed(docs))
87-
embeddings = np.stack(embeddings, axis=0) # shape => (2, 1024)
97+
embeddings = np.stack(embeddings, axis=0)
8898

89-
# Check shape
90-
assert embeddings.shape == (2, model_info["dim"]), (
91-
f"Expected shape (2, {model_info['dim']}) for {custom_model_name}, "
92-
f"but got {embeddings.shape}"
93-
)
99+
assert embeddings.shape == (
100+
2,
101+
dim,
102+
), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}"
94103

95-
# Compare the first 5 dimensions of the first embedding to the canonical vector
96-
cv = np.array(scenario["canonical_vector"], dtype=np.float32) # shape => (5,)
97104
num_compare_dims = cv.shape[0]
98105
assert np.allclose(
99106
embeddings[0, :num_compare_dims], cv, atol=1e-3
100107
), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)."
101108

102-
# Optional: check that embedding is not all zeros
103-
assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be entirely zeros."
109+
assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros."
104110

105-
# Clean up cache in CI environment
106111
if is_ci:
107112
delete_model_cache(model.model._model_dir)

0 commit comments

Comments
 (0)