1+ import functools
12import json
23import os
34import shutil
78from itertools import islice
89from multiprocessing import get_all_start_methods
910from pathlib import Path
10- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
11+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
1112
1213import numpy as np
1314import onnxruntime as ort
1415import requests
1516from tokenizers import AddedToken , Tokenizer
1617from tqdm import tqdm
18+ from huggingface_hub import snapshot_download
19+ from huggingface_hub .utils import RepositoryNotFoundError
20+ from loguru import logger
1721
1822from fastembed .parallel_processor import ParallelWorkerPool , Worker
1923
@@ -179,71 +183,44 @@ class Embedding(ABC):
179183 _type_: _description_
180184 """
181185
186+ # Internal helper decorator to maintain backward compatibility
187+ # by supporting a fallback to download from Google Cloud Storage (GCS)
188+ # if the model couldn't be downloaded from HuggingFace.
189+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
190+ @functools .wraps (hf_download_method )
191+ def wrapper (self , * args , ** kwargs ):
192+ try :
193+ return hf_download_method (self , * args , ** kwargs )
194+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
195+ logger .exception (
196+ f"Could not download model from HuggingFace: { e } "
197+ "Falling back to download from Google Cloud Storage"
198+ )
199+ return self .retrieve_model_gcs (* args , ** kwargs )
200+
201+ return wrapper
202+
182203 @abstractmethod
183204 def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
184205 raise NotImplementedError
185206
186207 @classmethod
187- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
188- """
189- Lists the supported models.
208+ def list_supported_models (cls , exclude : List [str ] = []) -> List [Dict [str , Any ]]:
209+ """Lists the supported models.
210+
211+ Args:
212+ exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213+
214+ Returns:
215+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
190216 """
191- return [
192- {
193- "model" : "BAAI/bge-small-en" ,
194- "dim" : 384 ,
195- "description" : "Fast English model" ,
196- "size_in_GB" : 0.2 ,
197- },
198- {
199- "model" : "BAAI/bge-small-en-v1.5" ,
200- "dim" : 384 ,
201- "description" : "Fast and Default English model" ,
202- "size_in_GB" : 0.13 ,
203- },
204- {
205- "model" : "BAAI/bge-small-zh-v1.5" ,
206- "dim" : 512 ,
207- "description" : "Fast and recommended Chinese model" ,
208- "size_in_GB" : 0.1 ,
209- },
210- {
211- "model" : "BAAI/bge-base-en" ,
212- "dim" : 768 ,
213- "description" : "Base English model" ,
214- "size_in_GB" : 0.5 ,
215- },
216- {
217- "model" : "BAAI/bge-base-en-v1.5" ,
218- "dim" : 768 ,
219- "description" : "Base English model, v1.5" ,
220- "size_in_GB" : 0.44 ,
221- },
222- {
223- "model" : "sentence-transformers/all-MiniLM-L6-v2" ,
224- "dim" : 384 ,
225- "description" : "Sentence Transformer model, MiniLM-L6-v2" ,
226- "size_in_GB" : 0.09 ,
227- },
228- {
229- "model" : "intfloat/multilingual-e5-large" ,
230- "dim" : 1024 ,
231- "description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
232- "size_in_GB" : 2.24 ,
233- },
234- {
235- "model" : "jinaai/jina-embeddings-v2-base-en" ,
236- "dim" : 768 ,
237- "description" : " English embedding model supporting 8192 sequence length" ,
238- "size_in_GB" : 0.55 ,
239- },
240- {
241- "model" : "jinaai/jina-embeddings-v2-small-en" ,
242- "dim" : 512 ,
243- "description" : " English embedding model supporting 8192 sequence length" ,
244- "size_in_GB" : 0.13 ,
245- },
246- ]
217+ models_file_path = Path (__file__ ).with_name ("models.json" )
218+ with open (models_file_path , "r" ) as file :
219+ models = json .load (file )
220+
221+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
222+
223+ return models
247224
248225 @classmethod
249226 def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
@@ -276,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
276253 if total_size_in_bytes == 0 :
277254 print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
278255
279- # Initialize the progress bar
280- progress_bar = (
281- tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True )
282- if total_size_in_bytes and show_progress
283- else None
284- )
256+ show_progress = total_size_in_bytes and show_progress
285257
286- # Attempt to download the file
287- try :
258+ with tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True , disable = not show_progress ) as progress_bar :
288259 with open (output_path , "wb" ) as file :
289- for chunk in response .iter_content (chunk_size = 1024 ): # Adjust chunk size to your preference
260+ for chunk in response .iter_content (chunk_size = 1024 ):
290261 if chunk : # Filter out keep-alive new chunks
291- if progress_bar is not None :
292- progress_bar .update (len (chunk ))
262+ progress_bar .update (len (chunk ))
293263 file .write (chunk )
294- except Exception as e :
295- print (f"An error occurred while trying to download the file: { str (e )} " )
296- return
297- finally :
298- if progress_bar is not None :
299- progress_bar .close ()
300264 return output_path
301265
302266 @classmethod
303- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
267+ def download_files_from_huggingface (cls , model_name : str , cache_dir : Optional [str ] = None ) -> str :
304268 """
305269 Downloads a model from HuggingFace Hub.
306270 Args:
307- repod_id (str): The HF hub id (name) of the model to retrieve .
271+ model_name (str): Name of the model to download .
308272 cache_dir (Optional[str]): The path to the cache directory.
309273 Raises:
310274 ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
311275 Returns:
312276 Path: The path to the model directory.
313277 """
314- from huggingface_hub import snapshot_download
315-
316- return snapshot_download (
317- repo_id = repod_id ,
318- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
319- cache_dir = cache_dir ,
320- )
278+ models = cls .list_supported_models (exclude = ["compressed_url_sources" ])
279+
280+ hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
281+
282+ # Check if the HF sources list is empty
283+ # Raise an exception causing a fallback to GCS
284+ if not hf_sources :
285+ raise ValueError (f"No HuggingFace source for { model_name } " )
286+
287+ for index , repo_id in enumerate (hf_sources ):
288+ try :
289+ return snapshot_download (
290+ repo_id = repo_id ,
291+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
292+ cache_dir = cache_dir ,
293+ )
294+ except (RepositoryNotFoundError , EnvironmentError ) as e :
295+ logger .exception (f"Failed to download model from HF source: { repo_id } : { e } " )
296+ if repo_id == hf_sources [- 1 ]:
297+ raise e
298+ logger .info (f"Trying another source: { hf_sources [index + 1 ]} " )
321299
322300 @classmethod
323301 def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -368,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
368346 Returns:
369347 Path: The path to the model directory.
370348 """
371-
372- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
373-
374349 fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
375350
376351 model_dir = Path (cache_dir ) / fast_model_name
377352 if model_dir .exists ():
378353 return model_dir
379354
380355 model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
381- try :
382- self .download_file_from_gcs (
383- f"https://storage.googleapis.com/qdrant-fastembed/{ fast_model_name } .tar.gz" ,
384- output_path = str (model_tar_gz ),
385- )
386- except PermissionError :
387- simple_model_name = model_name .replace ("/" , "-" )
388- print (f"Was not able to download { fast_model_name } .tar.gz, trying { simple_model_name } .tar.gz" )
389- self .download_file_from_gcs (
390- f"https://storage.googleapis.com/qdrant-fastembed/{ simple_model_name } .tar.gz" ,
391- output_path = str (model_tar_gz ),
392- )
356+
357+ models = self .list_supported_models (exclude = ["hf_sources" ])
358+
359+ compressed_url_sources = [
360+ item for model in models if model ["model" ] == model_name for item in model ["compressed_url_sources" ]
361+ ]
362+
363+ # Check if the GCS sources list is empty after falling back from HF
364+ # A model should always have at least one source
365+ if not compressed_url_sources :
366+ raise ValueError (f"No GCS source for { model_name } " )
367+
368+ for index , source in enumerate (compressed_url_sources ):
369+ try :
370+ self .download_file_from_gcs (
371+ source ,
372+ output_path = str (model_tar_gz ),
373+ )
374+ except (RuntimeError , PermissionError ) as e :
375+ logger .exception (f"Failed to download model from GCS source: { source } : { e } " )
376+ if source == compressed_url_sources [- 1 ]:
377+ raise e
378+ logger .info (f"Trying another source: { compressed_url_sources [index + 1 ]} " )
393379
394380 self .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = cache_dir )
395381 assert model_dir .exists (), f"Could not find { model_dir } in { cache_dir } "
@@ -398,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
398384
399385 return model_dir
400386
387+ @gcs_fallback
401388 def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
402389 """
403390 Retrieves a model from HuggingFace Hub.
404391 Args:
405392 model_name (str): The name of the model to retrieve.
406393 cache_dir (str): The path to the cache directory.
407- Raises:
408- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
409394 Returns:
410395 Path: The path to the model directory.
411396 """
412397
413- assert (
414- "/" in model_name
415- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398+ return Path (self .download_files_from_huggingface (model_name = model_name , cache_dir = cache_dir ))
399+
400+ @classmethod
401+ def assert_model_name (cls , model_name : str ):
402+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
416403
417- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
404+ models = cls .list_supported_models ()
405+ model_names = [model ["model" ] for model in models ]
406+ if model_name not in model_names :
407+ raise ValueError (
408+ f"{ model_name } is not a supported model.\n "
409+ f"Try one of { ', ' .join (model_names )} .\n "
410+ f"Use the 'list_supported_models()' method to get the model information."
411+ )
418412
419413 def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
420414 """
@@ -475,6 +469,9 @@ def __init__(
475469 Raises:
476470 ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
477471 """
472+
473+ self .assert_model_name (model_name )
474+
478475 self .model_name = model_name
479476
480477 if cache_dir is None :
@@ -483,7 +480,7 @@ def __init__(
483480 cache_dir .mkdir (parents = True , exist_ok = True )
484481
485482 self ._cache_dir = cache_dir
486- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
483+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
487484 self ._max_length = max_length
488485
489486 self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -536,12 +533,21 @@ def embed(
536533 yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
537534
538535 @classmethod
539- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
540- """
541- Lists the supported models.
536+ def list_supported_models (
537+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
538+ ) -> List [Dict [str , Any ]]:
539+ """Lists the supported models.
540+
541+ Args:
542+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543+
544+ Returns:
545+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
542546 """
543547 # jina models are not supported by this class
544- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
548+ return [
549+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
550+ ]
545551
546552
547553class DefaultEmbedding (FlagEmbedding ):
@@ -591,8 +597,10 @@ def __init__(
591597 Defaults to `fastembed_cache` in the system's temp directory.
592598 threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
593599 Raises:
594- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
600+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
595601 """
602+ self .assert_model_name (model_name )
603+
596604 self .model_name = model_name
597605
598606 if cache_dir is None :
@@ -652,12 +660,21 @@ def embed(
652660 yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
653661
654662 @classmethod
655- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
656- """
657- Lists the supported models.
663+ def list_supported_models (
664+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
665+ ) -> List [Dict [str , Any ]]:
666+ """Lists the supported models.
667+
668+ Args:
669+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670+
671+ Returns:
672+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
658673 """
659674 # only jina models are supported by this class
660- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
675+ return [
676+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
677+ ]
661678
662679 @staticmethod
663680 def mean_pooling (model_output , attention_mask ):
0 commit comments