diff --git a/fastembed/common/model_description.py b/fastembed/common/model_description.py index 432f8186..f1d7ac45 100644 --- a/fastembed/common/model_description.py +++ b/fastembed/common/model_description.py @@ -7,6 +7,11 @@ class ModelSource: hf: Optional[str] = None url: Optional[str] = None + _deprecated_tar_struct: bool = False + + @property + def deprecated_tar_struct(self) -> bool: + return self._deprecated_tar_struct def __post_init__(self) -> None: if self.hf is None and self.url is None: diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 33e0994f..ce004cb6 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -330,9 +330,10 @@ def retrieve_model_gcs( model_name: str, source_url: str, cache_dir: str, + deprecated_tar_struct: bool = False, local_files_only: bool = False, ) -> Path: - fast_model_name = f"fast-{model_name.split('/')[-1]}" + fast_model_name = f"{'fast-' if deprecated_tar_struct else ''}{model_name.split('/')[-1]}" cache_tmp_dir = Path(cache_dir) / "tmp" model_tmp_dir = cache_tmp_dir / fast_model_name model_dir = Path(cache_dir) / fast_model_name @@ -438,6 +439,7 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An model.model, str(url_source), str(cache_dir), + deprecated_tar_struct=model.sources.deprecated_tar_struct, local_files_only=local_files_only, ) except Exception: diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 9db22fa0..ee4f8f55 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union -import numpy as np from fastembed.common.types import NumpyArray from fastembed.common import ImageInput, OnnxProvider @@ -195,7 +194,7 @@ def _preprocess_onnx_input( return onnx_input def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: - return normalize(output.model_output).astype(np.float32) + return normalize(output.model_output) class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]): diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 841d3a73..eb9545b3 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -46,7 +46,7 @@ def _post_process_onnx_output( self, output: OnnxOutputContext, is_doc: bool = True ) -> Iterable[NumpyArray]: if not is_doc: - return output.model_output.astype(np.float32) + return output.model_output if output.input_ids is None or output.attention_mask is None: raise ValueError( @@ -58,11 +58,11 @@ def _post_process_onnx_output( if token_id in self.skip_list or token_id == self.pad_token_id: output.attention_mask[i, j] = 0 - output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32) + output.model_output *= np.expand_dims(output.attention_mask, 2) norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) norm_clamped = np.maximum(norm, 1e-12) output.model_output /= norm_clamped - return output.model_output.astype(np.float32) + return output.model_output def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 0193bed9..c43ff9d0 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -142,7 +142,7 @@ def _post_process_onnx_image_output( assert self.model_description.dim is not None, "Model dim is not defined" return output.model_output.reshape( output.model_output.shape[0], -1, self.model_description.dim - ).astype(np.float32) + ) def _post_process_onnx_text_output( self, @@ -157,7 +157,7 @@ def _post_process_onnx_text_output( Returns: Iterable[NumpyArray]: Post-processed output as NumPy arrays. """ - return output.model_output.astype(np.float32) + return output.model_output def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: texts_query: list[str] = [] diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index d272c37a..6e8212f1 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type, Union -import numpy as np from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -21,6 +20,7 @@ sources=ModelSource( hf="Qdrant/fast-bge-base-en", url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", + _deprecated_tar_struct=True, ), model_file="model_optimized.onnx", ), @@ -36,6 +36,7 @@ sources=ModelSource( hf="qdrant/bge-base-en-v1.5-onnx-q", url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", + _deprecated_tar_struct=True, ), model_file="model_optimized.onnx", ), @@ -63,6 +64,7 @@ sources=ModelSource( hf="Qdrant/bge-small-en", url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", + _deprecated_tar_struct=True, ), model_file="model_optimized.onnx", ), @@ -90,6 +92,7 @@ sources=ModelSource( hf="Qdrant/bge-small-zh-v1.5", url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", + _deprecated_tar_struct=True, ), model_file="model_optimized.onnx", ), @@ -309,7 +312,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy processed_embeddings = embeddings else: raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") - return normalize(processed_embeddings).astype(np.float32) + return normalize(processed_embeddings) def load_onnx_model(self) -> None: self._load_onnx_model( diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 1dc8c9f5..a03b8003 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -82,6 +82,7 @@ sources=ModelSource( hf="qdrant/multilingual-e5-large-onnx", url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", + _deprecated_tar_struct=True, ), model_file="model.onnx", additional_files=["model.onnx_data"], @@ -115,7 +116,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy embeddings = output.model_output attn_mask = output.attention_mask - return self.mean_pooling(embeddings, attn_mask).astype(np.float32) + return self.mean_pooling(embeddings, attn_mask) class PooledEmbeddingWorker(OnnxTextEmbeddingWorker): diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index ed825eca..8c11b43e 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Type -import numpy as np from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext @@ -22,6 +21,7 @@ sources=ModelSource( url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", hf="qdrant/all-MiniLM-L6-v2-onnx", + _deprecated_tar_struct=True, ), model_file="model.onnx", ), @@ -144,7 +144,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy embeddings = output.model_output attn_mask = output.attention_mask - return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) + return normalize(self.mean_pooling(embeddings, attn_mask)) class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index f893c57d..082d4c8c 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -91,15 +91,11 @@ def test_mock_add_custom_models(): expected_output = { f"{PoolingType.MEAN.lower()}-normalized": normalize( mean_pooling(dummy_token_embedding, dummy_attention_mask) - ).astype(np.float32), - f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), - f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype( - np.float32 ), + f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), + f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]), f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0], - f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype( - np.float32 - ), + f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding), f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding, }