Skip to content

Commit b186301

Browse files
committed
Clear supported_models list to keep only supported models
1 parent fa3f20c commit b186301

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

Diff for: tests/test_add_custom_model.py

+28-25
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,17 @@
4545

4646
@pytest.mark.parametrize("scenario", canonical_vectors)
4747
def test_add_custom_model_variations(scenario):
48-
"""
49-
Tests add_custom_model for different base models and different
50-
(mean_pooling, normalization) configs. Checks the first 5 dims
51-
of "hello world" match the scenario's canonical vector.
52-
"""
53-
5448
is_ci = bool(os.getenv("CI", False))
5549

5650
base_model_name = scenario["model"]
5751
mean_pooling = scenario["mean_pooling"]
5852
normalization = scenario["normalization"]
5953
cv = np.array(scenario["canonical_vector"], dtype=np.float32)
6054

55+
backup_supported_models = {}
56+
for embedding_cls in TextEmbedding.EMBEDDINGS_REGISTRY:
57+
backup_supported_models[embedding_cls] = embedding_cls.list_supported_models().copy()
58+
6159
suffixes = []
6260
suffixes.append("mean" if mean_pooling else "no-mean")
6361
suffixes.append("norm" if normalization else "no-norm")
@@ -81,32 +79,37 @@ def test_add_custom_model_variations(scenario):
8179
"additional_files": [],
8280
}
8381

84-
if is_ci and model_info["size_in_GB"] > 1:
82+
if is_ci and model_info["size_in_GB"] > 1.0:
8583
pytest.skip(
8684
f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}"
8785
)
8886

89-
TextEmbedding.add_custom_model(
90-
model_info=model_info, mean_pooling=mean_pooling, normalization=normalization
91-
)
87+
try:
88+
TextEmbedding.add_custom_model(
89+
model_info=model_info, mean_pooling=mean_pooling, normalization=normalization
90+
)
91+
92+
model = TextEmbedding(model_name=custom_model_name)
9293

93-
model = TextEmbedding(model_name=custom_model_name)
94+
docs = ["hello world", "flag embedding"]
95+
embeddings = list(model.embed(docs))
96+
embeddings = np.stack(embeddings, axis=0)
9497

95-
docs = ["hello world", "flag embedding"]
96-
embeddings = list(model.embed(docs))
97-
embeddings = np.stack(embeddings, axis=0)
98+
assert embeddings.shape == (
99+
2,
100+
dim,
101+
), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}"
98102

99-
assert embeddings.shape == (
100-
2,
101-
dim,
102-
), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}"
103+
num_compare_dims = cv.shape[0]
104+
assert np.allclose(
105+
embeddings[0, :num_compare_dims], cv, atol=1e-3
106+
), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)."
103107

104-
num_compare_dims = cv.shape[0]
105-
assert np.allclose(
106-
embeddings[0, :num_compare_dims], cv, atol=1e-3
107-
), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)."
108+
assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros."
108109

109-
assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros."
110+
if is_ci:
111+
delete_model_cache(model.model._model_dir)
110112

111-
if is_ci:
112-
delete_model_cache(model.model._model_dir)
113+
finally:
114+
for embedding_cls, old_list in backup_supported_models.items():
115+
embedding_cls.supported_models = old_list

0 commit comments

Comments
 (0)