Skip to content

Commit cef3caa

Browse files
Merge branch 'main' into speedup-ci
2 parents 3826294 + 1729aab commit cef3caa

File tree

9 files changed

+26
-20
lines changed

9 files changed

+26
-20
lines changed

fastembed/common/model_description.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
class ModelSource:
88
hf: Optional[str] = None
99
url: Optional[str] = None
10+
_deprecated_tar_struct: bool = False
11+
12+
@property
13+
def deprecated_tar_struct(self) -> bool:
14+
return self._deprecated_tar_struct
1015

1116
def __post_init__(self) -> None:
1217
if self.hf is None and self.url is None:

fastembed/common/model_management.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ def retrieve_model_gcs(
330330
model_name: str,
331331
source_url: str,
332332
cache_dir: str,
333+
deprecated_tar_struct: bool = False,
333334
local_files_only: bool = False,
334335
) -> Path:
335-
fast_model_name = f"fast-{model_name.split('/')[-1]}"
336+
fast_model_name = f"{'fast-' if deprecated_tar_struct else ''}{model_name.split('/')[-1]}"
336337
cache_tmp_dir = Path(cache_dir) / "tmp"
337338
model_tmp_dir = cache_tmp_dir / fast_model_name
338339
model_dir = Path(cache_dir) / fast_model_name
@@ -438,6 +439,7 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An
438439
model.model,
439440
str(url_source),
440441
str(cache_dir),
442+
deprecated_tar_struct=model.sources.deprecated_tar_struct,
441443
local_files_only=local_files_only,
442444
)
443445
except Exception:

fastembed/image/onnx_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

3-
import numpy as np
43

54
from fastembed.common.types import NumpyArray
65
from fastembed.common import ImageInput, OnnxProvider
@@ -195,7 +194,7 @@ def _preprocess_onnx_input(
195194
return onnx_input
196195

197196
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
198-
return normalize(output.model_output).astype(np.float32)
197+
return normalize(output.model_output)
199198

200199

201200
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):

fastembed/late_interaction/colbert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _post_process_onnx_output(
4646
self, output: OnnxOutputContext, is_doc: bool = True
4747
) -> Iterable[NumpyArray]:
4848
if not is_doc:
49-
return output.model_output.astype(np.float32)
49+
return output.model_output
5050

5151
if output.input_ids is None or output.attention_mask is None:
5252
raise ValueError(
@@ -58,11 +58,11 @@ def _post_process_onnx_output(
5858
if token_id in self.skip_list or token_id == self.pad_token_id:
5959
output.attention_mask[i, j] = 0
6060

61-
output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
61+
output.model_output *= np.expand_dims(output.attention_mask, 2)
6262
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
6363
norm_clamped = np.maximum(norm, 1e-12)
6464
output.model_output /= norm_clamped
65-
return output.model_output.astype(np.float32)
65+
return output.model_output
6666

6767
def _preprocess_onnx_input(
6868
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _post_process_onnx_image_output(
142142
assert self.model_description.dim is not None, "Model dim is not defined"
143143
return output.model_output.reshape(
144144
output.model_output.shape[0], -1, self.model_description.dim
145-
).astype(np.float32)
145+
)
146146

147147
def _post_process_onnx_text_output(
148148
self,
@@ -157,7 +157,7 @@ def _post_process_onnx_text_output(
157157
Returns:
158158
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
159159
"""
160-
return output.model_output.astype(np.float32)
160+
return output.model_output
161161

162162
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
163163
texts_query: list[str] = []

fastembed/text/onnx_embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

3-
import numpy as np
43
from fastembed.common.types import NumpyArray, OnnxProvider
54
from fastembed.common.onnx_model import OnnxOutputContext
65
from fastembed.common.utils import define_cache_dir, normalize
@@ -21,6 +20,7 @@
2120
sources=ModelSource(
2221
hf="Qdrant/fast-bge-base-en",
2322
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
23+
_deprecated_tar_struct=True,
2424
),
2525
model_file="model_optimized.onnx",
2626
),
@@ -36,6 +36,7 @@
3636
sources=ModelSource(
3737
hf="qdrant/bge-base-en-v1.5-onnx-q",
3838
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
39+
_deprecated_tar_struct=True,
3940
),
4041
model_file="model_optimized.onnx",
4142
),
@@ -63,6 +64,7 @@
6364
sources=ModelSource(
6465
hf="Qdrant/bge-small-en",
6566
url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
67+
_deprecated_tar_struct=True,
6668
),
6769
model_file="model_optimized.onnx",
6870
),
@@ -90,6 +92,7 @@
9092
sources=ModelSource(
9193
hf="Qdrant/bge-small-zh-v1.5",
9294
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
95+
_deprecated_tar_struct=True,
9396
),
9497
model_file="model_optimized.onnx",
9598
),
@@ -309,7 +312,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
309312
processed_embeddings = embeddings
310313
else:
311314
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
312-
return normalize(processed_embeddings).astype(np.float32)
315+
return normalize(processed_embeddings)
313316

314317
def load_onnx_model(self) -> None:
315318
self._load_onnx_model(

fastembed/text/pooled_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
sources=ModelSource(
8383
hf="qdrant/multilingual-e5-large-onnx",
8484
url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
85+
_deprecated_tar_struct=True,
8586
),
8687
model_file="model.onnx",
8788
additional_files=["model.onnx_data"],
@@ -115,7 +116,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
115116

116117
embeddings = output.model_output
117118
attn_mask = output.attention_mask
118-
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
119+
return self.mean_pooling(embeddings, attn_mask)
119120

120121

121122
class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):

fastembed/text/pooled_normalized_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Iterable, Type
22

3-
import numpy as np
43

54
from fastembed.common.types import NumpyArray
65
from fastembed.common.onnx_model import OnnxOutputContext
@@ -22,6 +21,7 @@
2221
sources=ModelSource(
2322
url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
2423
hf="qdrant/all-MiniLM-L6-v2-onnx",
24+
_deprecated_tar_struct=True,
2525
),
2626
model_file="model.onnx",
2727
),
@@ -144,7 +144,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
144144

145145
embeddings = output.model_output
146146
attn_mask = output.attention_mask
147-
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
147+
return normalize(self.mean_pooling(embeddings, attn_mask))
148148

149149

150150
class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):

tests/test_custom_models.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,11 @@ def test_mock_add_custom_models():
9191
expected_output = {
9292
f"{PoolingType.MEAN.lower()}-normalized": normalize(
9393
mean_pooling(dummy_token_embedding, dummy_attention_mask)
94-
).astype(np.float32),
95-
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
96-
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype(
97-
np.float32
9894
),
95+
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
96+
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]),
9997
f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0],
100-
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype(
101-
np.float32
102-
),
98+
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding),
10399
f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding,
104100
}
105101

0 commit comments

Comments
 (0)