1
+ import functools
1
2
import json
3
+ import logging
2
4
import os
3
5
import shutil
4
6
import tarfile
7
9
from itertools import islice
8
10
from multiprocessing import get_all_start_methods
9
11
from pathlib import Path
10
- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
12
+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
11
13
12
14
import numpy as np
13
15
import onnxruntime as ort
14
16
import requests
15
17
from tokenizers import AddedToken , Tokenizer
16
18
from tqdm import tqdm
19
+ from huggingface_hub import snapshot_download
20
+ from huggingface_hub .utils import RepositoryNotFoundError
17
21
18
22
from fastembed .parallel_processor import ParallelWorkerPool , Worker
19
23
24
+ logger = logging .getLogger (__name__ )
25
+
20
26
21
27
def iter_batch (iterable : Union [Iterable , Generator ], size : int ) -> Iterable :
22
28
"""
@@ -174,6 +180,23 @@ class Embedding(ABC):
174
180
_type_: _description_
175
181
"""
176
182
183
+ # Internal helper decorator to maintain backward compatibility
184
+ # by supporting a fallback to download from Google Cloud Storage (GCS)
185
+ # if the model couldn't be downloaded from HuggingFace.
186
+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
187
+ @functools .wraps (hf_download_method )
188
+ def wrapper (self , * args , ** kwargs ):
189
+ try :
190
+ return hf_download_method (self , * args , ** kwargs )
191
+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
192
+ logger .info (
193
+ f"Could not download model from HuggingFace: { e } "
194
+ "Falling back to download from Google Cloud Storage"
195
+ )
196
+ return self .retrieve_model_gcs (* args , ** kwargs )
197
+
198
+ return wrapper
199
+
177
200
@abstractmethod
178
201
def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
179
202
raise NotImplementedError
@@ -295,24 +318,30 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
295
318
return output_path
296
319
297
320
@classmethod
298
- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
321
+ def download_files_from_huggingface (cls , repo_ids : List [ str ] , cache_dir : Optional [str ] = None ) -> str :
299
322
"""
300
323
Downloads a model from HuggingFace Hub.
301
324
Args:
302
- repod_id (str): The HF hub id (name) of the model to retrieve.
325
+ repo_id (str): The HF hub id (name) of the model to retrieve.
303
326
cache_dir (Optional[str]): The path to the cache directory.
304
327
Raises:
305
328
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306
329
Returns:
307
330
Path: The path to the model directory.
308
331
"""
309
- from huggingface_hub import snapshot_download
310
332
311
- return snapshot_download (
312
- repo_id = repod_id ,
313
- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
314
- cache_dir = cache_dir ,
315
- )
333
+ for index , repo_id in enumerate (repo_ids ):
334
+ try :
335
+ return snapshot_download (
336
+ repo_id = repo_id ,
337
+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
338
+ cache_dir = cache_dir ,
339
+ )
340
+ except (RepositoryNotFoundError , EnvironmentError ) as e :
341
+ logger .error (f"Failed to download model from HF source: { repo_id } : { e } " )
342
+ if repo_id == repo_ids [- 1 ]:
343
+ raise e
344
+ logger .info (f"Trying another source: { repo_ids [index + 1 ]} " )
316
345
317
346
@classmethod
318
347
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -363,9 +392,6 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363
392
Returns:
364
393
Path: The path to the model directory.
365
394
"""
366
-
367
- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368
-
369
395
fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
370
396
371
397
model_dir = Path (cache_dir ) / fast_model_name
@@ -393,23 +419,25 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393
419
394
420
return model_dir
395
421
422
+ @gcs_fallback
396
423
def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
397
424
"""
398
425
Retrieves a model from HuggingFace Hub.
399
426
Args:
400
427
model_name (str): The name of the model to retrieve.
401
428
cache_dir (str): The path to the cache directory.
402
- Raises:
403
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404
429
Returns:
405
430
Path: The path to the model directory.
406
431
"""
432
+ models_file_path = Path (__file__ ).with_name ("models.json" )
433
+ models = json .load (open (str (models_file_path )))
407
434
408
- assert (
409
- "/" in model_name
410
- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
435
+ if model_name not in [model ["name" ] for model in models ]:
436
+ raise ValueError (f"Could not find { model_name } in { models_file_path } " )
437
+
438
+ sources = [item for model in models if model ["name" ] == model_name for item in model ["sources" ]]
411
439
412
- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
440
+ return Path (self .download_files_from_huggingface (repo_ids = sources , cache_dir = cache_dir ))
413
441
414
442
def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
415
443
"""
@@ -470,6 +498,8 @@ def __init__(
470
498
Raises:
471
499
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472
500
"""
501
+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
502
+
473
503
self .model_name = model_name
474
504
475
505
if cache_dir is None :
@@ -478,7 +508,7 @@ def __init__(
478
508
cache_dir .mkdir (parents = True , exist_ok = True )
479
509
480
510
self ._cache_dir = cache_dir
481
- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
511
+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
482
512
self ._max_length = max_length
483
513
484
514
self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -586,8 +616,12 @@ def __init__(
586
616
Defaults to `fastembed_cache` in the system's temp directory.
587
617
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588
618
Raises:
589
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
619
+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
590
620
"""
621
+ assert (
622
+ "/" in model_name
623
+ ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"
624
+
591
625
self .model_name = model_name
592
626
593
627
if cache_dir is None :
0 commit comments