Skip to content

Commit 32112dd

Browse files
committed
new: preserve embeddings in a type set by their model
1 parent 2082108 commit 32112dd

File tree

7 files changed

+14
-19
lines changed

7 files changed

+14
-19
lines changed

fastembed/common/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]]
1818
NumpyArray = Union[
19+
NDArray[np.float64],
1920
NDArray[np.float32],
2021
NDArray[np.float16],
2122
NDArray[np.int8],

fastembed/late_interaction/colbert.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _post_process_onnx_output(
4646
self, output: OnnxOutputContext, is_doc: bool = True
4747
) -> Iterable[NumpyArray]:
4848
if not is_doc:
49-
return output.model_output.astype(np.float32)
49+
return output.model_output
5050

5151
if output.input_ids is None or output.attention_mask is None:
5252
raise ValueError(
@@ -58,11 +58,11 @@ def _post_process_onnx_output(
5858
if token_id in self.skip_list or token_id == self.pad_token_id:
5959
output.attention_mask[i, j] = 0
6060

61-
output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
61+
output.model_output *= np.expand_dims(output.attention_mask, 2)
6262
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
6363
norm_clamped = np.maximum(norm, 1e-12)
6464
output.model_output /= norm_clamped
65-
return output.model_output.astype(np.float32)
65+
return output.model_output
6666

6767
def _preprocess_onnx_input(
6868
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any

fastembed/late_interaction_multimodal/colpali.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _post_process_onnx_image_output(
142142
assert self.model_description.dim is not None, "Model dim is not defined"
143143
return output.model_output.reshape(
144144
output.model_output.shape[0], -1, self.model_description.dim
145-
).astype(np.float32)
145+
)
146146

147147
def _post_process_onnx_text_output(
148148
self,
@@ -157,7 +157,7 @@ def _post_process_onnx_text_output(
157157
Returns:
158158
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
159159
"""
160-
return output.model_output.astype(np.float32)
160+
return output.model_output
161161

162162
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
163163
texts_query: list[str] = []

fastembed/text/onnx_embedding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

3-
import numpy as np
43
from fastembed.common.types import NumpyArray, OnnxProvider
54
from fastembed.common.onnx_model import OnnxOutputContext
65
from fastembed.common.utils import define_cache_dir, normalize
@@ -309,7 +308,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
309308
processed_embeddings = embeddings
310309
else:
311310
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
312-
return normalize(processed_embeddings).astype(np.float32)
311+
return normalize(processed_embeddings)
313312

314313
def load_onnx_model(self) -> None:
315314
self._load_onnx_model(

fastembed/text/pooled_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
115115

116116
embeddings = output.model_output
117117
attn_mask = output.attention_mask
118-
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
118+
return self.mean_pooling(embeddings, attn_mask)
119119

120120

121121
class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):

fastembed/text/pooled_normalized_embedding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Iterable, Type
22

3-
import numpy as np
43

54
from fastembed.common.types import NumpyArray
65
from fastembed.common.onnx_model import OnnxOutputContext
@@ -144,7 +143,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
144143

145144
embeddings = output.model_output
146145
attn_mask = output.attention_mask
147-
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
146+
return normalize(self.mean_pooling(embeddings, attn_mask))
148147

149148

150149
class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):

tests/test_custom_models.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_mock_add_custom_models():
7171
source = ModelSource(hf="artificial")
7272

7373
num_tokens = 10
74-
dummy_pooled_embedding = np.random.random((1, dim)).astype(np.float32)
75-
dummy_token_embedding = np.random.random((1, num_tokens, dim)).astype(np.float32)
74+
dummy_pooled_embedding = np.random.random((1, dim))
75+
dummy_token_embedding = np.random.random((1, num_tokens, dim))
7676
dummy_attention_mask = np.ones((1, num_tokens)).astype(np.int64)
7777

7878
dummy_token_output = OnnxOutputContext(
@@ -91,15 +91,11 @@ def test_mock_add_custom_models():
9191
expected_output = {
9292
f"{PoolingType.MEAN.lower()}-normalized": normalize(
9393
mean_pooling(dummy_token_embedding, dummy_attention_mask)
94-
).astype(np.float32),
95-
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
96-
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype(
97-
np.float32
9894
),
95+
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
96+
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]),
9997
f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0],
100-
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype(
101-
np.float32
102-
),
98+
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding),
10399
f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding,
104100
}
105101

0 commit comments

Comments
 (0)