1
+ import functools
1
2
import json
2
3
import os
3
4
import shutil
7
8
from itertools import islice
8
9
from multiprocessing import get_all_start_methods
9
10
from pathlib import Path
10
- from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
11
+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Union
11
12
12
13
import numpy as np
13
14
import onnxruntime as ort
14
15
import requests
15
16
from tokenizers import AddedToken , Tokenizer
16
17
from tqdm import tqdm
18
+ from huggingface_hub import snapshot_download
19
+ from huggingface_hub .utils import RepositoryNotFoundError
20
+ from loguru import logger
17
21
18
22
from fastembed .parallel_processor import ParallelWorkerPool , Worker
19
23
@@ -58,9 +62,14 @@ def load_tokenizer(cls, model_dir: Path, max_length: int = 512) -> Tokenizer:
58
62
if not tokens_map_path .exists ():
59
63
raise ValueError (f"Could not find special_tokens_map.json in { model_dir } " )
60
64
61
- config = json .load (open (str (config_path )))
62
- tokenizer_config = json .load (open (str (tokenizer_config_path )))
63
- tokens_map = json .load (open (str (tokens_map_path )))
65
+ with open (str (config_path )) as config_file :
66
+ config = json .load (config_file )
67
+
68
+ with open (str (tokenizer_config_path )) as tokenizer_config_file :
69
+ tokenizer_config = json .load (tokenizer_config_file )
70
+
71
+ with open (str (tokens_map_path )) as tokens_map_file :
72
+ tokens_map = json .load (tokens_map_file )
64
73
65
74
tokenizer = Tokenizer .from_file (str (tokenizer_path ))
66
75
tokenizer .enable_truncation (max_length = min (tokenizer_config ["model_max_length" ], max_length ))
@@ -174,71 +183,44 @@ class Embedding(ABC):
174
183
_type_: _description_
175
184
"""
176
185
186
+ # Internal helper decorator to maintain backward compatibility
187
+ # by supporting a fallback to download from Google Cloud Storage (GCS)
188
+ # if the model couldn't be downloaded from HuggingFace.
189
+ def gcs_fallback (hf_download_method : Callable ) -> Callable :
190
+ @functools .wraps (hf_download_method )
191
+ def wrapper (self , * args , ** kwargs ):
192
+ try :
193
+ return hf_download_method (self , * args , ** kwargs )
194
+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
195
+ logger .exception (
196
+ f"Could not download model from HuggingFace: { e } "
197
+ "Falling back to download from Google Cloud Storage"
198
+ )
199
+ return self .retrieve_model_gcs (* args , ** kwargs )
200
+
201
+ return wrapper
202
+
177
203
@abstractmethod
178
204
def embed (self , texts : Iterable [str ], batch_size : int = 256 , parallel : int = None ) -> List [np .ndarray ]:
179
205
raise NotImplementedError
180
206
181
207
@classmethod
182
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
183
- """
184
- Lists the supported models.
208
+ def list_supported_models (cls , exclude : List [str ] = []) -> List [Dict [str , Any ]]:
209
+ """Lists the supported models.
210
+
211
+ Args:
212
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213
+
214
+ Returns:
215
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
185
216
"""
186
- return [
187
- {
188
- "model" : "BAAI/bge-small-en" ,
189
- "dim" : 384 ,
190
- "description" : "Fast English model" ,
191
- "size_in_GB" : 0.2 ,
192
- },
193
- {
194
- "model" : "BAAI/bge-small-en-v1.5" ,
195
- "dim" : 384 ,
196
- "description" : "Fast and Default English model" ,
197
- "size_in_GB" : 0.13 ,
198
- },
199
- {
200
- "model" : "BAAI/bge-small-zh-v1.5" ,
201
- "dim" : 512 ,
202
- "description" : "Fast and recommended Chinese model" ,
203
- "size_in_GB" : 0.1 ,
204
- },
205
- {
206
- "model" : "BAAI/bge-base-en" ,
207
- "dim" : 768 ,
208
- "description" : "Base English model" ,
209
- "size_in_GB" : 0.5 ,
210
- },
211
- {
212
- "model" : "BAAI/bge-base-en-v1.5" ,
213
- "dim" : 768 ,
214
- "description" : "Base English model, v1.5" ,
215
- "size_in_GB" : 0.44 ,
216
- },
217
- {
218
- "model" : "sentence-transformers/all-MiniLM-L6-v2" ,
219
- "dim" : 384 ,
220
- "description" : "Sentence Transformer model, MiniLM-L6-v2" ,
221
- "size_in_GB" : 0.09 ,
222
- },
223
- {
224
- "model" : "intfloat/multilingual-e5-large" ,
225
- "dim" : 1024 ,
226
- "description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
227
- "size_in_GB" : 2.24 ,
228
- },
229
- {
230
- "model" : "jinaai/jina-embeddings-v2-base-en" ,
231
- "dim" : 768 ,
232
- "description" : " English embedding model supporting 8192 sequence length" ,
233
- "size_in_GB" : 0.55 ,
234
- },
235
- {
236
- "model" : "jinaai/jina-embeddings-v2-small-en" ,
237
- "dim" : 512 ,
238
- "description" : " English embedding model supporting 8192 sequence length" ,
239
- "size_in_GB" : 0.13 ,
240
- },
241
- ]
217
+ models_file_path = Path (__file__ ).with_name ("models.json" )
218
+ with open (models_file_path , "r" ) as file :
219
+ models = json .load (file )
220
+
221
+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
222
+
223
+ return models
242
224
243
225
@classmethod
244
226
def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
@@ -271,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
271
253
if total_size_in_bytes == 0 :
272
254
print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
273
255
274
- # Initialize the progress bar
275
- progress_bar = (
276
- tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True )
277
- if total_size_in_bytes and show_progress
278
- else None
279
- )
256
+ show_progress = total_size_in_bytes and show_progress
280
257
281
- # Attempt to download the file
282
- try :
258
+ with tqdm (total = total_size_in_bytes , unit = "iB" , unit_scale = True , disable = not show_progress ) as progress_bar :
283
259
with open (output_path , "wb" ) as file :
284
- for chunk in response .iter_content (chunk_size = 1024 ): # Adjust chunk size to your preference
260
+ for chunk in response .iter_content (chunk_size = 1024 ):
285
261
if chunk : # Filter out keep-alive new chunks
286
- if progress_bar is not None :
287
- progress_bar .update (len (chunk ))
262
+ progress_bar .update (len (chunk ))
288
263
file .write (chunk )
289
- except Exception as e :
290
- print (f"An error occurred while trying to download the file: { str (e )} " )
291
- return
292
- finally :
293
- if progress_bar is not None :
294
- progress_bar .close ()
295
264
return output_path
296
265
297
266
@classmethod
298
- def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
267
+ def download_files_from_huggingface (cls , model_name : str , cache_dir : Optional [str ] = None ) -> str :
299
268
"""
300
269
Downloads a model from HuggingFace Hub.
301
270
Args:
302
- repod_id (str): The HF hub id (name) of the model to retrieve .
271
+ model_name (str): Name of the model to download .
303
272
cache_dir (Optional[str]): The path to the cache directory.
304
273
Raises:
305
274
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306
275
Returns:
307
276
Path: The path to the model directory.
308
277
"""
309
- from huggingface_hub import snapshot_download
310
-
311
- return snapshot_download (
312
- repo_id = repod_id ,
313
- ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
314
- cache_dir = cache_dir ,
315
- )
278
+ models = cls .list_supported_models (exclude = ["compressed_url_sources" ])
279
+
280
+ hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
281
+
282
+ # Check if the HF sources list is empty
283
+ # Raise an exception causing a fallback to GCS
284
+ if not hf_sources :
285
+ raise ValueError (f"No HuggingFace source for { model_name } " )
286
+
287
+ for index , repo_id in enumerate (hf_sources ):
288
+ try :
289
+ return snapshot_download (
290
+ repo_id = repo_id ,
291
+ ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ],
292
+ cache_dir = cache_dir ,
293
+ )
294
+ except (RepositoryNotFoundError , EnvironmentError ) as e :
295
+ logger .exception (f"Failed to download model from HF source: { repo_id } : { e } " )
296
+ if repo_id == hf_sources [- 1 ]:
297
+ raise e
298
+ logger .info (f"Trying another source: { hf_sources [index + 1 ]} " )
316
299
317
300
@classmethod
318
301
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
@@ -363,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363
346
Returns:
364
347
Path: The path to the model directory.
365
348
"""
366
-
367
- assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368
-
369
349
fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
370
350
371
351
model_dir = Path (cache_dir ) / fast_model_name
372
352
if model_dir .exists ():
373
353
return model_dir
374
354
375
355
model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
376
- try :
377
- self .download_file_from_gcs (
378
- f"https://storage.googleapis.com/qdrant-fastembed/{ fast_model_name } .tar.gz" ,
379
- output_path = str (model_tar_gz ),
380
- )
381
- except PermissionError :
382
- simple_model_name = model_name .replace ("/" , "-" )
383
- print (f"Was not able to download { fast_model_name } .tar.gz, trying { simple_model_name } .tar.gz" )
384
- self .download_file_from_gcs (
385
- f"https://storage.googleapis.com/qdrant-fastembed/{ simple_model_name } .tar.gz" ,
386
- output_path = str (model_tar_gz ),
387
- )
356
+
357
+ models = self .list_supported_models (exclude = ["hf_sources" ])
358
+
359
+ compressed_url_sources = [
360
+ item for model in models if model ["model" ] == model_name for item in model ["compressed_url_sources" ]
361
+ ]
362
+
363
+ # Check if the GCS sources list is empty after falling back from HF
364
+ # A model should always have at least one source
365
+ if not compressed_url_sources :
366
+ raise ValueError (f"No GCS source for { model_name } " )
367
+
368
+ for index , source in enumerate (compressed_url_sources ):
369
+ try :
370
+ self .download_file_from_gcs (
371
+ source ,
372
+ output_path = str (model_tar_gz ),
373
+ )
374
+ except (RuntimeError , PermissionError ) as e :
375
+ logger .exception (f"Failed to download model from GCS source: { source } : { e } " )
376
+ if source == compressed_url_sources [- 1 ]:
377
+ raise e
378
+ logger .info (f"Trying another source: { compressed_url_sources [index + 1 ]} " )
388
379
389
380
self .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = cache_dir )
390
381
assert model_dir .exists (), f"Could not find { model_dir } in { cache_dir } "
@@ -393,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393
384
394
385
return model_dir
395
386
387
+ @gcs_fallback
396
388
def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
397
389
"""
398
390
Retrieves a model from HuggingFace Hub.
399
391
Args:
400
392
model_name (str): The name of the model to retrieve.
401
393
cache_dir (str): The path to the cache directory.
402
- Raises:
403
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404
394
Returns:
405
395
Path: The path to the model directory.
406
396
"""
407
397
408
- assert (
409
- "/" in model_name
410
- ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398
+ return Path (self .download_files_from_huggingface (model_name = model_name , cache_dir = cache_dir ))
411
399
412
- return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
400
+ @classmethod
401
+ def assert_model_name (cls , model_name : str ):
402
+ assert "/" in model_name , "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
403
+
404
+ models = cls .list_supported_models ()
405
+ model_names = [model ["model" ] for model in models ]
406
+ if model_name not in model_names :
407
+ raise ValueError (
408
+ f"{ model_name } is not a supported model.\n "
409
+ f"Try one of { ', ' .join (model_names )} .\n "
410
+ f"Use the 'list_supported_models()' method to get the model information."
411
+ )
413
412
414
413
def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
415
414
"""
@@ -470,6 +469,9 @@ def __init__(
470
469
Raises:
471
470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472
471
"""
472
+
473
+ self .assert_model_name (model_name )
474
+
473
475
self .model_name = model_name
474
476
475
477
if cache_dir is None :
@@ -478,7 +480,7 @@ def __init__(
478
480
cache_dir .mkdir (parents = True , exist_ok = True )
479
481
480
482
self ._cache_dir = cache_dir
481
- self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
483
+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
482
484
self ._max_length = max_length
483
485
484
486
self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length , max_threads = threads )
@@ -531,12 +533,21 @@ def embed(
531
533
yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
532
534
533
535
@classmethod
534
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
535
- """
536
- Lists the supported models.
536
+ def list_supported_models (
537
+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
538
+ ) -> List [Dict [str , Any ]]:
539
+ """Lists the supported models.
540
+
541
+ Args:
542
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543
+
544
+ Returns:
545
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
537
546
"""
538
547
# jina models are not supported by this class
539
- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
548
+ return [
549
+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
550
+ ]
540
551
541
552
542
553
class DefaultEmbedding (FlagEmbedding ):
@@ -586,8 +597,10 @@ def __init__(
586
597
Defaults to `fastembed_cache` in the system's temp directory.
587
598
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588
599
Raises:
589
- ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge -base-en.
600
+ ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2 -base-en.
590
601
"""
602
+ self .assert_model_name (model_name )
603
+
591
604
self .model_name = model_name
592
605
593
606
if cache_dir is None :
@@ -647,12 +660,21 @@ def embed(
647
660
yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
648
661
649
662
@classmethod
650
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
651
- """
652
- Lists the supported models.
663
+ def list_supported_models (
664
+ cls , exclude : List [str ] = ["compressed_url_sources" , "hf_sources" ]
665
+ ) -> List [Dict [str , Any ]]:
666
+ """Lists the supported models.
667
+
668
+ Args:
669
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670
+
671
+ Returns:
672
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
653
673
"""
654
674
# only jina models are supported by this class
655
- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
675
+ return [
676
+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
677
+ ]
656
678
657
679
@staticmethod
658
680
def mean_pooling (model_output , attention_mask ):
0 commit comments