Skip to content

Commit

Permalink
Merge branch 'main' into speedup-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader authored Mar 5, 2025
2 parents 3826294 + 1729aab commit cef3caa
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 20 deletions.
5 changes: 5 additions & 0 deletions fastembed/common/model_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
class ModelSource:
hf: Optional[str] = None
url: Optional[str] = None
_deprecated_tar_struct: bool = False

@property
def deprecated_tar_struct(self) -> bool:
return self._deprecated_tar_struct

def __post_init__(self) -> None:
if self.hf is None and self.url is None:
Expand Down
4 changes: 3 additions & 1 deletion fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def retrieve_model_gcs(
model_name: str,
source_url: str,
cache_dir: str,
deprecated_tar_struct: bool = False,
local_files_only: bool = False,
) -> Path:
fast_model_name = f"fast-{model_name.split('/')[-1]}"
fast_model_name = f"{'fast-' if deprecated_tar_struct else ''}{model_name.split('/')[-1]}"
cache_tmp_dir = Path(cache_dir) / "tmp"
model_tmp_dir = cache_tmp_dir / fast_model_name
model_dir = Path(cache_dir) / fast_model_name
Expand Down Expand Up @@ -438,6 +439,7 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An
model.model,
str(url_source),
str(cache_dir),
deprecated_tar_struct=model.sources.deprecated_tar_struct,
local_files_only=local_files_only,
)
except Exception:
Expand Down
3 changes: 1 addition & 2 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common import ImageInput, OnnxProvider
Expand Down Expand Up @@ -195,7 +194,7 @@ def _preprocess_onnx_input(
return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
return normalize(output.model_output).astype(np.float32)
return normalize(output.model_output)


class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
Expand Down
6 changes: 3 additions & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _post_process_onnx_output(
self, output: OnnxOutputContext, is_doc: bool = True
) -> Iterable[NumpyArray]:
if not is_doc:
return output.model_output.astype(np.float32)
return output.model_output

if output.input_ids is None or output.attention_mask is None:
raise ValueError(
Expand All @@ -58,11 +58,11 @@ def _post_process_onnx_output(
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
output.model_output *= np.expand_dims(output.attention_mask, 2)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped
return output.model_output.astype(np.float32)
return output.model_output

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _post_process_onnx_image_output(
assert self.model_description.dim is not None, "Model dim is not defined"
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description.dim
).astype(np.float32)
)

def _post_process_onnx_text_output(
self,
Expand All @@ -157,7 +157,7 @@ def _post_process_onnx_text_output(
Returns:
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)
return output.model_output

def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
texts_query: list[str] = []
Expand Down
7 changes: 5 additions & 2 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir, normalize
Expand All @@ -21,6 +20,7 @@
sources=ModelSource(
hf="Qdrant/fast-bge-base-en",
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
_deprecated_tar_struct=True,
),
model_file="model_optimized.onnx",
),
Expand All @@ -36,6 +36,7 @@
sources=ModelSource(
hf="qdrant/bge-base-en-v1.5-onnx-q",
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
_deprecated_tar_struct=True,
),
model_file="model_optimized.onnx",
),
Expand Down Expand Up @@ -63,6 +64,7 @@
sources=ModelSource(
hf="Qdrant/bge-small-en",
url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
_deprecated_tar_struct=True,
),
model_file="model_optimized.onnx",
),
Expand Down Expand Up @@ -90,6 +92,7 @@
sources=ModelSource(
hf="Qdrant/bge-small-zh-v1.5",
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
_deprecated_tar_struct=True,
),
model_file="model_optimized.onnx",
),
Expand Down Expand Up @@ -309,7 +312,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)
return normalize(processed_embeddings)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
3 changes: 2 additions & 1 deletion fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
sources=ModelSource(
hf="qdrant/multilingual-e5-large-onnx",
url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
_deprecated_tar_struct=True,
),
model_file="model.onnx",
additional_files=["model.onnx_data"],
Expand Down Expand Up @@ -115,7 +116,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
return self.mean_pooling(embeddings, attn_mask)


class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
4 changes: 2 additions & 2 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Type

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
Expand All @@ -22,6 +21,7 @@
sources=ModelSource(
url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
hf="qdrant/all-MiniLM-L6-v2-onnx",
_deprecated_tar_struct=True,
),
model_file="model.onnx",
),
Expand Down Expand Up @@ -144,7 +144,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
return normalize(self.mean_pooling(embeddings, attn_mask))


class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
10 changes: 3 additions & 7 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ def test_mock_add_custom_models():
expected_output = {
f"{PoolingType.MEAN.lower()}-normalized": normalize(
mean_pooling(dummy_token_embedding, dummy_attention_mask)
).astype(np.float32),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype(
np.float32
),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]),
f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0],
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype(
np.float32
),
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding),
f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding,
}

Expand Down

0 comments on commit cef3caa

Please sign in to comment.