1
+ import functools
1
2
import json
2
3
import os
3
4
import shutil
7
8
from itertools import islice
8
9
from multiprocessing import get_all_start_methods
9
10
from pathlib import Path
10
- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
11
+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
11
12
12
13
import numpy as np
13
14
import onnxruntime as ort
14
15
import requests
15
16
from tokenizers import AddedToken , Tokenizer
16
17
from tqdm import tqdm
18
+ from huggingface_hub import snapshot_download
19
+ from huggingface_hub .utils import RepositoryNotFoundError
20
+ from loguru import logger
17
21
18
22
from fastembed .parallel_processor import ParallelWorkerPool , Worker
19
23
@@ -179,71 +183,44 @@ class Embedding(ABC):
179
183
_type_: _description_
180
184
"""
181
185
186
+ # Internal helper decorator to maintain backward compatibility
187
+ # by supporting a fallback to download from Google Cloud Storage (GCS)
188
+ # if the model couldn't be downloaded from HuggingFace.
189
+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
190
+ @functools .wraps (hf_download_method )
191
+ def wrapper (self , * args , ** kwargs ):
192
+ try :
193
+ return hf_download_method (self , * args , ** kwargs )
194
+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
195
+ logger .exception (
196
+ f"Could not download model from HuggingFace: { e } "
197
+ "Falling back to download from Google Cloud Storage"
198
+ )
199
+ return self .retrieve_model_gcs (* args , ** kwargs )
200
+
201
+ return wrapper
202
+
182
203
@abstractmethod
183
204
def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
184
205
raise NotImplementedError
185
206
186
207
@classmethod
187
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
188
- """
189
- Lists the supported models.
208
+ def list_supported_models (cls , exclude : List [str ] = []) -> List [Dict [str , Any ]]:
209
+ """Lists the supported models.
210
+
211
+ Args:
212
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213
+
214
+ Returns:
215
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
190
216
"""
191
- return [
192
- {
193
- "model" : "BAAI/bge-small-en" ,
194
- "dim" : 384 ,
195
- "description" : "Fast English model" ,
196
- "size_in_GB" : 0.2 ,
197
- },
198
- {
199
- "model" : "BAAI/bge-small-en-v1.5" ,
200
- "dim" : 384 ,
201
- "description" : "Fast and Default English model" ,
202
- "size_in_GB" : 0.13 ,
203
- },
204
- {
205
- "model" : "BAAI/bge-small-zh-v1.5" ,
206
- "dim" : 512 ,
207
- "description" : "Fast and recommended Chinese model" ,
208
- "size_in_GB" : 0.1 ,
209
- },
210
- {
211
- "model" : "BAAI/bge-base-en" ,
212
- "dim" : 768 ,
213
- "description" : "Base English model" ,
214
- "size_in_GB" : 0.5 ,
215
- },
216
- {
217
- "model" : "BAAI/bge-base-en-v1.5" ,
218
- "dim" : 768 ,
219
- "description" : "Base English model, v1.5" ,
220
- "size_in_GB" : 0.44 ,
221
- },
222
- {
223
- "model" : "sentence-transformers/all-MiniLM-L6-v2" ,
224
- "dim" : 384 ,
225
- "description" : "Sentence Transformer model, MiniLM-L6-v2" ,
226
- "size_in_GB" : 0.09 ,
227
- },
228
- {
229
- "model" : "intfloat/multilingual-e5-large" ,
230
- "dim" : 1024 ,
231
- "description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
232
- "size_in_GB" : 2.24 ,
233
- },
234
- {
235
- "model" : "jinaai/jina-embeddings-v2-base-en" ,
236
- "dim" : 768 ,
237
- "description" : " English embedding model supporting 8192 sequence length" ,
238
- "size_in_GB" : 0.55 ,
239
- },
240
- {
241
- "model" : "jinaai/jina-embeddings-v2-small-en" ,
242
- "dim" : 512 ,
243
- "description" : " English embedding model supporting 8192 sequence length" ,
244
- "size_in_GB" : 0.13 ,
245
- },
246
- ]
217
+ models_file_path = Path (__file__ ).with_name ("models.json" )
218
+ with open (models_file_path , "r" ) as file :
219
+ models = json .load (file )
220
+
221
+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
222
+
223
+ return models
247
224
248
225
@classmethod
249
226
def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
@@ -276,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
276
253
if total_size_in_bytes == 0 :
277
254
print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
278
255
279
- # Initialize the progress bar
280
- progress_bar = (
281
- tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True )
282
- if total_size_in_bytes and show_progress
283
- else None
284
- )
256
+ show_progress = total_size_in_bytes and show_progress
285
257
286
- # Attempt to download the file
287
- try :
258
+ with tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True , disable = not show_progress ) as progress_bar :
288
259
with open (output_path , "wb" ) as file :
289
- for chunk in response .iter_content (chunk_size = 1024 ): # Adjust chunk size to your preference
260
+ for chunk in response .iter_content (chunk_size = 1024 ):
290
261
if chunk : # Filter out keep-alive new chunks
291
- if progress_bar is not None :
292
- progress_bar .update (len (chunk ))
262
+ progress_bar .update (len (chunk ))
293
263
file .write (chunk )
294
- except Exception as e :
295
- print (f"An error occurred while trying to download the file: { str (e )} " )
296
- return
297
- finally :
298
- if progress_bar is not None :
299
- progress_bar .close ()
300
264
return output_path
301
265
302
266
@classmethod
303
- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
267
+ def download_files_from_huggingface (cls , model_name : str , cache_dir : Optional [str ] = None ) -> str :
304
268
"""
305
269
Downloads a model from HuggingFace Hub.
306
270
Args:
307
- repod_id (str): The HF hub id (name) of the model to retrieve .
271
+ model_name (str): Name of the model to download .
308
272
cache_dir (Optional[str]): The path to the cache directory.
309
273
Raises:
310
274
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
311
275
Returns:
312
276
Path: The path to the model directory.
313
277
"""
314
- from huggingface_hub import snapshot_download
315
-
316
- return snapshot_download (
317
- repo_id = repod_id ,
318
- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
319
- cache_dir = cache_dir ,
320
- )
278
+ models = cls .list_supported_models (exclude = ["compressed_url_sources" ])
279
+
280
+ hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
281
+
282
+ # Check if the HF sources list is empty
283
+ # Raise an exception causing a fallback to GCS
284
+ if not hf_sources :
285
+ raise ValueError (f"No HuggingFace source for { model_name } " )
286
+
287
+ for index , repo_id in enumerate (hf_sources ):
288
+ try :
289
+ return snapshot_download (
290
+ repo_id = repo_id ,
291
+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
292
+ cache_dir = cache_dir ,
293
+ )
294
+ except (RepositoryNotFoundError , EnvironmentError ) as e :
295
+ logger .exception (f"Failed to download model from HF source: { repo_id } : { e } " )
296
+ if repo_id == hf_sources [- 1 ]:
297
+ raise e
298
+ logger .info (f"Trying another source: { hf_sources [index + 1 ]} " )
321
299
322
300
@classmethod
323
301
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -368,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
368
346
Returns:
369
347
Path: The path to the model directory.
370
348
"""
371
-
372
- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
373
-
374
349
fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
375
350
376
351
model_dir = Path (cache_dir ) / fast_model_name
377
352
if model_dir .exists ():
378
353
return model_dir
379
354
380
355
model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
381
- try :
382
- self .download_file_from_gcs (
383
- f"https://storage.googleapis.com/qdrant-fastembed/{ fast_model_name } .tar.gz" ,
384
- output_path = str (model_tar_gz ),
385
- )
386
- except PermissionError :
387
- simple_model_name = model_name .replace ("/" , "-" )
388
- print (f"Was not able to download { fast_model_name } .tar.gz, trying { simple_model_name } .tar.gz" )
389
- self .download_file_from_gcs (
390
- f"https://storage.googleapis.com/qdrant-fastembed/{ simple_model_name } .tar.gz" ,
391
- output_path = str (model_tar_gz ),
392
- )
356
+
357
+ models = self .list_supported_models (exclude = ["hf_sources" ])
358
+
359
+ compressed_url_sources = [
360
+ item for model in models if model ["model" ] == model_name for item in model ["compressed_url_sources" ]
361
+ ]
362
+
363
+ # Check if the GCS sources list is empty after falling back from HF
364
+ # A model should always have at least one source
365
+ if not compressed_url_sources :
366
+ raise ValueError (f"No GCS source for { model_name } " )
367
+
368
+ for index , source in enumerate (compressed_url_sources ):
369
+ try :
370
+ self .download_file_from_gcs (
371
+ source ,
372
+ output_path = str (model_tar_gz ),
373
+ )
374
+ except (RuntimeError , PermissionError ) as e :
375
+ logger .exception (f"Failed to download model from GCS source: { source } : { e } " )
376
+ if source == compressed_url_sources [- 1 ]:
377
+ raise e
378
+ logger .info (f"Trying another source: { compressed_url_sources [index + 1 ]} " )
393
379
394
380
self .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = cache_dir )
395
381
assert model_dir .exists (), f"Could not find { model_dir } in { cache_dir } "
@@ -398,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
398
384
399
385
return model_dir
400
386
387
+ @gcs_fallback
401
388
def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
402
389
"""
403
390
Retrieves a model from HuggingFace Hub.
404
391
Args:
405
392
model_name (str): The name of the model to retrieve.
406
393
cache_dir (str): The path to the cache directory.
407
- Raises:
408
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
409
394
Returns:
410
395
Path: The path to the model directory.
411
396
"""
412
397
413
- assert (
414
- "/" in model_name
415
- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398
+ return Path (self .download_files_from_huggingface (model_name = model_name , cache_dir = cache_dir ))
399
+
400
+ @classmethod
401
+ def assert_model_name (cls , model_name : str ):
402
+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
416
403
417
- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
404
+ models = cls .list_supported_models ()
405
+ model_names = [model ["model" ] for model in models ]
406
+ if model_name not in model_names :
407
+ raise ValueError (
408
+ f"{ model_name } is not a supported model.\n "
409
+ f"Try one of { ', ' .join (model_names )} .\n "
410
+ f"Use the 'list_supported_models()' method to get the model information."
411
+ )
418
412
419
413
def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
420
414
"""
@@ -475,6 +469,9 @@ def __init__(
475
469
Raises:
476
470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
477
471
"""
472
+
473
+ self .assert_model_name (model_name )
474
+
478
475
self .model_name = model_name
479
476
480
477
if cache_dir is None :
@@ -483,7 +480,7 @@ def __init__(
483
480
cache_dir .mkdir (parents = True , exist_ok = True )
484
481
485
482
self ._cache_dir = cache_dir
486
- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
483
+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
487
484
self ._max_length = max_length
488
485
489
486
self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -536,12 +533,21 @@ def embed(
536
533
yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
537
534
538
535
@classmethod
539
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
540
- """
541
- Lists the supported models.
536
+ def list_supported_models (
537
+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
538
+ ) -> List [Dict [str , Any ]]:
539
+ """Lists the supported models.
540
+
541
+ Args:
542
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543
+
544
+ Returns:
545
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
542
546
"""
543
547
# jina models are not supported by this class
544
- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
548
+ return [
549
+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
550
+ ]
545
551
546
552
547
553
class DefaultEmbedding (FlagEmbedding ):
@@ -591,8 +597,10 @@ def __init__(
591
597
Defaults to `fastembed_cache` in the system's temp directory.
592
598
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
593
599
Raises:
594
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
600
+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
595
601
"""
602
+ self .assert_model_name (model_name )
603
+
596
604
self .model_name = model_name
597
605
598
606
if cache_dir is None :
@@ -652,12 +660,21 @@ def embed(
652
660
yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
653
661
654
662
@classmethod
655
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
656
- """
657
- Lists the supported models.
663
+ def list_supported_models (
664
+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
665
+ ) -> List [Dict [str , Any ]]:
666
+ """Lists the supported models.
667
+
668
+ Args:
669
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670
+
671
+ Returns:
672
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
658
673
"""
659
674
# only jina models are supported by this class
660
- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
675
+ return [
676
+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
677
+ ]
661
678
662
679
@staticmethod
663
680
def mean_pooling (model_output , attention_mask ):
0 commit comments