|
5 | 5 | from fastembed.text.text_embedding import TextEmbedding |
6 | 6 | from tests.utils import delete_model_cache |
7 | 7 |
|
8 | | -# Four variations of the same base model, with different (mean_pooling, normalization). |
9 | | -# Each has a separate canonical vector for "hello world" (first 5 dims). |
10 | 8 | canonical_vectors = [ |
11 | 9 | { |
| 10 | + "model": "intfloat/multilingual-e5-small", |
12 | 11 | "mean_pooling": True, |
13 | 12 | "normalization": True, |
14 | 13 | "canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02], |
15 | 14 | }, |
16 | 15 | { |
| 16 | + "model": "intfloat/multilingual-e5-small", |
17 | 17 | "mean_pooling": True, |
18 | 18 | "normalization": False, |
19 | 19 | "canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01], |
20 | 20 | }, |
21 | 21 | { |
| 22 | + "model": "mixedbread-ai/mxbai-embed-xsmall-v1", |
22 | 23 | "mean_pooling": False, |
23 | 24 | "normalization": False, |
24 | | - "canonical_vector": [1.8612e-01, 9.1158e-02, -1.4521e-01, -3.3533e-01, 3.2876e-01], |
25 | | - }, |
26 | | - { |
27 | | - "mean_pooling": False, |
28 | | - "normalization": True, |
29 | | - "canonical_vector": [4.6600e-01, 2.1830e-01, -3.3190e-01, -4.2530e-01, 3.3240e-01], |
| 25 | + "canonical_vector": [ |
| 26 | + 2.49407589e-02, |
| 27 | + 1.00189969e-02, |
| 28 | + 1.07807154e-02, |
| 29 | + 3.63860987e-02, |
| 30 | + -2.27128249e-02, |
| 31 | + ], |
30 | 32 | }, |
31 | 33 | ] |
32 | 34 |
|
| 35 | +DIMENSIONS = { |
| 36 | + "intfloat/multilingual-e5-small": 384, |
| 37 | + "mixedbread-ai/mxbai-embed-xsmall-v1": 384, |
| 38 | +} |
| 39 | + |
| 40 | +SOURCES = { |
| 41 | + "intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small", |
| 42 | + "mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1", |
| 43 | +} |
| 44 | + |
33 | 45 |
|
34 | 46 | @pytest.mark.parametrize("scenario", canonical_vectors) |
35 | 47 | def test_add_custom_model_variations(scenario): |
36 | 48 | """ |
37 | | - Tests that add_custom_model successfully registers the same base model |
38 | | - with different (mean_pooling, normalization) configurations. We check |
39 | | - whether we get the correct partial embedding values for "hello world". |
| 49 | + Tests add_custom_model for different base models and different |
| 50 | + (mean_pooling, normalization) configs. Checks the first 5 dims |
| 51 | + of "hello world" match the scenario's canonical vector. |
40 | 52 | """ |
41 | 53 |
|
42 | | - is_ci = os.getenv("CI", False) |
| 54 | + is_ci = bool(os.getenv("CI", False)) |
43 | 55 |
|
44 | | - base_model_name = "intfloat/multilingual-e5-small" |
| 56 | + base_model_name = scenario["model"] |
| 57 | + mean_pooling = scenario["mean_pooling"] |
| 58 | + normalization = scenario["normalization"] |
| 59 | + cv = np.array(scenario["canonical_vector"], dtype=np.float32) |
45 | 60 |
|
46 | | - # Build a unique model name to avoid collisions in the registry |
47 | | - suffix = [] |
48 | | - suffix.append("mean" if scenario["mean_pooling"] else "no-mean") |
49 | | - suffix.append("norm" if scenario["normalization"] else "no-norm") |
50 | | - suffix_str = "-".join(suffix) # e.g. "mean-norm" or "no-mean-norm", etc. |
| 61 | + suffixes = [] |
| 62 | + suffixes.append("mean" if mean_pooling else "no-mean") |
| 63 | + suffixes.append("norm" if normalization else "no-norm") |
| 64 | + suffix_str = "-".join(suffixes) |
51 | 65 |
|
52 | 66 | custom_model_name = f"{base_model_name}-{suffix_str}" |
53 | 67 |
|
54 | | - # Build the base model_info dictionary |
| 68 | + dim = DIMENSIONS[base_model_name] |
| 69 | + hf_source = SOURCES[base_model_name] |
| 70 | + |
55 | 71 | model_info = { |
56 | | - "model": custom_model_name, # The registry key |
57 | | - "dim": 384, |
58 | | - "description": f"E5-small with {suffix_str}", |
| 72 | + "model": custom_model_name, |
| 73 | + "dim": dim, |
| 74 | + "description": f"{base_model_name} with {suffix_str}", |
59 | 75 | "license": "mit", |
60 | 76 | "size_in_GB": 0.13, |
61 | 77 | "sources": { |
62 | | - "hf": "intfloat/multilingual-e5-small", |
| 78 | + "hf": hf_source, |
63 | 79 | }, |
64 | 80 | "model_file": "onnx/model.onnx", |
65 | 81 | "additional_files": [], |
66 | 82 | } |
67 | 83 |
|
68 | | - # Possibly skip on CI if the model is large: |
69 | 84 | if is_ci and model_info["size_in_GB"] > 1: |
70 | 85 | pytest.skip( |
71 | | - f"Skipping {custom_model_name} on CI because size_in_GB={model_info['size_in_GB']}" |
| 86 | + f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}" |
72 | 87 | ) |
73 | 88 |
|
74 | | - # Register it so TextEmbedding can find it |
75 | 89 | TextEmbedding.add_custom_model( |
76 | | - model_info=model_info, |
77 | | - mean_pooling=scenario["mean_pooling"], |
78 | | - normalization=scenario["normalization"], |
| 90 | + model_info=model_info, mean_pooling=mean_pooling, normalization=normalization |
79 | 91 | ) |
80 | 92 |
|
81 | | - # Instantiate the newly added custom model |
82 | 93 | model = TextEmbedding(model_name=custom_model_name) |
83 | 94 |
|
84 | | - # Prepare docs and embed |
85 | 95 | docs = ["hello world", "flag embedding"] |
86 | 96 | embeddings = list(model.embed(docs)) |
87 | | - embeddings = np.stack(embeddings, axis=0) # shape => (2, 1024) |
| 97 | + embeddings = np.stack(embeddings, axis=0) |
88 | 98 |
|
89 | | - # Check shape |
90 | | - assert embeddings.shape == (2, model_info["dim"]), ( |
91 | | - f"Expected shape (2, {model_info['dim']}) for {custom_model_name}, " |
92 | | - f"but got {embeddings.shape}" |
93 | | - ) |
| 99 | + assert embeddings.shape == ( |
| 100 | + 2, |
| 101 | + dim, |
| 102 | + ), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}" |
94 | 103 |
|
95 | | - # Compare the first 5 dimensions of the first embedding to the canonical vector |
96 | | - cv = np.array(scenario["canonical_vector"], dtype=np.float32) # shape => (5,) |
97 | 104 | num_compare_dims = cv.shape[0] |
98 | 105 | assert np.allclose( |
99 | 106 | embeddings[0, :num_compare_dims], cv, atol=1e-3 |
100 | 107 | ), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)." |
101 | 108 |
|
102 | | - # Optional: check that embedding is not all zeros |
103 | | - assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be entirely zeros." |
| 109 | + assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros." |
104 | 110 |
|
105 | | - # Clean up cache in CI environment |
106 | 111 | if is_ci: |
107 | 112 | delete_model_cache(model.model._model_dir) |
0 commit comments