153153docs = ["Hello World" ]
154154
155155
156- def test_batch_embedding ():
156+ @pytest .mark .parametrize ("model_name" , ["answerdotai/answerai-colbert-small-v1" ])
157+ def test_batch_embedding (model_name : str ):
157158 is_ci = os .getenv ("CI" )
158159 docs_to_embed = docs * 10
159160
160- for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
161- print ("evaluating" , model_name )
162- model = LateInteractionTextEmbedding (model_name = model_name )
163- result = list (model .embed (docs_to_embed , batch_size = 6 ))
161+ model = LateInteractionTextEmbedding (model_name = model_name )
162+ result = list (model .embed (docs_to_embed , batch_size = 6 ))
163+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
164164
165- for value in result :
166- token_num , abridged_dim = expected_result .shape
167- assert np .allclose (value [:, :abridged_dim ], expected_result , atol = 2e-3 )
165+ for value in result :
166+ token_num , abridged_dim = expected_result .shape
167+ assert np .allclose (value [:, :abridged_dim ], expected_result , atol = 2e-3 )
168168
169- if is_ci :
170- delete_model_cache (model .model ._model_dir )
169+ if is_ci :
170+ delete_model_cache (model .model ._model_dir )
171171
172172
173173def test_single_embedding ():
174174 is_ci = os .getenv ("CI" )
175+ is_manual = os .getenv ("GITHUB_EVENT_NAME" ) == "workflow_dispatch"
175176 docs_to_embed = docs
176177
177- for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
178+ all_models = LateInteractionTextEmbedding ._list_supported_models ()
179+ models_to_test = [all_models [0 ]] if not is_manual else all_models
180+
181+ for model_desc in models_to_test :
182+ model_name = model_desc .model
183+ if (
184+ not is_ci and model_desc .size_in_GB > 1
185+ ) or model_desc .model not in CANONICAL_COLUMN_VALUES :
186+ continue
178187 print ("evaluating" , model_name )
179188 model = LateInteractionTextEmbedding (model_name = model_name )
180189 result = next (iter (model .embed (docs_to_embed , batch_size = 6 )))
190+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
181191 token_num , abridged_dim = expected_result .shape
182192 assert np .allclose (result [:, :abridged_dim ], expected_result , atol = 2e-3 )
183193
184- if is_ci :
185- delete_model_cache (model .model ._model_dir )
194+ if is_ci :
195+ delete_model_cache (model .model ._model_dir )
186196
187197
188198def test_single_embedding_query ():
189199 is_ci = os .getenv ("CI" )
200+ is_manual = os .getenv ("GITHUB_EVENT_NAME" ) == "workflow_dispatch"
190201 queries_to_embed = docs
191202
192- for model_name , expected_result in CANONICAL_QUERY_VALUES .items ():
203+ all_models = LateInteractionTextEmbedding ._list_supported_models ()
204+ models_to_test = [all_models [0 ]] if not is_manual else all_models
205+
206+ for model_desc in models_to_test :
207+ model_name = model_desc .model
208+ if (
209+ not is_ci and model_desc .size_in_GB > 1
210+ ) or model_desc .model not in CANONICAL_QUERY_VALUES :
211+ continue
193212 print ("evaluating" , model_name )
194213 model = LateInteractionTextEmbedding (model_name = model_name )
195214 result = next (iter (model .query_embed (queries_to_embed )))
215+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
196216 token_num , abridged_dim = expected_result .shape
197217 assert np .allclose (result [:, :abridged_dim ], expected_result , atol = 2e-3 )
198218
199219 if is_ci :
200220 delete_model_cache (model .model ._model_dir )
201221
202222
203- def test_parallel_processing ():
223+ @pytest .mark .parametrize (
224+ "token_dim,model_name" ,
225+ [(96 , "answerdotai/answerai-colbert-small-v1" )],
226+ )
227+ def test_parallel_processing (token_dim : int , model_name : str ):
204228 is_ci = os .getenv ("CI" )
205- model = LateInteractionTextEmbedding (model_name = "colbert-ir/colbertv2.0" )
206- token_dim = 128
229+ model = LateInteractionTextEmbedding (model_name = model_name )
230+
207231 docs = ["hello world" , "flag embedding" ] * 100
208232 embeddings = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
209233 embeddings = np .stack (embeddings , axis = 0 )
@@ -224,7 +248,7 @@ def test_parallel_processing():
224248
225249@pytest .mark .parametrize (
226250 "model_name" ,
227- ["colbert-ir/colbertv2.0 " ],
251+ ["answerdotai/answerai- colbert-small-v1 " ],
228252)
229253def test_lazy_load (model_name : str ):
230254 is_ci = os .getenv ("CI" )
0 commit comments