Skip to content

Commit 3312079

Browse files
I8dNLod.rudenko
and
d.rudenko
authored
MiniLM fix (#275)
* MiniLM fix * Added MiniLM to text embedding Fixed MiniLM source destination Black + isort for repo * Fixed model all-MiniLM-L6-v2 description Recomputed canonical vector for all-MiniLM-L6-v2 in test --------- Co-authored-by: d.rudenko <[email protected]>
1 parent fd0b26f commit 3312079

39 files changed

+462
-193
lines changed

experiments/attention_export.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
# print("Model already exported")
1313
# except FileNotFoundError:
1414
print(f"Exporting model to {output_dir}")
15-
main_export(model_id, output=output_dir, no_post_process=True, model_kwargs=model_kwargs)
15+
main_export(
16+
model_id, output=output_dir, no_post_process=True, model_kwargs=model_kwargs
17+
)

experiments/try_attention_export.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
attention_mask = tokenizer_output["attention_mask"]
1818
print(attention_mask)
1919
# Prepare the input
20-
input_ids = np.array(input_ids).astype(np.int64) # Replace your_input_ids with actual input data
20+
input_ids = np.array(input_ids).astype(
21+
np.int64
22+
) # Replace your_input_ids with actual input data
2123

2224
# Run the ONNX model
23-
outputs = ort_session.run(None, {"input_ids": input_ids, "attention_mask": attention_mask})
25+
outputs = ort_session.run(
26+
None, {"input_ids": input_ids, "attention_mask": attention_mask}
27+
)
2428

2529
# Get the attention weights
2630
attentions = outputs[-1]

fastembed/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import importlib.metadata
22

33
from fastembed.image import ImageEmbedding
4-
from fastembed.text import TextEmbedding
5-
from fastembed.sparse import SparseTextEmbedding, SparseEmbedding
64
from fastembed.late_interaction import LateInteractionTextEmbedding
5+
from fastembed.sparse import SparseEmbedding, SparseTextEmbedding
6+
from fastembed.text import TextEmbedding
77

88
try:
99
version = importlib.metadata.version("fastembed")

fastembed/common/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from fastembed.common.types import OnnxProvider, ImageInput, PathInput
1+
from fastembed.common.types import ImageInput, OnnxProvider, PathInput
22

33
__all__ = ["OnnxProvider", "ImageInput", "PathInput"]

fastembed/common/model_management.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import shutil
33
import tarfile
44
from pathlib import Path
5-
from typing import List, Optional, Dict, Any
5+
from typing import Any, Dict, List, Optional
66

77
import requests
88
from huggingface_hub import snapshot_download
99
from huggingface_hub.utils import RepositoryNotFoundError
10-
from tqdm import tqdm
1110
from loguru import logger
11+
from tqdm import tqdm
1212

1313

1414
class ModelManagement:
@@ -42,7 +42,9 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
4242
raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
4343

4444
@classmethod
45-
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
45+
def download_file_from_gcs(
46+
cls, url: str, output_path: str, show_progress: bool = True
47+
) -> str:
4648
"""
4749
Downloads a file from Google Cloud Storage.
4850
@@ -71,12 +73,17 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
7173

7274
# Warn if the total size is zero
7375
if total_size_in_bytes == 0:
74-
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
76+
print(
77+
f"Warning: Content-length header is missing or zero in the response from {url}."
78+
)
7579

7680
show_progress = total_size_in_bytes and show_progress
7781

7882
with tqdm(
79-
total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress
83+
total=total_size_in_bytes,
84+
unit="iB",
85+
unit_scale=True,
86+
disable=not show_progress,
8087
) as progress_bar:
8188
with open(output_path, "wb") as file:
8289
for chunk in response.iter_content(chunk_size=1024):
@@ -156,7 +163,9 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
156163
return cache_dir
157164

158165
@classmethod
159-
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
166+
def retrieve_model_gcs(
167+
cls, model_name: str, source_url: str, cache_dir: str
168+
) -> Path:
160169
fast_model_name = f"fast-{model_name.split('/')[-1]}"
161170

162171
cache_tmp_dir = Path(cache_dir) / "tmp"
@@ -182,8 +191,12 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) ->
182191
output_path=str(model_tar_gz),
183192
)
184193

185-
cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
186-
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
194+
cls.decompress_to_cache(
195+
targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir)
196+
)
197+
assert (
198+
model_tmp_dir.exists()
199+
), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
187200

188201
model_tar_gz.unlink()
189202
# Rename from tmp to final name is atomic

fastembed/common/onnx_model.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
import warnings
12
from dataclasses import dataclass
23
from pathlib import Path
3-
from typing import Any, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Sequence
4-
import warnings
4+
from typing import (
5+
Any,
6+
Dict,
7+
Generic,
8+
Iterable,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Type,
13+
TypeVar,
14+
)
515

616
import numpy as np
717
import onnxruntime as ort
818

919
from fastembed.common.types import OnnxProvider
1020
from fastembed.parallel_processor import Worker
1121

12-
1322
# Holds type of the embedding result
1423
T = TypeVar("T")
1524

@@ -51,7 +60,9 @@ def load_onnx_model(
5160
model_path = model_dir / model_file
5261
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
5362

54-
onnx_providers = ["CPUExecutionProvider"] if providers is None else list(providers)
63+
onnx_providers = (
64+
["CPUExecutionProvider"] if providers is None else list(providers)
65+
)
5566
available_providers = ort.get_available_providers()
5667
requested_provider_names = []
5768
for provider in onnx_providers:

fastembed/common/preprocessor_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from typing import Tuple
44

5-
from tokenizers import Tokenizer, AddedToken
5+
from tokenizers import AddedToken, Tokenizer
66

77
from fastembed.image.transform.operators import Compose
88

@@ -40,7 +40,9 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, d
4040
tokens_map = load_special_tokens(model_dir)
4141

4242
tokenizer = Tokenizer.from_file(str(tokenizer_path))
43-
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
43+
tokenizer.enable_truncation(
44+
max_length=min(tokenizer_config["model_max_length"], max_length)
45+
)
4446
tokenizer.enable_padding(
4547
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
4648
)

fastembed/common/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import sys
3-
from typing import Union, Iterable, Tuple, Dict, Any
3+
from typing import Any, Dict, Iterable, Tuple, Union
44

55
if sys.version_info >= (3, 10):
66
from typing import TypeAlias

fastembed/common/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
from itertools import islice
44
from pathlib import Path
5-
from typing import Union, Iterable, Generator, Optional
5+
from typing import Generator, Iterable, Optional, Union
66

77
import numpy as np
88

fastembed/image/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from fastembed.image.image_embedding import ImageEmbedding
22

3-
43
__all__ = ["ImageEmbedding"]

fastembed/image/image_embedding.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Iterable, List, Optional, Type, Sequence
1+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type
22

33
import numpy as np
44

@@ -51,9 +51,16 @@ def __init__(
5151

5252
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
5353
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
54-
if any(model_name.lower() == model["model"].lower() for model in supported_models):
54+
if any(
55+
model_name.lower() == model["model"].lower()
56+
for model in supported_models
57+
):
5558
self.model = EMBEDDING_MODEL_TYPE(
56-
model_name, cache_dir, threads=threads, providers=providers, **kwargs
59+
model_name,
60+
cache_dir,
61+
threads=threads,
62+
providers=providers,
63+
**kwargs,
5764
)
5865
return
5966

fastembed/image/onnx_embedding.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Dict, Optional, Iterable, Type, List, Any, Sequence
1+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type
22

33
import numpy as np
44

5-
from fastembed.common.onnx_model import OnnxOutputContext
6-
from fastembed.common.utils import normalize, define_cache_dir
75
from fastembed.common import ImageInput, OnnxProvider
6+
from fastembed.common.onnx_model import OnnxOutputContext
7+
from fastembed.common.utils import define_cache_dir, normalize
88
from fastembed.image.image_embedding_base import ImageEmbeddingBase
9-
from fastembed.image.onnx_image_model import OnnxImageModel, ImageEmbeddingWorker
9+
from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel
1010

1111
supported_onnx_models = [
1212
{
@@ -122,10 +122,16 @@ def _preprocess_onnx_input(
122122

123123
return onnx_input
124124

125-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
125+
def _post_process_onnx_output(
126+
self, output: OnnxOutputContext
127+
) -> Iterable[np.ndarray]:
126128
return normalize(output.model_output).astype(np.float32)
127129

128130

129131
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
130-
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
131-
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
132+
def init_embedding(
133+
self, model_name: str, cache_dir: str, **kwargs
134+
) -> OnnxImageEmbedding:
135+
return OnnxImageEmbedding(
136+
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
137+
)

fastembed/image/onnx_image_model.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import os
21
import contextlib
2+
import os
33
from multiprocessing import get_all_start_methods
44
from pathlib import Path
5-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Sequence
5+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
66

7-
from PIL import Image
87
import numpy as np
8+
from PIL import Image
99

10+
from fastembed.common import ImageInput, OnnxProvider, PathInput
11+
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1012
from fastembed.common.preprocessor_utils import load_preprocessor
11-
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, T, OnnxOutputContext
12-
from fastembed.common import PathInput, ImageInput, OnnxProvider
1313
from fastembed.common.utils import iter_batch
1414
from fastembed.parallel_processor import ParallelWorkerPool
1515

@@ -44,7 +44,10 @@ def load_onnx_model(
4444
providers: Optional[Sequence[OnnxProvider]] = None,
4545
) -> None:
4646
super().load_onnx_model(
47-
model_dir=model_dir, model_file=model_file, threads=threads, providers=providers
47+
model_dir=model_dir,
48+
model_file=model_file,
49+
threads=threads,
50+
providers=providers,
4851
)
4952
self.processor = load_preprocessor(model_dir=model_dir)
5053

@@ -87,7 +90,9 @@ def _embed_images(
8790
for batch in iter_batch(images, batch_size):
8891
yield from self._post_process_onnx_output(self.onnx_embed(batch))
8992
else:
90-
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
93+
start_method = (
94+
"forkserver" if "forkserver" in get_all_start_methods() else "spawn"
95+
)
9196
params = {"model_name": model_name, "cache_dir": cache_dir, **kwargs}
9297
pool = ParallelWorkerPool(
9398
parallel, self._get_worker_class(), start_method=start_method

fastembed/image/transform/functional.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Union, Tuple, Sized
2-
3-
from PIL import Image
1+
from typing import Sized, Tuple, Union
42

53
import numpy as np
4+
from PIL import Image
65

76

87
def convert_to_rgb(image: Image.Image) -> Image.Image:

0 commit comments

Comments
 (0)