4545
4646@pytest .mark .parametrize ("scenario" , canonical_vectors )
4747def test_add_custom_model_variations (scenario ):
48- """
49- Tests add_custom_model for different base models and different
50- (mean_pooling, normalization) configs. Checks the first 5 dims
51- of "hello world" match the scenario's canonical vector.
52- """
53-
5448 is_ci = bool (os .getenv ("CI" , False ))
5549
5650 base_model_name = scenario ["model" ]
5751 mean_pooling = scenario ["mean_pooling" ]
5852 normalization = scenario ["normalization" ]
5953 cv = np .array (scenario ["canonical_vector" ], dtype = np .float32 )
6054
55+ backup_supported_models = {}
56+ for embedding_cls in TextEmbedding .EMBEDDINGS_REGISTRY :
57+ backup_supported_models [embedding_cls ] = embedding_cls .list_supported_models ().copy ()
58+
6159 suffixes = []
6260 suffixes .append ("mean" if mean_pooling else "no-mean" )
6361 suffixes .append ("norm" if normalization else "no-norm" )
@@ -81,32 +79,37 @@ def test_add_custom_model_variations(scenario):
8179 "additional_files" : [],
8280 }
8381
84- if is_ci and model_info ["size_in_GB" ] > 1 :
82+ if is_ci and model_info ["size_in_GB" ] > 1.0 :
8583 pytest .skip (
8684 f"Skipping { custom_model_name } on CI due to size_in_GB={ model_info ['size_in_GB' ]} "
8785 )
8886
89- TextEmbedding .add_custom_model (
90- model_info = model_info , mean_pooling = mean_pooling , normalization = normalization
91- )
87+ try :
88+ TextEmbedding .add_custom_model (
89+ model_info = model_info , mean_pooling = mean_pooling , normalization = normalization
90+ )
91+
92+ model = TextEmbedding (model_name = custom_model_name )
9293
93- model = TextEmbedding (model_name = custom_model_name )
94+ docs = ["hello world" , "flag embedding" ]
95+ embeddings = list (model .embed (docs ))
96+ embeddings = np .stack (embeddings , axis = 0 )
9497
95- docs = ["hello world" , "flag embedding" ]
96- embeddings = list (model .embed (docs ))
97- embeddings = np .stack (embeddings , axis = 0 )
98+ assert embeddings .shape == (
99+ 2 ,
100+ dim ,
101+ ), f"Expected shape (2, { dim } ) for { custom_model_name } , but got { embeddings .shape } "
98102
99- assert embeddings . shape == (
100- 2 ,
101- dim ,
102- ), f"Expected shape (2, { dim } ) for { custom_model_name } , but got { embeddings . shape } "
103+ num_compare_dims = cv . shape [ 0 ]
104+ assert np . allclose (
105+ embeddings [ 0 , : num_compare_dims ], cv , atol = 1e-3
106+ ), f"Embedding mismatch for { custom_model_name } (first { num_compare_dims } dims). "
103107
104- num_compare_dims = cv .shape [0 ]
105- assert np .allclose (
106- embeddings [0 , :num_compare_dims ], cv , atol = 1e-3
107- ), f"Embedding mismatch for { custom_model_name } (first { num_compare_dims } dims)."
108+ assert not np .allclose (embeddings [0 , :], 0.0 ), "Embedding should not be all zeros."
108109
109- assert not np .allclose (embeddings [0 , :], 0.0 ), "Embedding should not be all zeros."
110+ if is_ci :
111+ delete_model_cache (model .model ._model_dir )
110112
111- if is_ci :
112- delete_model_cache (model .model ._model_dir )
113+ finally :
114+ for embedding_cls , old_list in backup_supported_models .items ():
115+ embedding_cls .supported_models = old_list
0 commit comments