|
5 | 5 | from fastembed.text.text_embedding import TextEmbedding
|
6 | 6 | from tests.utils import delete_model_cache
|
7 | 7 |
|
8 |
| -# Four variations of the same base model, with different (mean_pooling, normalization). |
9 |
| -# Each has a separate canonical vector for "hello world" (first 5 dims). |
10 | 8 | canonical_vectors = [
|
11 | 9 | {
|
| 10 | + "model": "intfloat/multilingual-e5-small", |
12 | 11 | "mean_pooling": True,
|
13 | 12 | "normalization": True,
|
14 | 13 | "canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02],
|
15 | 14 | },
|
16 | 15 | {
|
| 16 | + "model": "intfloat/multilingual-e5-small", |
17 | 17 | "mean_pooling": True,
|
18 | 18 | "normalization": False,
|
19 | 19 | "canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01],
|
20 | 20 | },
|
21 | 21 | {
|
| 22 | + "model": "mixedbread-ai/mxbai-embed-xsmall-v1", |
22 | 23 | "mean_pooling": False,
|
23 | 24 | "normalization": False,
|
24 |
| - "canonical_vector": [1.8612e-01, 9.1158e-02, -1.4521e-01, -3.3533e-01, 3.2876e-01], |
25 |
| - }, |
26 |
| - { |
27 |
| - "mean_pooling": False, |
28 |
| - "normalization": True, |
29 |
| - "canonical_vector": [4.6600e-01, 2.1830e-01, -3.3190e-01, -4.2530e-01, 3.3240e-01], |
| 25 | + "canonical_vector": [ |
| 26 | + 2.49407589e-02, |
| 27 | + 1.00189969e-02, |
| 28 | + 1.07807154e-02, |
| 29 | + 3.63860987e-02, |
| 30 | + -2.27128249e-02, |
| 31 | + ], |
30 | 32 | },
|
31 | 33 | ]
|
32 | 34 |
|
| 35 | +DIMENSIONS = { |
| 36 | + "intfloat/multilingual-e5-small": 384, |
| 37 | + "mixedbread-ai/mxbai-embed-xsmall-v1": 384, |
| 38 | +} |
| 39 | + |
| 40 | +SOURCES = { |
| 41 | + "intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small", |
| 42 | + "mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1", |
| 43 | +} |
| 44 | + |
33 | 45 |
|
34 | 46 | @pytest.mark.parametrize("scenario", canonical_vectors)
|
35 | 47 | def test_add_custom_model_variations(scenario):
|
36 | 48 | """
|
37 |
| - Tests that add_custom_model successfully registers the same base model |
38 |
| - with different (mean_pooling, normalization) configurations. We check |
39 |
| - whether we get the correct partial embedding values for "hello world". |
| 49 | + Tests add_custom_model for different base models and different |
| 50 | + (mean_pooling, normalization) configs. Checks the first 5 dims |
| 51 | + of "hello world" match the scenario's canonical vector. |
40 | 52 | """
|
41 | 53 |
|
42 |
| - is_ci = os.getenv("CI", False) |
| 54 | + is_ci = bool(os.getenv("CI", False)) |
43 | 55 |
|
44 |
| - base_model_name = "intfloat/multilingual-e5-small" |
| 56 | + base_model_name = scenario["model"] |
| 57 | + mean_pooling = scenario["mean_pooling"] |
| 58 | + normalization = scenario["normalization"] |
| 59 | + cv = np.array(scenario["canonical_vector"], dtype=np.float32) |
45 | 60 |
|
46 |
| - # Build a unique model name to avoid collisions in the registry |
47 |
| - suffix = [] |
48 |
| - suffix.append("mean" if scenario["mean_pooling"] else "no-mean") |
49 |
| - suffix.append("norm" if scenario["normalization"] else "no-norm") |
50 |
| - suffix_str = "-".join(suffix) # e.g. "mean-norm" or "no-mean-norm", etc. |
| 61 | + suffixes = [] |
| 62 | + suffixes.append("mean" if mean_pooling else "no-mean") |
| 63 | + suffixes.append("norm" if normalization else "no-norm") |
| 64 | + suffix_str = "-".join(suffixes) |
51 | 65 |
|
52 | 66 | custom_model_name = f"{base_model_name}-{suffix_str}"
|
53 | 67 |
|
54 |
| - # Build the base model_info dictionary |
| 68 | + dim = DIMENSIONS[base_model_name] |
| 69 | + hf_source = SOURCES[base_model_name] |
| 70 | + |
55 | 71 | model_info = {
|
56 |
| - "model": custom_model_name, # The registry key |
57 |
| - "dim": 384, |
58 |
| - "description": f"E5-small with {suffix_str}", |
| 72 | + "model": custom_model_name, |
| 73 | + "dim": dim, |
| 74 | + "description": f"{base_model_name} with {suffix_str}", |
59 | 75 | "license": "mit",
|
60 | 76 | "size_in_GB": 0.13,
|
61 | 77 | "sources": {
|
62 |
| - "hf": "intfloat/multilingual-e5-small", |
| 78 | + "hf": hf_source, |
63 | 79 | },
|
64 | 80 | "model_file": "onnx/model.onnx",
|
65 | 81 | "additional_files": [],
|
66 | 82 | }
|
67 | 83 |
|
68 |
| - # Possibly skip on CI if the model is large: |
69 | 84 | if is_ci and model_info["size_in_GB"] > 1:
|
70 | 85 | pytest.skip(
|
71 |
| - f"Skipping {custom_model_name} on CI because size_in_GB={model_info['size_in_GB']}" |
| 86 | + f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}" |
72 | 87 | )
|
73 | 88 |
|
74 |
| - # Register it so TextEmbedding can find it |
75 | 89 | TextEmbedding.add_custom_model(
|
76 |
| - model_info=model_info, |
77 |
| - mean_pooling=scenario["mean_pooling"], |
78 |
| - normalization=scenario["normalization"], |
| 90 | + model_info=model_info, mean_pooling=mean_pooling, normalization=normalization |
79 | 91 | )
|
80 | 92 |
|
81 |
| - # Instantiate the newly added custom model |
82 | 93 | model = TextEmbedding(model_name=custom_model_name)
|
83 | 94 |
|
84 |
| - # Prepare docs and embed |
85 | 95 | docs = ["hello world", "flag embedding"]
|
86 | 96 | embeddings = list(model.embed(docs))
|
87 |
| - embeddings = np.stack(embeddings, axis=0) # shape => (2, 1024) |
| 97 | + embeddings = np.stack(embeddings, axis=0) |
88 | 98 |
|
89 |
| - # Check shape |
90 |
| - assert embeddings.shape == (2, model_info["dim"]), ( |
91 |
| - f"Expected shape (2, {model_info['dim']}) for {custom_model_name}, " |
92 |
| - f"but got {embeddings.shape}" |
93 |
| - ) |
| 99 | + assert embeddings.shape == ( |
| 100 | + 2, |
| 101 | + dim, |
| 102 | + ), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}" |
94 | 103 |
|
95 |
| - # Compare the first 5 dimensions of the first embedding to the canonical vector |
96 |
| - cv = np.array(scenario["canonical_vector"], dtype=np.float32) # shape => (5,) |
97 | 104 | num_compare_dims = cv.shape[0]
|
98 | 105 | assert np.allclose(
|
99 | 106 | embeddings[0, :num_compare_dims], cv, atol=1e-3
|
100 | 107 | ), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)."
|
101 | 108 |
|
102 |
| - # Optional: check that embedding is not all zeros |
103 |
| - assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be entirely zeros." |
| 109 | + assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros." |
104 | 110 |
|
105 |
| - # Clean up cache in CI environment |
106 | 111 | if is_ci:
|
107 | 112 | delete_model_cache(model.model._model_dir)
|
0 commit comments