Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HuggingFace download support for FlagEmbedding #94

Merged
merged 8 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import json
import logging
import os
import shutil
import tarfile
Expand All @@ -7,16 +9,20 @@
from itertools import islice
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union

import numpy as np
import onnxruntime as ort
import requests
from tokenizers import AddedToken, Tokenizer
from tqdm import tqdm
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError

from fastembed.parallel_processor import ParallelWorkerPool, Worker

logger = logging.getLogger(__name__)


def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""
Expand Down Expand Up @@ -174,6 +180,23 @@ class Embedding(ABC):
_type_: _description_
"""

# Internal helper decorator to maintain backward compatibility
# by supporting a fallback to download from Google Cloud Storage (GCS)
# if the model couldn't be downloaded from HuggingFace.
def gcs_fallback(hf_download_method: Callable) -> Callable:
@functools.wraps(hf_download_method)
def wrapper(self, *args, **kwargs):
try:
return hf_download_method(self, *args, **kwargs)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.info(
f"Could not download model from HuggingFace: {e}"
"Falling back to download from Google Cloud Storage"
)
return self.retrieve_model_gcs(*args, **kwargs)

return wrapper

@abstractmethod
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
raise NotImplementedError
Expand Down Expand Up @@ -295,24 +318,30 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
return output_path

@classmethod
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
def download_files_from_huggingface(cls, repo_ids: List[str], cache_dir: Optional[str] = None) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
repod_id (str): The HF hub id (name) of the model to retrieve.
repo_id (str): The HF hub id (name) of the model to retrieve.
cache_dir (Optional[str]): The path to the cache directory.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
Returns:
Path: The path to the model directory.
"""
from huggingface_hub import snapshot_download

return snapshot_download(
repo_id=repod_id,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)
for index, repo_id in enumerate(repo_ids):
try:
return snapshot_download(
repo_id=repo_id,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)
except (RepositoryNotFoundError, EnvironmentError) as e:
logger.error(f"Failed to download model from HF source: {repo_id}: {e} ")
if repo_id == repo_ids[-1]:
raise e
logger.info(f"Trying another source: {repo_ids[index+1]}")

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
Expand Down Expand Up @@ -363,9 +392,6 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
Returns:
Path: The path to the model directory.
"""

assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"

fast_model_name = f"fast-{model_name.split('/')[-1]}"

model_dir = Path(cache_dir) / fast_model_name
Expand Down Expand Up @@ -393,23 +419,25 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:

return model_dir

@gcs_fallback
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
"""
Retrieves a model from HuggingFace Hub.
Args:
model_name (str): The name of the model to retrieve.
cache_dir (str): The path to the cache directory.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
Returns:
Path: The path to the model directory.
"""
models_file_path = Path(__file__).with_name("models.json")
models = json.load(open(str(models_file_path)))

assert (
"/" in model_name
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
if model_name not in [model["name"] for model in models]:
raise ValueError(f"Could not find {model_name} in {models_file_path}")

sources = [item for model in models if model["name"] == model_name for item in model["sources"]]

return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
return Path(self.download_files_from_huggingface(repo_ids=sources, cache_dir=cache_dir))

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
"""
Expand Down Expand Up @@ -470,6 +498,8 @@ def __init__(
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"

self.model_name = model_name

if cache_dir is None:
Expand All @@ -478,7 +508,7 @@ def __init__(
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
self._max_length = max_length

self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads)
Expand Down Expand Up @@ -586,8 +616,12 @@ def __init__(
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
"""
assert (
"/" in model_name
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"

self.model_name = model_name

if cache_dir is None:
Expand Down
45 changes: 45 additions & 0 deletions fastembed/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
[
{
"name": "BAAI/bge-small-en-v1.5",
"sources": [
"Qdrant/bge-small-en-v1.5-onnx-Q"
]
},
{
"name": "BAAI/bge-base-en-v1.5",
"sources": [
"Qdrant/bge-base-en-v1.5-onnx-Q"
]
},
{
"name": "BAAI/bge-large-en-v1.5",
"sources": [
"Qdrant/bge-large-en-v1.5-onnx",
"Qdrant/bge-large-en-v1.5-onnx-Q"
]
},
{
"name": "sentence-transformers/all-MiniLM-L6-v2",
"sources": [
"Qdrant/all-MiniLM-L6-v2-onnx"
]
},
{
"name": "intfloat/multilingual-e5-large",
"sources": [
"Qdrant/multilingual-e5-large-onnx"
]
},
{
"name": "jinaai/jina-embeddings-v2-base-en",
"sources": [
"jinaai/jina-embeddings-v2-base-en"
]
},
{
"name": "jinaai/jina-embeddings-v2-small-en",
"sources": [
"jinaai/jina-embeddings-v2-small-en"
]
}
]