1+ import functools
12import json
23import os
34import shutil
78from itertools import islice
89from multiprocessing import get_all_start_methods
910from pathlib import Path
10- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
11+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
1112
1213import numpy as np
1314import onnxruntime as ort
1415import requests
1516from tokenizers import AddedToken , Tokenizer
1617from tqdm import tqdm
18+ from huggingface_hub import snapshot_download
19+ from huggingface_hub .utils import RepositoryNotFoundError
20+ from loguru import logger
1721
1822from fastembed .parallel_processor import ParallelWorkerPool , Worker
1923
@@ -58,9 +62,14 @@ def load_tokenizer(cls, model_dir: Path, max_length: int = 512) -> Tokenizer:
5862 if not tokens_map_path .exists ():
5963 raise ValueError (f"Could not find special_tokens_map.json in { model_dir } " )
6064
61- config = json .load (open (str (config_path )))
62- tokenizer_config = json .load (open (str (tokenizer_config_path )))
63- tokens_map = json .load (open (str (tokens_map_path )))
65+ with open (str (config_path )) as config_file :
66+ config = json .load (config_file )
67+
68+ with open (str (tokenizer_config_path )) as tokenizer_config_file :
69+ tokenizer_config = json .load (tokenizer_config_file )
70+
71+ with open (str (tokens_map_path )) as tokens_map_file :
72+ tokens_map = json .load (tokens_map_file )
6473
6574 tokenizer = Tokenizer .from_file (str (tokenizer_path ))
6675 tokenizer .enable_truncation (max_length = min (tokenizer_config ["model_max_length" ], max_length ))
@@ -174,71 +183,44 @@ class Embedding(ABC):
174183 _type_: _description_
175184 """
176185
186+ # Internal helper decorator to maintain backward compatibility
187+ # by supporting a fallback to download from Google Cloud Storage (GCS)
188+ # if the model couldn't be downloaded from HuggingFace.
189+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
190+ @functools .wraps (hf_download_method )
191+ def wrapper (self , * args , ** kwargs ):
192+ try :
193+ return hf_download_method (self , * args , ** kwargs )
194+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
195+ logger .exception (
196+ f"Could not download model from HuggingFace: { e } "
197+ "Falling back to download from Google Cloud Storage"
198+ )
199+ return self .retrieve_model_gcs (* args , ** kwargs )
200+
201+ return wrapper
202+
177203 @abstractmethod
178204 def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
179205 raise NotImplementedError
180206
181207 @classmethod
182- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
183- """
184- Lists the supported models.
208+ def list_supported_models (cls , exclude : List [str ] = []) -> List [Dict [str , Any ]]:
209+ """Lists the supported models.
210+
211+ Args:
212+ exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213+
214+ Returns:
215+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
185216 """
186- return [
187- {
188- "model" : "BAAI/bge-small-en" ,
189- "dim" : 384 ,
190- "description" : "Fast English model" ,
191- "size_in_GB" : 0.2 ,
192- },
193- {
194- "model" : "BAAI/bge-small-en-v1.5" ,
195- "dim" : 384 ,
196- "description" : "Fast and Default English model" ,
197- "size_in_GB" : 0.13 ,
198- },
199- {
200- "model" : "BAAI/bge-small-zh-v1.5" ,
201- "dim" : 512 ,
202- "description" : "Fast and recommended Chinese model" ,
203- "size_in_GB" : 0.1 ,
204- },
205- {
206- "model" : "BAAI/bge-base-en" ,
207- "dim" : 768 ,
208- "description" : "Base English model" ,
209- "size_in_GB" : 0.5 ,
210- },
211- {
212- "model" : "BAAI/bge-base-en-v1.5" ,
213- "dim" : 768 ,
214- "description" : "Base English model, v1.5" ,
215- "size_in_GB" : 0.44 ,
216- },
217- {
218- "model" : "sentence-transformers/all-MiniLM-L6-v2" ,
219- "dim" : 384 ,
220- "description" : "Sentence Transformer model, MiniLM-L6-v2" ,
221- "size_in_GB" : 0.09 ,
222- },
223- {
224- "model" : "intfloat/multilingual-e5-large" ,
225- "dim" : 1024 ,
226- "description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
227- "size_in_GB" : 2.24 ,
228- },
229- {
230- "model" : "jinaai/jina-embeddings-v2-base-en" ,
231- "dim" : 768 ,
232- "description" : " English embedding model supporting 8192 sequence length" ,
233- "size_in_GB" : 0.55 ,
234- },
235- {
236- "model" : "jinaai/jina-embeddings-v2-small-en" ,
237- "dim" : 512 ,
238- "description" : " English embedding model supporting 8192 sequence length" ,
239- "size_in_GB" : 0.13 ,
240- },
241- ]
217+ models_file_path = Path (__file__ ).with_name ("models.json" )
218+ with open (models_file_path , "r" ) as file :
219+ models = json .load (file )
220+
221+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
222+
223+ return models
242224
243225 @classmethod
244226 def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
@@ -271,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
271253 if total_size_in_bytes == 0 :
272254 print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
273255
274- # Initialize the progress bar
275- progress_bar = (
276- tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True )
277- if total_size_in_bytes and show_progress
278- else None
279- )
256+ show_progress = total_size_in_bytes and show_progress
280257
281- # Attempt to download the file
282- try :
258+ with tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True , disable = not show_progress ) as progress_bar :
283259 with open (output_path , "wb" ) as file :
284- for chunk in response .iter_content (chunk_size = 1024 ): # Adjust chunk size to your preference
260+ for chunk in response .iter_content (chunk_size = 1024 ):
285261 if chunk : # Filter out keep-alive new chunks
286- if progress_bar is not None :
287- progress_bar .update (len (chunk ))
262+ progress_bar .update (len (chunk ))
288263 file .write (chunk )
289- except Exception as e :
290- print (f"An error occurred while trying to download the file: { str (e )} " )
291- return
292- finally :
293- if progress_bar is not None :
294- progress_bar .close ()
295264 return output_path
296265
297266 @classmethod
298- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
267+ def download_files_from_huggingface (cls , model_name : str , cache_dir : Optional [str ] = None ) -> str :
299268 """
300269 Downloads a model from HuggingFace Hub.
301270 Args:
302- repod_id (str): The HF hub id (name) of the model to retrieve .
271+ model_name (str): Name of the model to download .
303272 cache_dir (Optional[str]): The path to the cache directory.
304273 Raises:
305274 ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306275 Returns:
307276 Path: The path to the model directory.
308277 """
309- from huggingface_hub import snapshot_download
310-
311- return snapshot_download (
312- repo_id = repod_id ,
313- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
314- cache_dir = cache_dir ,
315- )
278+ models = cls .list_supported_models (exclude = ["compressed_url_sources" ])
279+
280+ hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
281+
282+ # Check if the HF sources list is empty
283+ # Raise an exception causing a fallback to GCS
284+ if not hf_sources :
285+ raise ValueError (f"No HuggingFace source for { model_name } " )
286+
287+ for index , repo_id in enumerate (hf_sources ):
288+ try :
289+ return snapshot_download (
290+ repo_id = repo_id ,
291+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
292+ cache_dir = cache_dir ,
293+ )
294+ except (RepositoryNotFoundError , EnvironmentError ) as e :
295+ logger .exception (f"Failed to download model from HF source: { repo_id } : { e } " )
296+ if repo_id == hf_sources [- 1 ]:
297+ raise e
298+ logger .info (f"Trying another source: { hf_sources [index + 1 ]} " )
316299
317300 @classmethod
318301 def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -363,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363346 Returns:
364347 Path: The path to the model directory.
365348 """
366-
367- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368-
369349 fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
370350
371351 model_dir = Path (cache_dir ) / fast_model_name
372352 if model_dir .exists ():
373353 return model_dir
374354
375355 model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
376- try :
377- self .download_file_from_gcs (
378- f"https://storage.googleapis.com/qdrant-fastembed/{ fast_model_name } .tar.gz" ,
379- output_path = str (model_tar_gz ),
380- )
381- except PermissionError :
382- simple_model_name = model_name .replace ("/" , "-" )
383- print (f"Was not able to download { fast_model_name } .tar.gz, trying { simple_model_name } .tar.gz" )
384- self .download_file_from_gcs (
385- f"https://storage.googleapis.com/qdrant-fastembed/{ simple_model_name } .tar.gz" ,
386- output_path = str (model_tar_gz ),
387- )
356+
357+ models = self .list_supported_models (exclude = ["hf_sources" ])
358+
359+ compressed_url_sources = [
360+ item for model in models if model ["model" ] == model_name for item in model ["compressed_url_sources" ]
361+ ]
362+
363+ # Check if the GCS sources list is empty after falling back from HF
364+ # A model should always have at least one source
365+ if not compressed_url_sources :
366+ raise ValueError (f"No GCS source for { model_name } " )
367+
368+ for index , source in enumerate (compressed_url_sources ):
369+ try :
370+ self .download_file_from_gcs (
371+ source ,
372+ output_path = str (model_tar_gz ),
373+ )
374+ except (RuntimeError , PermissionError ) as e :
375+ logger .exception (f"Failed to download model from GCS source: { source } : { e } " )
376+ if source == compressed_url_sources [- 1 ]:
377+ raise e
378+ logger .info (f"Trying another source: { compressed_url_sources [index + 1 ]} " )
388379
389380 self .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = cache_dir )
390381 assert model_dir .exists (), f"Could not find { model_dir } in { cache_dir } "
@@ -393,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393384
394385 return model_dir
395386
387+ @gcs_fallback
396388 def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
397389 """
398390 Retrieves a model from HuggingFace Hub.
399391 Args:
400392 model_name (str): The name of the model to retrieve.
401393 cache_dir (str): The path to the cache directory.
402- Raises:
403- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404394 Returns:
405395 Path: The path to the model directory.
406396 """
407397
408- assert (
409- "/" in model_name
410- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398+ return Path (self .download_files_from_huggingface (model_name = model_name , cache_dir = cache_dir ))
411399
412- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
400+ @classmethod
401+ def assert_model_name (cls , model_name : str ):
402+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
403+
404+ models = cls .list_supported_models ()
405+ model_names = [model ["model" ] for model in models ]
406+ if model_name not in model_names :
407+ raise ValueError (
408+ f"{ model_name } is not a supported model.\n "
409+ f"Try one of { ', ' .join (model_names )} .\n "
410+ f"Use the 'list_supported_models()' method to get the model information."
411+ )
413412
414413 def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
415414 """
@@ -470,6 +469,9 @@ def __init__(
470469 Raises:
471470 ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472471 """
472+
473+ self .assert_model_name (model_name )
474+
473475 self .model_name = model_name
474476
475477 if cache_dir is None :
@@ -478,7 +480,7 @@ def __init__(
478480 cache_dir .mkdir (parents = True , exist_ok = True )
479481
480482 self ._cache_dir = cache_dir
481- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
483+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
482484 self ._max_length = max_length
483485
484486 self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -531,12 +533,21 @@ def embed(
531533 yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
532534
533535 @classmethod
534- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
535- """
536- Lists the supported models.
536+ def list_supported_models (
537+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
538+ ) -> List [Dict [str , Any ]]:
539+ """Lists the supported models.
540+
541+ Args:
542+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543+
544+ Returns:
545+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
537546 """
538547 # jina models are not supported by this class
539- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
548+ return [
549+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
550+ ]
540551
541552
542553class DefaultEmbedding (FlagEmbedding ):
@@ -586,8 +597,10 @@ def __init__(
586597 Defaults to `fastembed_cache` in the system's temp directory.
587598 threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588599 Raises:
589- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
600+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
590601 """
602+ self .assert_model_name (model_name )
603+
591604 self .model_name = model_name
592605
593606 if cache_dir is None :
@@ -647,12 +660,21 @@ def embed(
647660 yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
648661
649662 @classmethod
650- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
651- """
652- Lists the supported models.
663+ def list_supported_models (
664+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
665+ ) -> List [Dict [str , Any ]]:
666+ """Lists the supported models.
667+
668+ Args:
669+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670+
671+ Returns:
672+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
653673 """
654674 # only jina models are supported by this class
655- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
675+ return [
676+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
677+ ]
656678
657679 @staticmethod
658680 def mean_pooling (model_output , attention_mask ):
0 commit comments