Skip to content

Commit ede507e

Browse files
authored
feat: HuggingFace download support for FlagEmbedding (#94)
* feat: HF support for FlagEmbedding * chore: update docstring embedding.py * refactor: GCS URLs models.json * chore: toLower() models.json * chore: update tqdm declarative * chore: exclude keys list_supported_models * chore: review changes
1 parent f87330f commit ede507e

File tree

5 files changed

+303
-119
lines changed

5 files changed

+303
-119
lines changed

fastembed/embedding.py

+134-117
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import json
23
import os
34
import shutil
@@ -7,13 +8,16 @@
78
from itertools import islice
89
from multiprocessing import get_all_start_methods
910
from pathlib import Path
10-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
11+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
1112

1213
import numpy as np
1314
import onnxruntime as ort
1415
import requests
1516
from tokenizers import AddedToken, Tokenizer
1617
from tqdm import tqdm
18+
from huggingface_hub import snapshot_download
19+
from huggingface_hub.utils import RepositoryNotFoundError
20+
from loguru import logger
1721

1822
from fastembed.parallel_processor import ParallelWorkerPool, Worker
1923

@@ -179,71 +183,44 @@ class Embedding(ABC):
179183
_type_: _description_
180184
"""
181185

186+
# Internal helper decorator to maintain backward compatibility
187+
# by supporting a fallback to download from Google Cloud Storage (GCS)
188+
# if the model couldn't be downloaded from HuggingFace.
189+
def gcs_fallback(hf_download_method: Callable) -> Callable:
190+
@functools.wraps(hf_download_method)
191+
def wrapper(self, *args, **kwargs):
192+
try:
193+
return hf_download_method(self, *args, **kwargs)
194+
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
195+
logger.exception(
196+
f"Could not download model from HuggingFace: {e}"
197+
"Falling back to download from Google Cloud Storage"
198+
)
199+
return self.retrieve_model_gcs(*args, **kwargs)
200+
201+
return wrapper
202+
182203
@abstractmethod
183204
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
184205
raise NotImplementedError
185206

186207
@classmethod
187-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
188-
"""
189-
Lists the supported models.
208+
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]:
209+
"""Lists the supported models.
210+
211+
Args:
212+
exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213+
214+
Returns:
215+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
190216
"""
191-
return [
192-
{
193-
"model": "BAAI/bge-small-en",
194-
"dim": 384,
195-
"description": "Fast English model",
196-
"size_in_GB": 0.2,
197-
},
198-
{
199-
"model": "BAAI/bge-small-en-v1.5",
200-
"dim": 384,
201-
"description": "Fast and Default English model",
202-
"size_in_GB": 0.13,
203-
},
204-
{
205-
"model": "BAAI/bge-small-zh-v1.5",
206-
"dim": 512,
207-
"description": "Fast and recommended Chinese model",
208-
"size_in_GB": 0.1,
209-
},
210-
{
211-
"model": "BAAI/bge-base-en",
212-
"dim": 768,
213-
"description": "Base English model",
214-
"size_in_GB": 0.5,
215-
},
216-
{
217-
"model": "BAAI/bge-base-en-v1.5",
218-
"dim": 768,
219-
"description": "Base English model, v1.5",
220-
"size_in_GB": 0.44,
221-
},
222-
{
223-
"model": "sentence-transformers/all-MiniLM-L6-v2",
224-
"dim": 384,
225-
"description": "Sentence Transformer model, MiniLM-L6-v2",
226-
"size_in_GB": 0.09,
227-
},
228-
{
229-
"model": "intfloat/multilingual-e5-large",
230-
"dim": 1024,
231-
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
232-
"size_in_GB": 2.24,
233-
},
234-
{
235-
"model": "jinaai/jina-embeddings-v2-base-en",
236-
"dim": 768,
237-
"description": " English embedding model supporting 8192 sequence length",
238-
"size_in_GB": 0.55,
239-
},
240-
{
241-
"model": "jinaai/jina-embeddings-v2-small-en",
242-
"dim": 512,
243-
"description": " English embedding model supporting 8192 sequence length",
244-
"size_in_GB": 0.13,
245-
},
246-
]
217+
models_file_path = Path(__file__).with_name("models.json")
218+
with open(models_file_path, "r") as file:
219+
models = json.load(file)
220+
221+
models = [{k: v for k, v in model.items() if k not in exclude} for model in models]
222+
223+
return models
247224

248225
@classmethod
249226
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
@@ -276,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
276253
if total_size_in_bytes == 0:
277254
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
278255

279-
# Initialize the progress bar
280-
progress_bar = (
281-
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
282-
if total_size_in_bytes and show_progress
283-
else None
284-
)
256+
show_progress = total_size_in_bytes and show_progress
285257

286-
# Attempt to download the file
287-
try:
258+
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
288259
with open(output_path, "wb") as file:
289-
for chunk in response.iter_content(chunk_size=1024): # Adjust chunk size to your preference
260+
for chunk in response.iter_content(chunk_size=1024):
290261
if chunk: # Filter out keep-alive new chunks
291-
if progress_bar is not None:
292-
progress_bar.update(len(chunk))
262+
progress_bar.update(len(chunk))
293263
file.write(chunk)
294-
except Exception as e:
295-
print(f"An error occurred while trying to download the file: {str(e)}")
296-
return
297-
finally:
298-
if progress_bar is not None:
299-
progress_bar.close()
300264
return output_path
301265

302266
@classmethod
303-
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
267+
def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str:
304268
"""
305269
Downloads a model from HuggingFace Hub.
306270
Args:
307-
repod_id (str): The HF hub id (name) of the model to retrieve.
271+
model_name (str): Name of the model to download.
308272
cache_dir (Optional[str]): The path to the cache directory.
309273
Raises:
310274
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
311275
Returns:
312276
Path: The path to the model directory.
313277
"""
314-
from huggingface_hub import snapshot_download
315-
316-
return snapshot_download(
317-
repo_id=repod_id,
318-
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
319-
cache_dir=cache_dir,
320-
)
278+
models = cls.list_supported_models(exclude=["compressed_url_sources"])
279+
280+
hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]
281+
282+
# Check if the HF sources list is empty
283+
# Raise an exception causing a fallback to GCS
284+
if not hf_sources:
285+
raise ValueError(f"No HuggingFace source for {model_name}")
286+
287+
for index, repo_id in enumerate(hf_sources):
288+
try:
289+
return snapshot_download(
290+
repo_id=repo_id,
291+
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
292+
cache_dir=cache_dir,
293+
)
294+
except (RepositoryNotFoundError, EnvironmentError) as e:
295+
logger.exception(f"Failed to download model from HF source: {repo_id}: {e} ")
296+
if repo_id == hf_sources[-1]:
297+
raise e
298+
logger.info(f"Trying another source: {hf_sources[index+1]}")
321299

322300
@classmethod
323301
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
@@ -368,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
368346
Returns:
369347
Path: The path to the model directory.
370348
"""
371-
372-
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
373-
374349
fast_model_name = f"fast-{model_name.split('/')[-1]}"
375350

376351
model_dir = Path(cache_dir) / fast_model_name
377352
if model_dir.exists():
378353
return model_dir
379354

380355
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
381-
try:
382-
self.download_file_from_gcs(
383-
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
384-
output_path=str(model_tar_gz),
385-
)
386-
except PermissionError:
387-
simple_model_name = model_name.replace("/", "-")
388-
print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz")
389-
self.download_file_from_gcs(
390-
f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz",
391-
output_path=str(model_tar_gz),
392-
)
356+
357+
models = self.list_supported_models(exclude=["hf_sources"])
358+
359+
compressed_url_sources = [
360+
item for model in models if model["model"] == model_name for item in model["compressed_url_sources"]
361+
]
362+
363+
# Check if the GCS sources list is empty after falling back from HF
364+
# A model should always have at least one source
365+
if not compressed_url_sources:
366+
raise ValueError(f"No GCS source for {model_name}")
367+
368+
for index, source in enumerate(compressed_url_sources):
369+
try:
370+
self.download_file_from_gcs(
371+
source,
372+
output_path=str(model_tar_gz),
373+
)
374+
except (RuntimeError, PermissionError) as e:
375+
logger.exception(f"Failed to download model from GCS source: {source}: {e} ")
376+
if source == compressed_url_sources[-1]:
377+
raise e
378+
logger.info(f"Trying another source: {compressed_url_sources[index+1]}")
393379

394380
self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir)
395381
assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}"
@@ -398,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
398384

399385
return model_dir
400386

387+
@gcs_fallback
401388
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
402389
"""
403390
Retrieves a model from HuggingFace Hub.
404391
Args:
405392
model_name (str): The name of the model to retrieve.
406393
cache_dir (str): The path to the cache directory.
407-
Raises:
408-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
409394
Returns:
410395
Path: The path to the model directory.
411396
"""
412397

413-
assert (
414-
"/" in model_name
415-
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398+
return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir))
399+
400+
@classmethod
401+
def assert_model_name(cls, model_name: str):
402+
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
416403

417-
return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
404+
models = cls.list_supported_models()
405+
model_names = [model["model"] for model in models]
406+
if model_name not in model_names:
407+
raise ValueError(
408+
f"{model_name} is not a supported model.\n"
409+
f"Try one of {', '.join(model_names)}.\n"
410+
f"Use the 'list_supported_models()' method to get the model information."
411+
)
418412

419413
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
420414
"""
@@ -475,6 +469,9 @@ def __init__(
475469
Raises:
476470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
477471
"""
472+
473+
self.assert_model_name(model_name)
474+
478475
self.model_name = model_name
479476

480477
if cache_dir is None:
@@ -483,7 +480,7 @@ def __init__(
483480
cache_dir.mkdir(parents=True, exist_ok=True)
484481

485482
self._cache_dir = cache_dir
486-
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
483+
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
487484
self._max_length = max_length
488485

489486
self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads)
@@ -536,12 +533,21 @@ def embed(
536533
yield from normalize(embeddings[:, 0]).astype(np.float32)
537534

538535
@classmethod
539-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
540-
"""
541-
Lists the supported models.
536+
def list_supported_models(
537+
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
538+
) -> List[Dict[str, Any]]:
539+
"""Lists the supported models.
540+
541+
Args:
542+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543+
544+
Returns:
545+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
542546
"""
543547
# jina models are not supported by this class
544-
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
548+
return [
549+
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
550+
]
545551

546552

547553
class DefaultEmbedding(FlagEmbedding):
@@ -591,8 +597,10 @@ def __init__(
591597
Defaults to `fastembed_cache` in the system's temp directory.
592598
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
593599
Raises:
594-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
600+
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
595601
"""
602+
self.assert_model_name(model_name)
603+
596604
self.model_name = model_name
597605

598606
if cache_dir is None:
@@ -652,12 +660,21 @@ def embed(
652660
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
653661

654662
@classmethod
655-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
656-
"""
657-
Lists the supported models.
663+
def list_supported_models(
664+
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
665+
) -> List[Dict[str, Any]]:
666+
"""Lists the supported models.
667+
668+
Args:
669+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670+
671+
Returns:
672+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
658673
"""
659674
# only jina models are supported by this class
660-
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
675+
return [
676+
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
677+
]
661678

662679
@staticmethod
663680
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)