Skip to content

Commit 62607c2

Browse files
I8dNLojoein
andauthored
Fix to avoid overfloat and get rid of model_max_length (#319)
* Fix to avoid overfloat and get rid of model_max_length * Fixes for max_length vs model_max_length logic Jupter warning disabled * Support of jwodder/versioningit#48 * Update fastembed/common/preprocessor_utils.py --------- Co-authored-by: George <[email protected]>
1 parent 49762a6 commit 62607c2

File tree

5 files changed

+31
-44
lines changed

5 files changed

+31
-44
lines changed

fastembed/common/model_management.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
148148
# Open the tar.gz file
149149
with tarfile.open(targz_path, "r:gz") as tar:
150150
# Extract all files into the cache directory
151-
tar.extractall(path=cache_dir)
151+
tar.extractall(
152+
path=cache_dir,
153+
)
152154
except tarfile.TarError as e:
153155
# If any error occurs while opening or extracting the tar.gz file,
154156
# delete the cache directory (if it was created in this function)

fastembed/common/preprocessor_utils.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
from pathlib import Path
33
from typing import Tuple
4-
54
from tokenizers import AddedToken, Tokenizer
65

76
from fastembed.image.transform.operators import Compose
@@ -18,7 +17,7 @@ def load_special_tokens(model_dir: Path) -> dict:
1817
return tokens_map
1918

2019

21-
def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, dict]:
20+
def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]:
2221
config_path = model_dir / "config.json"
2322
if not config_path.exists():
2423
raise ValueError(f"Could not find config.json in {model_dir}")
@@ -36,13 +35,20 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, d
3635

3736
with open(str(tokenizer_config_path)) as tokenizer_config_file:
3837
tokenizer_config = json.load(tokenizer_config_file)
38+
assert (
39+
"model_max_length" in tokenizer_config or "max_length" in tokenizer_config
40+
), "Models without model_max_length or max_length are not supported."
41+
if "model_max_length" not in tokenizer_config:
42+
max_context = tokenizer_config["max_length"]
43+
elif "max_length" not in tokenizer_config:
44+
max_context = tokenizer_config["model_max_length"]
45+
else:
46+
max_context = min(tokenizer_config["model_max_length"], tokenizer_config["max_length"])
3947

4048
tokens_map = load_special_tokens(model_dir)
4149

4250
tokenizer = Tokenizer.from_file(str(tokenizer_path))
43-
tokenizer.enable_truncation(
44-
max_length=min(tokenizer_config["model_max_length"], max_length)
45-
)
51+
tokenizer.enable_truncation(max_length=max_context)
4652
tokenizer.enable_padding(
4753
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
4854
)

fastembed/text/onnx_embedding.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ def _preprocess_onnx_input(
244244
"""
245245
return onnx_input
246246

247-
def _post_process_onnx_output(
248-
self, output: OnnxOutputContext
249-
) -> Iterable[np.ndarray]:
247+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
250248
embeddings = output.model_output
251249
return normalize(embeddings[:, 0]).astype(np.float32)
252250

@@ -258,6 +256,4 @@ def init_embedding(
258256
cache_dir: str,
259257
**kwargs,
260258
) -> OnnxTextEmbedding:
261-
return OnnxTextEmbedding(
262-
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
263-
)
259+
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)

tests/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import os
2+
3+
# disable DeprecationWarning https://github.com/jupyter/jupyter_core/issues/398
4+
os.environ["JUPYTER_PLATFORM_DIRS"] = "1"

tests/test_text_onnx_embeddings.py

+11-32
Original file line numberDiff line numberDiff line change
@@ -32,54 +32,34 @@
3232
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": np.array(
3333
[0.0094, 0.0184, 0.0328, 0.0072, -0.0351]
3434
),
35-
"intfloat/multilingual-e5-large": np.array(
36-
[0.0098, 0.0045, 0.0066, -0.0354, 0.0070]
37-
),
35+
"intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]),
3836
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2": np.array(
3937
[-0.01341097, 0.0416553, -0.00480805, 0.02844842, 0.0505299]
4038
),
41-
"jinaai/jina-embeddings-v2-small-en": np.array(
42-
[-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]
43-
),
44-
"jinaai/jina-embeddings-v2-base-en": np.array(
45-
[-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]
46-
),
47-
"jinaai/jina-embeddings-v2-base-de": np.array(
48-
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
49-
),
50-
"jinaai/jina-embeddings-v2-base-code": np.array(
51-
[0.0145, -0.0164, 0.0136, -0.0170, 0.0734]
52-
),
53-
"nomic-ai/nomic-embed-text-v1": np.array(
54-
[0.3708 , 0.2031, -0.3406, -0.2114, -0.3230]
55-
),
39+
"jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]),
40+
"jinaai/jina-embeddings-v2-base-en": np.array([-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]),
41+
"jinaai/jina-embeddings-v2-base-de": np.array([-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]),
42+
"jinaai/jina-embeddings-v2-base-code": np.array([0.0145, -0.0164, 0.0136, -0.0170, 0.0734]),
43+
"nomic-ai/nomic-embed-text-v1": np.array([0.3708, 0.2031, -0.3406, -0.2114, -0.3230]),
5644
"nomic-ai/nomic-embed-text-v1.5": np.array(
5745
[-0.15407836, -0.03053198, -3.9138033, 0.1910364, 0.13224715]
5846
),
5947
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
60-
[-0.12525563, 0.38030425, -3.961622 , 0.04176439, -0.0758301]
48+
[-0.12525563, 0.38030425, -3.961622, 0.04176439, -0.0758301]
6149
),
6250
"thenlper/gte-large": np.array(
6351
[-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]
6452
),
6553
"mixedbread-ai/mxbai-embed-large-v1": np.array(
6654
[0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]
6755
),
68-
"snowflake/snowflake-arctic-embed-xs": np.array(
69-
[0.0092, 0.0619, 0.0196, 0.009, -0.0114]
70-
),
71-
"snowflake/snowflake-arctic-embed-s": np.array(
72-
[-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]
73-
),
74-
"snowflake/snowflake-arctic-embed-m": np.array(
75-
[-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]
76-
),
56+
"snowflake/snowflake-arctic-embed-xs": np.array([0.0092, 0.0619, 0.0196, 0.009, -0.0114]),
57+
"snowflake/snowflake-arctic-embed-s": np.array([-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]),
58+
"snowflake/snowflake-arctic-embed-m": np.array([-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]),
7759
"snowflake/snowflake-arctic-embed-m-long": np.array(
7860
[0.0080, -0.0266, -0.0335, 0.0282, 0.0143]
7961
),
80-
"snowflake/snowflake-arctic-embed-l": np.array(
81-
[0.0189, -0.0673, 0.0183, 0.0124, 0.0146]
82-
),
62+
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
8363
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
8464
}
8565

@@ -94,7 +74,6 @@ def test_embedding():
9474
dim = model_desc["dim"]
9575

9676
model = TextEmbedding(model_name=model_desc["model"])
97-
9877
docs = ["hello world", "flag embedding"]
9978
embeddings = list(model.embed(docs))
10079
embeddings = np.stack(embeddings, axis=0)

0 commit comments

Comments
 (0)