@@ -206,62 +206,10 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
206
206
"""
207
207
Lists the supported models.
208
208
"""
209
- return [
210
- {
211
- "model" : "BAAI/bge-small-en" ,
212
- "dim" : 384 ,
213
- "description" : "Fast English model" ,
214
- "size_in_GB" : 0.2 ,
215
- },
216
- {
217
- "model" : "BAAI/bge-small-en-v1.5" ,
218
- "dim" : 384 ,
219
- "description" : "Fast and Default English model" ,
220
- "size_in_GB" : 0.13 ,
221
- },
222
- {
223
- "model" : "BAAI/bge-small-zh-v1.5" ,
224
- "dim" : 512 ,
225
- "description" : "Fast and recommended Chinese model" ,
226
- "size_in_GB" : 0.1 ,
227
- },
228
- {
229
- "model" : "BAAI/bge-base-en" ,
230
- "dim" : 768 ,
231
- "description" : "Base English model" ,
232
- "size_in_GB" : 0.5 ,
233
- },
234
- {
235
- "model" : "BAAI/bge-base-en-v1.5" ,
236
- "dim" : 768 ,
237
- "description" : "Base English model, v1.5" ,
238
- "size_in_GB" : 0.44 ,
239
- },
240
- {
241
- "model" : "sentence-transformers/all-MiniLM-L6-v2" ,
242
- "dim" : 384 ,
243
- "description" : "Sentence Transformer model, MiniLM-L6-v2" ,
244
- "size_in_GB" : 0.09 ,
245
- },
246
- {
247
- "model" : "intfloat/multilingual-e5-large" ,
248
- "dim" : 1024 ,
249
- "description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
250
- "size_in_GB" : 2.24 ,
251
- },
252
- {
253
- "model" : "jinaai/jina-embeddings-v2-base-en" ,
254
- "dim" : 768 ,
255
- "description" : " English embedding model supporting 8192 sequence length" ,
256
- "size_in_GB" : 0.55 ,
257
- },
258
- {
259
- "model" : "jinaai/jina-embeddings-v2-small-en" ,
260
- "dim" : 512 ,
261
- "description" : " English embedding model supporting 8192 sequence length" ,
262
- "size_in_GB" : 0.13 ,
263
- },
264
- ]
209
+ models_file_path = Path (__file__ ).with_name ("models.json" )
210
+ models = json .load (open (str (models_file_path )))
211
+
212
+ return models
265
213
266
214
@classmethod
267
215
def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
@@ -318,19 +266,27 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
318
266
return output_path
319
267
320
268
@classmethod
321
- def download_files_from_huggingface (cls , repo_ids : List [ str ] , cache_dir : Optional [str ] = None ) -> str :
269
+ def download_files_from_huggingface (cls , model_name : str , cache_dir : Optional [str ] = None ) -> str :
322
270
"""
323
271
Downloads a model from HuggingFace Hub.
324
272
Args:
325
- repo_ids (List[ str] ): A list of HF model IDs to download.
273
+ model_name ( str): Name of the model to download.
326
274
cache_dir (Optional[str]): The path to the cache directory.
327
275
Raises:
328
276
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
329
277
Returns:
330
278
Path: The path to the model directory.
331
279
"""
280
+ models = cls .list_supported_models ()
332
281
333
- for index , repo_id in enumerate (repo_ids ):
282
+ hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
283
+
284
+ # Check if the HF sources list is empty
285
+ # Raise an exception causing a fallback to GCS
286
+ if not hf_sources :
287
+ raise ValueError (f"No HuggingFace source for { model_name } " )
288
+
289
+ for index , repo_id in enumerate (hf_sources ):
334
290
try :
335
291
return snapshot_download (
336
292
repo_id = repo_id ,
@@ -339,9 +295,9 @@ def download_files_from_huggingface(cls, repo_ids: List[str], cache_dir: Optiona
339
295
)
340
296
except (RepositoryNotFoundError , EnvironmentError ) as e :
341
297
logger .error (f"Failed to download model from HF source: { repo_id } : { e } " )
342
- if repo_id == repo_ids [- 1 ]:
298
+ if repo_id == hf_sources [- 1 ]:
343
299
raise e
344
- logger .info (f"Trying another source: { repo_ids [index + 1 ]} " )
300
+ logger .info (f"Trying another source: { hf_sources [index + 1 ]} " )
345
301
346
302
@classmethod
347
303
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -399,18 +355,27 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
399
355
return model_dir
400
356
401
357
model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
402
- try :
403
- self .download_file_from_gcs (
404
- f"https://storage.googleapis.com/qdrant-fastembed/{ fast_model_name } .tar.gz" ,
405
- output_path = str (model_tar_gz ),
406
- )
407
- except PermissionError :
408
- simple_model_name = model_name .replace ("/" , "-" )
409
- print (f"Was not able to download { fast_model_name } .tar.gz, trying { simple_model_name } .tar.gz" )
410
- self .download_file_from_gcs (
411
- f"https://storage.googleapis.com/qdrant-fastembed/{ simple_model_name } .tar.gz" ,
412
- output_path = str (model_tar_gz ),
413
- )
358
+
359
+ models = self .list_supported_models ()
360
+
361
+ gcs_sources = [item for model in models if model ["model" ] == model_name for item in model ["gcs_sources" ]]
362
+
363
+ # Check if the GCS sources list is empty after falling back from HF
364
+ # A model should always have at least one source
365
+ if not gcs_sources :
366
+ raise ValueError (f"No GCS source for { model_name } " )
367
+
368
+ for index , source in enumerate (gcs_sources ):
369
+ try :
370
+ self .download_file_from_gcs (
371
+ f"https://storage.googleapis.com/{ source } " ,
372
+ output_path = str (model_tar_gz ),
373
+ )
374
+ except (RuntimeError , PermissionError ) as e :
375
+ logger .error (f"Failed to download model from GCS source: { source } : { e } " )
376
+ if source == gcs_sources [- 1 ]:
377
+ raise e
378
+ logger .info (f"Trying another source: { gcs_sources [index + 1 ]} " )
414
379
415
380
self .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = cache_dir )
416
381
assert model_dir .exists (), f"Could not find { model_dir } in { cache_dir } "
@@ -429,15 +394,21 @@ def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
429
394
Returns:
430
395
Path: The path to the model directory.
431
396
"""
432
- models_file_path = Path (__file__ ).with_name ("models.json" )
433
- models = json .load (open (str (models_file_path )))
434
397
435
- if model_name not in [model ["name" ] for model in models ]:
436
- raise ValueError (f"Could not find { model_name } in { models_file_path } " )
398
+ return Path (self .download_files_from_huggingface (model_name = model_name , cache_dir = cache_dir ))
437
399
438
- sources = [item for model in models if model ["name" ] == model_name for item in model ["sources" ]]
400
+ @classmethod
401
+ def assert_model_name (cls , model_name : str ):
402
+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
439
403
440
- return Path (self .download_files_from_huggingface (repo_ids = sources , cache_dir = cache_dir ))
404
+ models = cls .list_supported_models ()
405
+ model_names = [model ["model" ] for model in models ]
406
+ if model_name not in model_names :
407
+ raise ValueError (
408
+ f"{ model_name } is not a supported model.\n "
409
+ f"Try one of { ', ' .join (model_names )} .\n "
410
+ f"Use the 'list_supported_models()' method to get the model information."
411
+ )
441
412
442
413
def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
443
414
"""
@@ -498,7 +469,8 @@ def __init__(
498
469
Raises:
499
470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
500
471
"""
501
- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
472
+
473
+ self .assert_model_name (model_name )
502
474
503
475
self .model_name = model_name
504
476
@@ -618,9 +590,7 @@ def __init__(
618
590
Raises:
619
591
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
620
592
"""
621
- assert (
622
- "/" in model_name
623
- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"
593
+ self .assert_model_name (model_name )
624
594
625
595
self .model_name = model_name
626
596
0 commit comments