1+ import functools
12import json
3+ import logging
24import os
35import shutil
46import tarfile
79from itertools import islice
810from multiprocessing import get_all_start_methods
911from pathlib import Path
10- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
12+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
1113
1214import numpy as np
1315import onnxruntime as ort
1416import requests
1517from tokenizers import AddedToken , Tokenizer
1618from tqdm import tqdm
19+ from huggingface_hub import snapshot_download
20+ from huggingface_hub .utils import RepositoryNotFoundError
1721
1822from fastembed .parallel_processor import ParallelWorkerPool , Worker
1923
24+ logger = logging .getLogger (__name__ )
25+
2026
2127def iter_batch (iterable : Union [Iterable , Generator ], size : int ) -> Iterable :
2228 """
@@ -174,6 +180,23 @@ class Embedding(ABC):
174180 _type_: _description_
175181 """
176182
183+ # Internal helper decorator to maintain backward compatibility
184+ # by supporting a fallback to download from Google Cloud Storage (GCS)
185+ # if the model couldn't be downloaded from HuggingFace.
186+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
187+ @functools .wraps (hf_download_method )
188+ def wrapper (self , * args , ** kwargs ):
189+ try :
190+ return hf_download_method (self , * args , ** kwargs )
191+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
192+ logger .info (
193+ f"Could not download model from HuggingFace: { e } "
194+ "Falling back to download from Google Cloud Storage"
195+ )
196+ return self .retrieve_model_gcs (* args , ** kwargs )
197+
198+ return wrapper
199+
177200 @abstractmethod
178201 def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
179202 raise NotImplementedError
@@ -295,24 +318,30 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
295318 return output_path
296319
297320 @classmethod
298- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
321+ def download_files_from_huggingface (cls , repo_ids : List [ str ] , cache_dir : Optional [str ] = None ) -> str :
299322 """
300323 Downloads a model from HuggingFace Hub.
301324 Args:
302- repod_id (str): The HF hub id (name) of the model to retrieve.
325+ repo_id (str): The HF hub id (name) of the model to retrieve.
303326 cache_dir (Optional[str]): The path to the cache directory.
304327 Raises:
305328 ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306329 Returns:
307330 Path: The path to the model directory.
308331 """
309- from huggingface_hub import snapshot_download
310332
311- return snapshot_download (
312- repo_id = repod_id ,
313- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
314- cache_dir = cache_dir ,
315- )
333+ for index , repo_id in enumerate (repo_ids ):
334+ try :
335+ return snapshot_download (
336+ repo_id = repo_id ,
337+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
338+ cache_dir = cache_dir ,
339+ )
340+ except (RepositoryNotFoundError , EnvironmentError ) as e :
341+ logger .error (f"Failed to download model from HF source: { repo_id } : { e } " )
342+ if repo_id == repo_ids [- 1 ]:
343+ raise e
344+ logger .info (f"Trying another source: { repo_ids [index + 1 ]} " )
316345
317346 @classmethod
318347 def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -363,9 +392,6 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363392 Returns:
364393 Path: The path to the model directory.
365394 """
366-
367- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368-
369395 fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
370396
371397 model_dir = Path (cache_dir ) / fast_model_name
@@ -393,23 +419,25 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393419
394420 return model_dir
395421
422+ @gcs_fallback
396423 def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
397424 """
398425 Retrieves a model from HuggingFace Hub.
399426 Args:
400427 model_name (str): The name of the model to retrieve.
401428 cache_dir (str): The path to the cache directory.
402- Raises:
403- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404429 Returns:
405430 Path: The path to the model directory.
406431 """
432+ models_file_path = Path (__file__ ).with_name ("models.json" )
433+ models = json .load (open (str (models_file_path )))
407434
408- assert (
409- "/" in model_name
410- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
435+ if model_name not in [model ["name" ] for model in models ]:
436+ raise ValueError (f"Could not find { model_name } in { models_file_path } " )
437+
438+ sources = [item for model in models if model ["name" ] == model_name for item in model ["sources" ]]
411439
412- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
440+ return Path (self .download_files_from_huggingface (repo_ids = sources , cache_dir = cache_dir ))
413441
414442 def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
415443 """
@@ -470,6 +498,8 @@ def __init__(
470498 Raises:
471499 ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472500 """
501+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
502+
473503 self .model_name = model_name
474504
475505 if cache_dir is None :
@@ -478,7 +508,7 @@ def __init__(
478508 cache_dir .mkdir (parents = True , exist_ok = True )
479509
480510 self ._cache_dir = cache_dir
481- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
511+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
482512 self ._max_length = max_length
483513
484514 self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -586,8 +616,12 @@ def __init__(
586616 Defaults to `fastembed_cache` in the system's temp directory.
587617 threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588618 Raises:
589- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
619+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
590620 """
621+ assert (
622+ "/" in model_name
623+ ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"
624+
591625 self .model_name = model_name
592626
593627 if cache_dir is None :
0 commit comments