Skip to content

Commit e71073c

Browse files
committed
Merge branch 'main' into onnx-pipeline
2 parents f02bd57 + ede507e commit e71073c

File tree

5 files changed

+689
-280
lines changed

5 files changed

+689
-280
lines changed

fastembed/embedding.py

+142-120
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import json
23
import os
34
import shutil
@@ -7,13 +8,16 @@
78
from itertools import islice
89
from multiprocessing import get_all_start_methods
910
from pathlib import Path
10-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
11+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
1112

1213
import numpy as np
1314
import onnxruntime as ort
1415
import requests
1516
from tokenizers import AddedToken, Tokenizer
1617
from tqdm import tqdm
18+
from huggingface_hub import snapshot_download
19+
from huggingface_hub.utils import RepositoryNotFoundError
20+
from loguru import logger
1721

1822
from fastembed.parallel_processor import ParallelWorkerPool, Worker
1923

@@ -58,9 +62,14 @@ def load_tokenizer(cls, model_dir: Path, max_length: int = 512) -> Tokenizer:
5862
if not tokens_map_path.exists():
5963
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
6064

61-
config = json.load(open(str(config_path)))
62-
tokenizer_config = json.load(open(str(tokenizer_config_path)))
63-
tokens_map = json.load(open(str(tokens_map_path)))
65+
with open(str(config_path)) as config_file:
66+
config = json.load(config_file)
67+
68+
with open(str(tokenizer_config_path)) as tokenizer_config_file:
69+
tokenizer_config = json.load(tokenizer_config_file)
70+
71+
with open(str(tokens_map_path)) as tokens_map_file:
72+
tokens_map = json.load(tokens_map_file)
6473

6574
tokenizer = Tokenizer.from_file(str(tokenizer_path))
6675
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
@@ -174,71 +183,44 @@ class Embedding(ABC):
174183
_type_: _description_
175184
"""
176185

186+
# Internal helper decorator to maintain backward compatibility
187+
# by supporting a fallback to download from Google Cloud Storage (GCS)
188+
# if the model couldn't be downloaded from HuggingFace.
189+
def gcs_fallback(hf_download_method: Callable) -> Callable:
190+
@functools.wraps(hf_download_method)
191+
def wrapper(self, *args, **kwargs):
192+
try:
193+
return hf_download_method(self, *args, **kwargs)
194+
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
195+
logger.exception(
196+
f"Could not download model from HuggingFace: {e}"
197+
"Falling back to download from Google Cloud Storage"
198+
)
199+
return self.retrieve_model_gcs(*args, **kwargs)
200+
201+
return wrapper
202+
177203
@abstractmethod
178204
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
179205
raise NotImplementedError
180206

181207
@classmethod
182-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
183-
"""
184-
Lists the supported models.
208+
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]:
209+
"""Lists the supported models.
210+
211+
Args:
212+
exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
213+
214+
Returns:
215+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
185216
"""
186-
return [
187-
{
188-
"model": "BAAI/bge-small-en",
189-
"dim": 384,
190-
"description": "Fast English model",
191-
"size_in_GB": 0.2,
192-
},
193-
{
194-
"model": "BAAI/bge-small-en-v1.5",
195-
"dim": 384,
196-
"description": "Fast and Default English model",
197-
"size_in_GB": 0.13,
198-
},
199-
{
200-
"model": "BAAI/bge-small-zh-v1.5",
201-
"dim": 512,
202-
"description": "Fast and recommended Chinese model",
203-
"size_in_GB": 0.1,
204-
},
205-
{
206-
"model": "BAAI/bge-base-en",
207-
"dim": 768,
208-
"description": "Base English model",
209-
"size_in_GB": 0.5,
210-
},
211-
{
212-
"model": "BAAI/bge-base-en-v1.5",
213-
"dim": 768,
214-
"description": "Base English model, v1.5",
215-
"size_in_GB": 0.44,
216-
},
217-
{
218-
"model": "sentence-transformers/all-MiniLM-L6-v2",
219-
"dim": 384,
220-
"description": "Sentence Transformer model, MiniLM-L6-v2",
221-
"size_in_GB": 0.09,
222-
},
223-
{
224-
"model": "intfloat/multilingual-e5-large",
225-
"dim": 1024,
226-
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
227-
"size_in_GB": 2.24,
228-
},
229-
{
230-
"model": "jinaai/jina-embeddings-v2-base-en",
231-
"dim": 768,
232-
"description": " English embedding model supporting 8192 sequence length",
233-
"size_in_GB": 0.55,
234-
},
235-
{
236-
"model": "jinaai/jina-embeddings-v2-small-en",
237-
"dim": 512,
238-
"description": " English embedding model supporting 8192 sequence length",
239-
"size_in_GB": 0.13,
240-
},
241-
]
217+
models_file_path = Path(__file__).with_name("models.json")
218+
with open(models_file_path, "r") as file:
219+
models = json.load(file)
220+
221+
models = [{k: v for k, v in model.items() if k not in exclude} for model in models]
222+
223+
return models
242224

243225
@classmethod
244226
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
@@ -271,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
271253
if total_size_in_bytes == 0:
272254
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
273255

274-
# Initialize the progress bar
275-
progress_bar = (
276-
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
277-
if total_size_in_bytes and show_progress
278-
else None
279-
)
256+
show_progress = total_size_in_bytes and show_progress
280257

281-
# Attempt to download the file
282-
try:
258+
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
283259
with open(output_path, "wb") as file:
284-
for chunk in response.iter_content(chunk_size=1024): # Adjust chunk size to your preference
260+
for chunk in response.iter_content(chunk_size=1024):
285261
if chunk: # Filter out keep-alive new chunks
286-
if progress_bar is not None:
287-
progress_bar.update(len(chunk))
262+
progress_bar.update(len(chunk))
288263
file.write(chunk)
289-
except Exception as e:
290-
print(f"An error occurred while trying to download the file: {str(e)}")
291-
return
292-
finally:
293-
if progress_bar is not None:
294-
progress_bar.close()
295264
return output_path
296265

297266
@classmethod
298-
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
267+
def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str:
299268
"""
300269
Downloads a model from HuggingFace Hub.
301270
Args:
302-
repod_id (str): The HF hub id (name) of the model to retrieve.
271+
model_name (str): Name of the model to download.
303272
cache_dir (Optional[str]): The path to the cache directory.
304273
Raises:
305274
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
306275
Returns:
307276
Path: The path to the model directory.
308277
"""
309-
from huggingface_hub import snapshot_download
310-
311-
return snapshot_download(
312-
repo_id=repod_id,
313-
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
314-
cache_dir=cache_dir,
315-
)
278+
models = cls.list_supported_models(exclude=["compressed_url_sources"])
279+
280+
hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]
281+
282+
# Check if the HF sources list is empty
283+
# Raise an exception causing a fallback to GCS
284+
if not hf_sources:
285+
raise ValueError(f"No HuggingFace source for {model_name}")
286+
287+
for index, repo_id in enumerate(hf_sources):
288+
try:
289+
return snapshot_download(
290+
repo_id=repo_id,
291+
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
292+
cache_dir=cache_dir,
293+
)
294+
except (RepositoryNotFoundError, EnvironmentError) as e:
295+
logger.exception(f"Failed to download model from HF source: {repo_id}: {e} ")
296+
if repo_id == hf_sources[-1]:
297+
raise e
298+
logger.info(f"Trying another source: {hf_sources[index+1]}")
316299

317300
@classmethod
318301
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
@@ -363,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
363346
Returns:
364347
Path: The path to the model directory.
365348
"""
366-
367-
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
368-
369349
fast_model_name = f"fast-{model_name.split('/')[-1]}"
370350

371351
model_dir = Path(cache_dir) / fast_model_name
372352
if model_dir.exists():
373353
return model_dir
374354

375355
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
376-
try:
377-
self.download_file_from_gcs(
378-
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
379-
output_path=str(model_tar_gz),
380-
)
381-
except PermissionError:
382-
simple_model_name = model_name.replace("/", "-")
383-
print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz")
384-
self.download_file_from_gcs(
385-
f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz",
386-
output_path=str(model_tar_gz),
387-
)
356+
357+
models = self.list_supported_models(exclude=["hf_sources"])
358+
359+
compressed_url_sources = [
360+
item for model in models if model["model"] == model_name for item in model["compressed_url_sources"]
361+
]
362+
363+
# Check if the GCS sources list is empty after falling back from HF
364+
# A model should always have at least one source
365+
if not compressed_url_sources:
366+
raise ValueError(f"No GCS source for {model_name}")
367+
368+
for index, source in enumerate(compressed_url_sources):
369+
try:
370+
self.download_file_from_gcs(
371+
source,
372+
output_path=str(model_tar_gz),
373+
)
374+
except (RuntimeError, PermissionError) as e:
375+
logger.exception(f"Failed to download model from GCS source: {source}: {e} ")
376+
if source == compressed_url_sources[-1]:
377+
raise e
378+
logger.info(f"Trying another source: {compressed_url_sources[index+1]}")
388379

389380
self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir)
390381
assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}"
@@ -393,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
393384

394385
return model_dir
395386

387+
@gcs_fallback
396388
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
397389
"""
398390
Retrieves a model from HuggingFace Hub.
399391
Args:
400392
model_name (str): The name of the model to retrieve.
401393
cache_dir (str): The path to the cache directory.
402-
Raises:
403-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
404394
Returns:
405395
Path: The path to the model directory.
406396
"""
407397

408-
assert (
409-
"/" in model_name
410-
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
398+
return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir))
411399

412-
return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
400+
@classmethod
401+
def assert_model_name(cls, model_name: str):
402+
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
403+
404+
models = cls.list_supported_models()
405+
model_names = [model["model"] for model in models]
406+
if model_name not in model_names:
407+
raise ValueError(
408+
f"{model_name} is not a supported model.\n"
409+
f"Try one of {', '.join(model_names)}.\n"
410+
f"Use the 'list_supported_models()' method to get the model information."
411+
)
413412

414413
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
415414
"""
@@ -470,6 +469,9 @@ def __init__(
470469
Raises:
471470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
472471
"""
472+
473+
self.assert_model_name(model_name)
474+
473475
self.model_name = model_name
474476

475477
if cache_dir is None:
@@ -478,7 +480,7 @@ def __init__(
478480
cache_dir.mkdir(parents=True, exist_ok=True)
479481

480482
self._cache_dir = cache_dir
481-
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
483+
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
482484
self._max_length = max_length
483485

484486
self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads)
@@ -531,12 +533,21 @@ def embed(
531533
yield from normalize(embeddings[:, 0]).astype(np.float32)
532534

533535
@classmethod
534-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
535-
"""
536-
Lists the supported models.
536+
def list_supported_models(
537+
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
538+
) -> List[Dict[str, Any]]:
539+
"""Lists the supported models.
540+
541+
Args:
542+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
543+
544+
Returns:
545+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
537546
"""
538547
# jina models are not supported by this class
539-
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
548+
return [
549+
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
550+
]
540551

541552

542553
class DefaultEmbedding(FlagEmbedding):
@@ -586,8 +597,10 @@ def __init__(
586597
Defaults to `fastembed_cache` in the system's temp directory.
587598
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
588599
Raises:
589-
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
600+
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
590601
"""
602+
self.assert_model_name(model_name)
603+
591604
self.model_name = model_name
592605

593606
if cache_dir is None:
@@ -647,12 +660,21 @@ def embed(
647660
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
648661

649662
@classmethod
650-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
651-
"""
652-
Lists the supported models.
663+
def list_supported_models(
664+
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
665+
) -> List[Dict[str, Any]]:
666+
"""Lists the supported models.
667+
668+
Args:
669+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
670+
671+
Returns:
672+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
653673
"""
654674
# only jina models are supported by this class
655-
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
675+
return [
676+
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
677+
]
656678

657679
@staticmethod
658680
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)