Skip to content

Commit 193dcc1

Browse files
committed
feat: HF support for FlagEmbedding
1 parent 3b32619 commit 193dcc1

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

fastembed/embedding.py

+54-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import functools
12
import json
3+
import logging
24
import os
35
import shutil
46
import tarfile
@@ -7,16 +9,20 @@
79
from itertools import islice
810
from multiprocessing import get_all_start_methods
911
from pathlib import Path
10-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
12+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
1113

1214
import numpy as np
1315
import onnxruntime as ort
1416
import requests
1517
from tokenizers import AddedToken, Tokenizer
1618
from tqdm import tqdm
19+
from huggingface_hub import snapshot_download
20+
from huggingface_hub.utils import RepositoryNotFoundError
1721

1822
from fastembed.parallel_processor import ParallelWorkerPool, Worker
1923

24+
logger = logging.getLogger(__name__)
25+
2026

2127
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
2228
"""
@@ -174,6 +180,23 @@ class Embedding(ABC):
174180
_type_: _description_
175181
"""
176182

183+
# Internal helper decorator to maintain backward compatibility
184+
# by supporting a fallback to download from Google Cloud Storage (GCS)
185+
# if the model couldn't be downloaded from HuggingFace.
186+
def gcs_fallback(hf_download_method: Callable) -> Callable:
187+
@functools.wraps(hf_download_method)
188+
def wrapper(self, *args, **kwargs):
189+
try:
190+
return hf_download_method(self, *args, **kwargs)
191+
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
192+
logger.info(
193+
f"Could not download model from HuggingFace: {e}"
194+
"Falling back to download from Google Cloud Storage"
195+
)
196+
return self.retrieve_model_gcs(*args, **kwargs)
197+
198+
return wrapper
199+
177200
@abstractmethod
178201
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
179202
raise NotImplementedError
@@ -295,24 +318,30 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
295318
return output_path
296319

297320
@classmethod
298-
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
321+
def download_files_from_huggingface(cls, repo_ids: list[str], cache_dir: Optional[str] = None) -> str:
299322
"""
300323
Downloads a model from HuggingFace Hub.
301324
Args:
302-
repod_id (str): The HF hub id (name) of the model to retrieve.
325+
repo_id (str): The HF hub id (name) of the model to retrieve.
303326
cache_dir (Optional[str]): The path to the cache directory.
304327
Raises:
305328
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306329
Returns:
307330
Path: The path to the model directory.
308331
"""
309-
from huggingface_hub import snapshot_download
310332

311-
return snapshot_download(
312-
repo_id=repod_id,
313-
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
314-
cache_dir=cache_dir,
315-
)
333+
for index, repo_id in enumerate(repo_ids):
334+
try:
335+
return snapshot_download(
336+
repo_id=repo_id,
337+
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
338+
cache_dir=cache_dir,
339+
)
340+
except (RepositoryNotFoundError, EnvironmentError) as e:
341+
logger.error(f"Failed to download model from HF source: {repo_id}: {e} ")
342+
if repo_id == repo_ids[-1]:
343+
raise e
344+
logger.info(f"Trying another source: {repo_ids[index+1]}")
316345

317346
@classmethod
318347
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
@@ -363,9 +392,6 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363392
Returns:
364393
Path: The path to the model directory.
365394
"""
366-
367-
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368-
369395
fast_model_name = f"fast-{model_name.split('/')[-1]}"
370396

371397
model_dir = Path(cache_dir) / fast_model_name
@@ -393,23 +419,25 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393419

394420
return model_dir
395421

422+
@gcs_fallback
396423
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
397424
"""
398425
Retrieves a model from HuggingFace Hub.
399426
Args:
400427
model_name (str): The name of the model to retrieve.
401428
cache_dir (str): The path to the cache directory.
402-
Raises:
403-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404429
Returns:
405430
Path: The path to the model directory.
406431
"""
432+
models_file_path = Path(__file__).with_name("models.json")
433+
models = json.load(open(str(models_file_path)))
407434

408-
assert (
409-
"/" in model_name
410-
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
435+
if model_name not in [model["name"] for model in models]:
436+
raise ValueError(f"Could not find {model_name} in {models_file_path}")
437+
438+
sources = [item for model in models if model["name"] == model_name for item in model["sources"]]
411439

412-
return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
440+
return Path(self.download_files_from_huggingface(repo_ids=sources, cache_dir=cache_dir))
413441

414442
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
415443
"""
@@ -470,6 +498,8 @@ def __init__(
470498
Raises:
471499
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472500
"""
501+
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
502+
473503
self.model_name = model_name
474504

475505
if cache_dir is None:
@@ -478,7 +508,7 @@ def __init__(
478508
cache_dir.mkdir(parents=True, exist_ok=True)
479509

480510
self._cache_dir = cache_dir
481-
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
511+
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
482512
self._max_length = max_length
483513

484514
self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads)
@@ -586,8 +616,12 @@ def __init__(
586616
Defaults to `fastembed_cache` in the system's temp directory.
587617
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588618
Raises:
589-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
619+
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
590620
"""
621+
assert (
622+
"/" in model_name
623+
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"
624+
591625
self.model_name = model_name
592626

593627
if cache_dir is None:

fastembed/models.json

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
[
2+
{
3+
"name": "BAAI/bge-small-en-v1.5",
4+
"sources": [
5+
"Qdrant/bge-small-en-v1.5-onnx-Q"
6+
]
7+
},
8+
{
9+
"name": "BAAI/bge-base-en-v1.5",
10+
"sources": [
11+
"weakit-v/bge-base-en-v1.5-onnx",
12+
"Qdrant/bge-base-en-v1.5-onnx-Q"
13+
]
14+
},
15+
{
16+
"name": "BAAI/bge-large-en-v1.5",
17+
"sources": [
18+
"Qdrant/bge-large-en-v1.5-onnx",
19+
"Qdrant/bge-large-en-v1.5-onnx-Q"
20+
]
21+
},
22+
{
23+
"name": "sentence-transformers/all-MiniLM-L6-v2",
24+
"sources": [
25+
"Qdrant/all-MiniLM-L6-v2-onnx"
26+
]
27+
},
28+
{
29+
"name": "intfloat/multilingual-e5-large",
30+
"sources": [
31+
"Qdrant/multilingual-e5-large-onnx"
32+
]
33+
},
34+
{
35+
"name": "jinaai/jina-embeddings-v2-base-en",
36+
"sources": [
37+
"jinaai/jina-embeddings-v2-base-en"
38+
]
39+
},
40+
{
41+
"name": "jinaai/jina-embeddings-v2-small-en",
42+
"sources": [
43+
"jinaai/jina-embeddings-v2-small-en"
44+
]
45+
}
46+
]

0 commit comments

Comments
 (0)