153
153
docs = ["Hello World" ]
154
154
155
155
156
- def test_batch_embedding ():
156
+ @pytest .mark .parametrize ("model_name" , ["answerdotai/answerai-colbert-small-v1" ])
157
+ def test_batch_embedding (model_name : str ):
157
158
is_ci = os .getenv ("CI" )
158
159
docs_to_embed = docs * 10
159
160
160
- for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
161
- print ("evaluating" , model_name )
162
- model = LateInteractionTextEmbedding (model_name = model_name )
163
- result = list (model .embed (docs_to_embed , batch_size = 6 ))
161
+ model = LateInteractionTextEmbedding (model_name = model_name )
162
+ result = list (model .embed (docs_to_embed , batch_size = 6 ))
163
+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
164
164
165
- for value in result :
166
- token_num , abridged_dim = expected_result .shape
167
- assert np .allclose (value [:, :abridged_dim ], expected_result , atol = 2e-3 )
165
+ for value in result :
166
+ token_num , abridged_dim = expected_result .shape
167
+ assert np .allclose (value [:, :abridged_dim ], expected_result , atol = 2e-3 )
168
168
169
- if is_ci :
170
- delete_model_cache (model .model ._model_dir )
169
+ if is_ci :
170
+ delete_model_cache (model .model ._model_dir )
171
171
172
172
173
173
def test_single_embedding ():
174
174
is_ci = os .getenv ("CI" )
175
+ is_manual = os .getenv ("GITHUB_EVENT_NAME" ) == "workflow_dispatch"
175
176
docs_to_embed = docs
176
177
177
- for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
178
+ all_models = LateInteractionTextEmbedding ._list_supported_models ()
179
+ models_to_test = [all_models [0 ]] if not is_manual else all_models
180
+
181
+ for model_desc in models_to_test :
182
+ model_name = model_desc .model
183
+ if (
184
+ not is_ci and model_desc .size_in_GB > 1
185
+ ) or model_desc .model not in CANONICAL_COLUMN_VALUES :
186
+ continue
178
187
print ("evaluating" , model_name )
179
188
model = LateInteractionTextEmbedding (model_name = model_name )
180
189
result = next (iter (model .embed (docs_to_embed , batch_size = 6 )))
190
+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
181
191
token_num , abridged_dim = expected_result .shape
182
192
assert np .allclose (result [:, :abridged_dim ], expected_result , atol = 2e-3 )
183
193
184
- if is_ci :
185
- delete_model_cache (model .model ._model_dir )
194
+ if is_ci :
195
+ delete_model_cache (model .model ._model_dir )
186
196
187
197
188
198
def test_single_embedding_query ():
189
199
is_ci = os .getenv ("CI" )
200
+ is_manual = os .getenv ("GITHUB_EVENT_NAME" ) == "workflow_dispatch"
190
201
queries_to_embed = docs
191
202
192
- for model_name , expected_result in CANONICAL_QUERY_VALUES .items ():
203
+ all_models = LateInteractionTextEmbedding ._list_supported_models ()
204
+ models_to_test = [all_models [0 ]] if not is_manual else all_models
205
+
206
+ for model_desc in models_to_test :
207
+ model_name = model_desc .model
208
+ if (
209
+ not is_ci and model_desc .size_in_GB > 1
210
+ ) or model_desc .model not in CANONICAL_QUERY_VALUES :
211
+ continue
193
212
print ("evaluating" , model_name )
194
213
model = LateInteractionTextEmbedding (model_name = model_name )
195
214
result = next (iter (model .query_embed (queries_to_embed )))
215
+ expected_result = CANONICAL_COLUMN_VALUES [model_name ]
196
216
token_num , abridged_dim = expected_result .shape
197
217
assert np .allclose (result [:, :abridged_dim ], expected_result , atol = 2e-3 )
198
218
199
219
if is_ci :
200
220
delete_model_cache (model .model ._model_dir )
201
221
202
222
203
- def test_parallel_processing ():
223
+ @pytest .mark .parametrize (
224
+ "token_dim,model_name" ,
225
+ [(96 , "answerdotai/answerai-colbert-small-v1" )],
226
+ )
227
+ def test_parallel_processing (token_dim : int , model_name : str ):
204
228
is_ci = os .getenv ("CI" )
205
- model = LateInteractionTextEmbedding (model_name = "colbert-ir/colbertv2.0" )
206
- token_dim = 128
229
+ model = LateInteractionTextEmbedding (model_name = model_name )
230
+
207
231
docs = ["hello world" , "flag embedding" ] * 100
208
232
embeddings = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
209
233
embeddings = np .stack (embeddings , axis = 0 )
@@ -224,7 +248,7 @@ def test_parallel_processing():
224
248
225
249
@pytest .mark .parametrize (
226
250
"model_name" ,
227
- ["colbert-ir/colbertv2.0 " ],
251
+ ["answerdotai/answerai- colbert-small-v1 " ],
228
252
)
229
253
def test_lazy_load (model_name : str ):
230
254
is_ci = os .getenv ("CI" )
0 commit comments