Skip to content

Commit 1371984

Browse files
committed
refactor: GCS URLs models.json
1 parent be1d9d6 commit 1371984

File tree

3 files changed

+139
-100
lines changed

3 files changed

+139
-100
lines changed

fastembed/embedding.py

+53-83
Original file line numberDiff line numberDiff line change
@@ -206,62 +206,10 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
206206
"""
207207
Lists the supported models.
208208
"""
209-
return [
210-
{
211-
"model": "BAAI/bge-small-en",
212-
"dim": 384,
213-
"description": "Fast English model",
214-
"size_in_GB": 0.2,
215-
},
216-
{
217-
"model": "BAAI/bge-small-en-v1.5",
218-
"dim": 384,
219-
"description": "Fast and Default English model",
220-
"size_in_GB": 0.13,
221-
},
222-
{
223-
"model": "BAAI/bge-small-zh-v1.5",
224-
"dim": 512,
225-
"description": "Fast and recommended Chinese model",
226-
"size_in_GB": 0.1,
227-
},
228-
{
229-
"model": "BAAI/bge-base-en",
230-
"dim": 768,
231-
"description": "Base English model",
232-
"size_in_GB": 0.5,
233-
},
234-
{
235-
"model": "BAAI/bge-base-en-v1.5",
236-
"dim": 768,
237-
"description": "Base English model, v1.5",
238-
"size_in_GB": 0.44,
239-
},
240-
{
241-
"model": "sentence-transformers/all-MiniLM-L6-v2",
242-
"dim": 384,
243-
"description": "Sentence Transformer model, MiniLM-L6-v2",
244-
"size_in_GB": 0.09,
245-
},
246-
{
247-
"model": "intfloat/multilingual-e5-large",
248-
"dim": 1024,
249-
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
250-
"size_in_GB": 2.24,
251-
},
252-
{
253-
"model": "jinaai/jina-embeddings-v2-base-en",
254-
"dim": 768,
255-
"description": " English embedding model supporting 8192 sequence length",
256-
"size_in_GB": 0.55,
257-
},
258-
{
259-
"model": "jinaai/jina-embeddings-v2-small-en",
260-
"dim": 512,
261-
"description": " English embedding model supporting 8192 sequence length",
262-
"size_in_GB": 0.13,
263-
},
264-
]
209+
models_file_path = Path(__file__).with_name("models.json")
210+
models = json.load(open(str(models_file_path)))
211+
212+
return models
265213

266214
@classmethod
267215
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
@@ -318,19 +266,27 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
318266
return output_path
319267

320268
@classmethod
321-
def download_files_from_huggingface(cls, repo_ids: List[str], cache_dir: Optional[str] = None) -> str:
269+
def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str:
322270
"""
323271
Downloads a model from HuggingFace Hub.
324272
Args:
325-
repo_ids (List[str]): A list of HF model IDs to download.
273+
model_name (str): Name of the model to download.
326274
cache_dir (Optional[str]): The path to the cache directory.
327275
Raises:
328276
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
329277
Returns:
330278
Path: The path to the model directory.
331279
"""
280+
models = cls.list_supported_models()
332281

333-
for index, repo_id in enumerate(repo_ids):
282+
hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]
283+
284+
# Check if the HF sources list is empty
285+
# Raise an exception causing a fallback to GCS
286+
if not hf_sources:
287+
raise ValueError(f"No HuggingFace source for {model_name}")
288+
289+
for index, repo_id in enumerate(hf_sources):
334290
try:
335291
return snapshot_download(
336292
repo_id=repo_id,
@@ -339,9 +295,9 @@ def download_files_from_huggingface(cls, repo_ids: List[str], cache_dir: Optiona
339295
)
340296
except (RepositoryNotFoundError, EnvironmentError) as e:
341297
logger.error(f"Failed to download model from HF source: {repo_id}: {e} ")
342-
if repo_id == repo_ids[-1]:
298+
if repo_id == hf_sources[-1]:
343299
raise e
344-
logger.info(f"Trying another source: {repo_ids[index+1]}")
300+
logger.info(f"Trying another source: {hf_sources[index+1]}")
345301

346302
@classmethod
347303
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
@@ -399,18 +355,27 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
399355
return model_dir
400356

401357
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
402-
try:
403-
self.download_file_from_gcs(
404-
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
405-
output_path=str(model_tar_gz),
406-
)
407-
except PermissionError:
408-
simple_model_name = model_name.replace("/", "-")
409-
print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz")
410-
self.download_file_from_gcs(
411-
f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz",
412-
output_path=str(model_tar_gz),
413-
)
358+
359+
models = self.list_supported_models()
360+
361+
gcs_sources = [item for model in models if model["model"] == model_name for item in model["gcs_sources"]]
362+
363+
# Check if the GCS sources list is empty after falling back from HF
364+
# A model should always have at least one source
365+
if not gcs_sources:
366+
raise ValueError(f"No GCS source for {model_name}")
367+
368+
for index, source in enumerate(gcs_sources):
369+
try:
370+
self.download_file_from_gcs(
371+
f"https://storage.googleapis.com/{source}",
372+
output_path=str(model_tar_gz),
373+
)
374+
except (RuntimeError, PermissionError) as e:
375+
logger.error(f"Failed to download model from GCS source: {source}: {e} ")
376+
if source == gcs_sources[-1]:
377+
raise e
378+
logger.info(f"Trying another source: {gcs_sources[index+1]}")
414379

415380
self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir)
416381
assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}"
@@ -429,15 +394,21 @@ def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
429394
Returns:
430395
Path: The path to the model directory.
431396
"""
432-
models_file_path = Path(__file__).with_name("models.json")
433-
models = json.load(open(str(models_file_path)))
434397

435-
if model_name not in [model["name"] for model in models]:
436-
raise ValueError(f"Could not find {model_name} in {models_file_path}")
398+
return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir))
437399

438-
sources = [item for model in models if model["name"] == model_name for item in model["sources"]]
400+
@classmethod
401+
def assert_model_name(cls, model_name: str):
402+
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
439403

440-
return Path(self.download_files_from_huggingface(repo_ids=sources, cache_dir=cache_dir))
404+
models = cls.list_supported_models()
405+
model_names = [model["model"] for model in models]
406+
if model_name not in model_names:
407+
raise ValueError(
408+
f"{model_name} is not a supported model.\n"
409+
f"Try one of {', '.join(model_names)}.\n"
410+
f"Use the 'list_supported_models()' method to get the model information."
411+
)
441412

442413
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
443414
"""
@@ -498,7 +469,8 @@ def __init__(
498469
Raises:
499470
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
500471
"""
501-
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"
472+
473+
self.assert_model_name(model_name)
502474

503475
self.model_name = model_name
504476

@@ -618,9 +590,7 @@ def __init__(
618590
Raises:
619591
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
620592
"""
621-
assert (
622-
"/" in model_name
623-
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en"
593+
self.assert_model_name(model_name)
624594

625595
self.model_name = model_name
626596

fastembed/models.json

+85-17
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,113 @@
11
[
22
{
3-
"name": "BAAI/bge-small-en-v1.5",
4-
"sources": [
5-
"Qdrant/bge-small-en-v1.5-onnx-Q"
3+
"model": "BAAI/bge-base-en",
4+
"dim": 768,
5+
"description": "Base English model",
6+
"size_in_GB": 0.5,
7+
"hf_sources": [],
8+
"gcs_sources": [
9+
"qdrant-fastembed/fast-bge-base-en.tar.gz"
610
]
711
},
812
{
9-
"name": "BAAI/bge-base-en-v1.5",
10-
"sources": [
13+
"model": "BAAI/bge-base-en-v1.5",
14+
"dim": 768,
15+
"description": "Base English model, v1.5",
16+
"size_in_GB": 0.44,
17+
"hf_sources": [
1118
"Qdrant/bge-base-en-v1.5-onnx-Q"
19+
],
20+
"gcs_sources": [
21+
"qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz"
1222
]
1323
},
1424
{
15-
"name": "BAAI/bge-large-en-v1.5",
16-
"sources": [
25+
"model": "BAAI/bge-large-en-v1.5",
26+
"dim": 1024,
27+
"description": "Large English model, v1.5",
28+
"size_in_GB": 1.34,
29+
"hf_sources": [
1730
"Qdrant/bge-large-en-v1.5-onnx",
1831
"Qdrant/bge-large-en-v1.5-onnx-Q"
32+
],
33+
"gcs_sources": []
34+
},
35+
{
36+
"model": "BAAI/bge-small-en",
37+
"dim": 384,
38+
"description": "Fast English model",
39+
"size_in_GB": 0.2,
40+
"hf_sources": [],
41+
"gcs_sources": [
42+
"qdrant-fastembed/fast-bge-small-en.tar.gz",
43+
"qdrant-fastembed/BAAI-bge-small-en.tar.gz"
1944
]
2045
},
2146
{
22-
"name": "sentence-transformers/all-MiniLM-L6-v2",
23-
"sources": [
24-
"Qdrant/all-MiniLM-L6-v2-onnx"
47+
"model": "BAAI/bge-small-en-v1.5",
48+
"dim": 384,
49+
"description": "Fast and Default English model",
50+
"size_in_GB": 0.13,
51+
"hf_sources": [
52+
"Qdrant/bge-small-en-v1.5-onnx-Q"
53+
],
54+
"gcs_sources": [
55+
"qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz"
56+
]
57+
},
58+
{
59+
"model": "BAAI/bge-small-zh-v1.5",
60+
"dim": 512,
61+
"description": "Fast and recommended Chinese model",
62+
"size_in_GB": 0.1,
63+
"hf_sources": [],
64+
"gcs_sources": [
65+
"qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz"
2566
]
2667
},
2768
{
28-
"name": "intfloat/multilingual-e5-large",
29-
"sources": [
69+
"model": "intfloat/multilingual-e5-large",
70+
"dim": 1024,
71+
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
72+
"size_in_GB": 2.24,
73+
"hf_sources": [
3074
"Qdrant/multilingual-e5-large-onnx"
75+
],
76+
"gcs_sources": [
77+
"qdrant-fastembed/intfloat-multilingual-e5-large.tar.gz"
3178
]
3279
},
3380
{
34-
"name": "jinaai/jina-embeddings-v2-base-en",
35-
"sources": [
81+
"model": "jinaai/jina-embeddings-v2-base-en",
82+
"dim": 768,
83+
"description": " English embedding model supporting 8192 sequence length",
84+
"size_in_GB": 0.55,
85+
"hf_sources": [
3686
"jinaai/jina-embeddings-v2-base-en"
37-
]
87+
],
88+
"gcs_sources": []
3889
},
3990
{
40-
"name": "jinaai/jina-embeddings-v2-small-en",
41-
"sources": [
91+
"model": "jinaai/jina-embeddings-v2-small-en",
92+
"dim": 512,
93+
"description": " English embedding model supporting 8192 sequence length",
94+
"size_in_GB": 0.13,
95+
"hf_sources": [
4296
"jinaai/jina-embeddings-v2-small-en"
97+
],
98+
"gcs_sources": []
99+
},
100+
{
101+
"model": "sentence-transformers/all-MiniLM-L6-v2",
102+
"dim": 384,
103+
"description": "Sentence Transformer model, MiniLM-L6-v2",
104+
"size_in_GB": 0.09,
105+
"hf_sources": [
106+
"Qdrant/all-MiniLM-L6-v2-onnx"
107+
],
108+
"gcs_sources": [
109+
"qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz",
110+
"qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz"
43111
]
44112
}
45113
]

tests/test_onnx_embeddings.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"BAAI/bge-small-zh-v1.5": np.array([-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762]),
1212
"BAAI/bge-base-en": np.array([0.0115, 0.0372, 0.0295, 0.0121, 0.0346]),
1313
"BAAI/bge-base-en-v1.5": np.array([0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045]),
14+
"BAAI/bge-large-en-v1.5": np.array([0.03434538, 0.03316108, 0.02191251, -0.03713358, -0.01577825]),
1415
"sentence-transformers/all-MiniLM-L6-v2": np.array([0.0259, 0.0058, 0.0114, 0.0380, -0.0233]),
1516
"intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]),
1617
"jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]),

0 commit comments

Comments
 (0)