Skip to content

Commit adfc03e

Browse files
Improve models cache progressbar (#406)
* chore: Remove typing hints of Python less than 3.9 * chore: Removed optional from cache as it cannot be undefined * improve: Turned off progress bar of huggingface models if cached
1 parent e9dc3b1 commit adfc03e

8 files changed

+51
-37
lines changed

README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ pip install fastembed-gpu
2828

2929
```python
3030
from fastembed import TextEmbedding
31-
from typing import List
31+
3232

3333
# Example list of documents
34-
documents: List[str] = [
34+
documents: list[str] = [
3535
"This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.",
3636
"fastembed is supported by and maintained by Qdrant.",
3737
]
@@ -139,11 +139,10 @@ embeddings = list(model.embed(images))
139139

140140
### 🔄 Rerankers
141141
```python
142-
from typing import List
143142
from fastembed.rerank.cross_encoder import TextCrossEncoder
144143

145144
query = "Who is maintaining Qdrant?"
146-
documents: List[str] = [
145+
documents: list[str] = [
147146
"This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.",
148147
"fastembed is supported by and maintained by Qdrant.",
149148
]

docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb

+4-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
},
4545
{
4646
"cell_type": "code",
47-
"execution_count": 21,
47+
"execution_count": null,
4848
"metadata": {
4949
"ExecuteTime": {
5050
"end_time": "2024-03-30T00:45:24.814968Z",
@@ -58,8 +58,6 @@
5858
},
5959
"outputs": [],
6060
"source": [
61-
"from typing import List\n",
62-
"\n",
6361
"import numpy as np\n",
6462
"from datasets import load_dataset\n",
6563
"from peft import AutoPeftModelForCausalLM\n",
@@ -72,11 +70,11 @@
7270
},
7371
{
7472
"cell_type": "code",
75-
"execution_count": 23,
73+
"execution_count": null,
7674
"metadata": {},
7775
"outputs": [],
7876
"source": [
79-
"hf_token = <YOUR_HF_TOKEN_HERE> # Get your token from https://huggingface.co/settings/token, needed for Gemma weights"
77+
"hf_token = \"<YOUR_HF_TOKEN_HERE>\" # Get your token from https://huggingface.co/settings/token, needed for Gemma weights"
8078
]
8179
},
8280
{
@@ -246,7 +244,7 @@
246244
},
247245
"outputs": [],
248246
"source": [
249-
"context_embeddings: List[np.ndarray] = list(\n",
247+
"context_embeddings: list[np.ndarray] = list(\n",
250248
" embedding_model.embed(contexts)\n",
251249
") # Note the list() call - this is a generator"
252250
]

docs/examples/SPLADE_with_FastEmbed.ipynb

+14-9
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
},
4848
{
4949
"cell_type": "code",
50-
"execution_count": 2,
50+
"execution_count": null,
5151
"metadata": {
5252
"ExecuteTime": {
5353
"end_time": "2024-03-30T00:49:20.516644Z",
@@ -56,8 +56,7 @@
5656
},
5757
"outputs": [],
5858
"source": [
59-
"from fastembed import SparseTextEmbedding, SparseEmbedding\n",
60-
"from typing import List"
59+
"from fastembed import SparseTextEmbedding, SparseEmbedding"
6160
]
6261
},
6362
{
@@ -134,7 +133,7 @@
134133
},
135134
{
136135
"cell_type": "code",
137-
"execution_count": 5,
136+
"execution_count": null,
138137
"metadata": {
139138
"ExecuteTime": {
140139
"end_time": "2024-03-30T00:49:28.624109Z",
@@ -143,7 +142,7 @@
143142
},
144143
"outputs": [],
145144
"source": [
146-
"documents: List[str] = [\n",
145+
"documents: list[str] = [\n",
147146
" \"Chandrayaan-3 is India's third lunar mission\",\n",
148147
" \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
149148
" \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
@@ -157,7 +156,7 @@
157156
" \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
158157
" \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
159158
"]\n",
160-
"sparse_embeddings_list: List[SparseEmbedding] = list(\n",
159+
"sparse_embeddings_list: list[SparseEmbedding] = list(\n",
161160
" model.embed(documents, batch_size=6)\n",
162161
") # batch_size is optional, notice the generator"
163162
]
@@ -235,7 +234,9 @@
235234
"source": [
236235
"# Let's print the first 5 features and their weights for better understanding.\n",
237236
"for i in range(5):\n",
238-
" print(f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\")"
237+
" print(\n",
238+
" f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\"\n",
239+
" )"
239240
]
240241
},
241242
{
@@ -261,7 +262,9 @@
261262
"import json\n",
262263
"from transformers import AutoTokenizer\n",
263264
"\n",
264-
"tokenizer = AutoTokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"])"
265+
"tokenizer = AutoTokenizer.from_pretrained(\n",
266+
" SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"]\n",
267+
")"
265268
]
266269
},
267270
{
@@ -326,7 +329,9 @@
326329
" token_weight_dict[token] = weight\n",
327330
"\n",
328331
" # Sort the dictionary by weights\n",
329-
" token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))\n",
332+
" token_weight_dict = dict(\n",
333+
" sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True)\n",
334+
" )\n",
330335
" return token_weight_dict\n",
331336
"\n",
332337
"\n",

docs/index.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ pip install fastembed
2626
```python
2727
from fastembed import TextEmbedding
2828

29-
documents: List[str] = [
29+
documents: list[str] = [
3030
"passage: Hello, World!",
3131
"query: Hello, World!",
3232
"passage: This is an example passage.",
3333
"fastembed is supported by and maintained by Qdrant."
3434
]
3535
embedding_model = TextEmbedding()
36-
embeddings: List[np.ndarray] = embedding_model.embed(documents)
36+
embeddings: list[np.ndarray] = embedding_model.embed(documents)
3737
```
3838

3939
## Usage with Qdrant

fastembed/common/model_management.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import requests
99
from huggingface_hub import snapshot_download
10-
from huggingface_hub.utils import RepositoryNotFoundError
10+
from huggingface_hub.utils import (
11+
RepositoryNotFoundError,
12+
disable_progress_bars,
13+
enable_progress_bars,
14+
)
1115
from loguru import logger
1216
from tqdm import tqdm
1317

@@ -93,7 +97,7 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
9397
def download_files_from_huggingface(
9498
cls,
9599
hf_source_repo: str,
96-
cache_dir: Optional[str] = None,
100+
cache_dir: str,
97101
extra_patterns: Optional[list[str]] = None,
98102
local_files_only: bool = False,
99103
**kwargs,
@@ -119,6 +123,12 @@ def download_files_from_huggingface(
119123
if extra_patterns is not None:
120124
allow_patterns.extend(extra_patterns)
121125

126+
snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
127+
is_cached = snapshot_dir.exists()
128+
129+
if is_cached:
130+
disable_progress_bars()
131+
122132
return snapshot_download(
123133
repo_id=hf_source_repo,
124134
allow_patterns=allow_patterns,
@@ -265,6 +275,8 @@ def download_model(
265275
f"Could not download model from HuggingFace: {e} "
266276
"Falling back to other sources."
267277
)
278+
finally:
279+
enable_progress_bars()
268280
if url_source or local_files_only:
269281
try:
270282
return cls.retrieve_model_gcs(

fastembed/common/utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import os
2+
import sys
3+
import re
24
import tempfile
3-
from itertools import islice
5+
import unicodedata
46
from pathlib import Path
7+
from itertools import islice
58
from typing import Generator, Iterable, Optional, Union
6-
import unicodedata
7-
import sys
9+
810
import numpy as np
9-
import re
10-
from typing import Set
1111

1212

1313
def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
@@ -45,7 +45,7 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
4545
return cache_path
4646

4747

48-
def get_all_punctuation() -> Set[str]:
48+
def get_all_punctuation() -> set[str]:
4949
return set(
5050
chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
5151
)

fastembed/image/onnx_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
8181
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
8282
Defaults to False.
83-
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
83+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
8484
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
8585
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
8686
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
@@ -134,7 +134,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
134134
Lists the supported models.
135135
136136
Returns:
137-
List[Dict[str, Any]]: A list of dictionaries containing the model information.
137+
list[Dict[str, Any]]: A list of dictionaries containing the model information.
138138
"""
139139
return supported_onnx_models
140140

tests/profiling.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# %%
1111
import time
12-
from typing import Callable, List, Tuple
12+
from typing import Callable
1313

1414
import matplotlib.pyplot as plt
1515
import torch.nn.functional as F
@@ -23,7 +23,7 @@
2323
# data is a list of strings, each string is a document.
2424

2525
# %%
26-
documents: List[str] = [
26+
documents: list[str] = [
2727
"Chandrayaan-3 is India's third lunar mission",
2828
"It aimed to land a rover on the Moon's surface - joining the US, China and Russia",
2929
"The mission is a follow-up to Chandrayaan-2, which had partial success",
@@ -56,7 +56,7 @@ def __init__(self, model_id: str):
5656
self.model = AutoModel.from_pretrained(model_id)
5757
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
5858

59-
def embed(self, texts: List[str]):
59+
def embed(self, texts: list[str]):
6060
encoded_input = self.tokenizer(
6161
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
6262
)
@@ -88,7 +88,7 @@ def embed(self, texts: List[str]):
8888
# %%
8989
def calculate_time_stats(
9090
embed_func: Callable, documents: list, k: int
91-
) -> Tuple[float, float, float]:
91+
) -> tuple[float, float, float]:
9292
times = []
9393
for _ in range(k):
9494
# Timing the embed_func call
@@ -111,8 +111,8 @@ def calculate_time_stats(
111111

112112
# %%
113113
def plot_character_per_second_comparison(
114-
hf_stats: Tuple[float, float, float],
115-
fst_stats: Tuple[float, float, float],
114+
hf_stats: tuple[float, float, float],
115+
fst_stats: tuple[float, float, float],
116116
documents: list,
117117
):
118118
# Calculating total characters in documents

0 commit comments

Comments
 (0)