Skip to content

Commit d2348ad

Browse files
authored
feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode (#8806)
* feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode * refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility reasons * docs: added explanation for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * doc: removed empty lines from docstrings of SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility (part II.)
1 parent 2828d9e commit d2348ad

5 files changed

+59
-3
lines changed

Diff for: haystack/components/embedders/sentence_transformers_document_embedder.py

+8
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
5656
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
5757
config_kwargs: Optional[Dict[str, Any]] = None,
5858
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
59+
encode_kwargs: Optional[Dict[str, Any]] = None,
5960
):
6061
"""
6162
Creates a SentenceTransformersDocumentEmbedder component.
@@ -104,6 +105,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
104105
All non-float32 precisions are quantized embeddings.
105106
Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
106107
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
108+
:param encode_kwargs:
109+
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
110+
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
111+
avoid passing parameters that change the output type.
107112
"""
108113

109114
self.model = model
@@ -121,6 +126,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
121126
self.model_kwargs = model_kwargs
122127
self.tokenizer_kwargs = tokenizer_kwargs
123128
self.config_kwargs = config_kwargs
129+
self.encode_kwargs = encode_kwargs
124130
self.embedding_backend = None
125131
self.precision = precision
126132

@@ -155,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]:
155161
tokenizer_kwargs=self.tokenizer_kwargs,
156162
config_kwargs=self.config_kwargs,
157163
precision=self.precision,
164+
encode_kwargs=self.encode_kwargs,
158165
)
159166
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
160167
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@@ -232,6 +239,7 @@ def run(self, documents: List[Document]):
232239
show_progress_bar=self.progress_bar,
233240
normalize_embeddings=self.normalize_embeddings,
234241
precision=self.precision,
242+
**(self.encode_kwargs if self.encode_kwargs else {}),
235243
)
236244

237245
for doc, emb in zip(documents, embeddings):

Diff for: haystack/components/embedders/sentence_transformers_text_embedder.py

+8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
5050
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
5151
config_kwargs: Optional[Dict[str, Any]] = None,
5252
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
53+
encode_kwargs: Optional[Dict[str, Any]] = None,
5354
):
5455
"""
5556
Create a SentenceTransformersTextEmbedder component.
@@ -94,6 +95,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
9495
All non-float32 precisions are quantized embeddings.
9596
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
9697
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
98+
:param encode_kwargs:
99+
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
100+
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
101+
avoid passing parameters that change the output type.
97102
"""
98103

99104
self.model = model
@@ -109,6 +114,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
109114
self.model_kwargs = model_kwargs
110115
self.tokenizer_kwargs = tokenizer_kwargs
111116
self.config_kwargs = config_kwargs
117+
self.encode_kwargs = encode_kwargs
112118
self.embedding_backend = None
113119
self.precision = precision
114120

@@ -141,6 +147,7 @@ def to_dict(self) -> Dict[str, Any]:
141147
tokenizer_kwargs=self.tokenizer_kwargs,
142148
config_kwargs=self.config_kwargs,
143149
precision=self.precision,
150+
encode_kwargs=self.encode_kwargs,
144151
)
145152
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
146153
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@@ -209,5 +216,6 @@ def run(self, text: str):
209216
show_progress_bar=self.progress_bar,
210217
normalize_embeddings=self.normalize_embeddings,
211218
precision=self.precision,
219+
**(self.encode_kwargs if self.encode_kwargs else {}),
212220
)[0]
213221
return {"embedding": embedding}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
enhancements:
3+
- |
4+
Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept
5+
an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method
6+
for greater flexibility in embedding customization.

Diff for: test/components/embedders/test_sentence_transformers_document_embedder.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import random
45
from unittest.mock import MagicMock, patch
56

6-
import random
77
import pytest
88
import torch
99

@@ -79,6 +79,7 @@ def test_to_dict(self):
7979
"truncate_dim": None,
8080
"model_kwargs": None,
8181
"tokenizer_kwargs": None,
82+
"encode_kwargs": None,
8283
"config_kwargs": None,
8384
"precision": "float32",
8485
},
@@ -102,6 +103,7 @@ def test_to_dict_with_custom_init_parameters(self):
102103
tokenizer_kwargs={"model_max_length": 512},
103104
config_kwargs={"use_memory_efficient_attention": True},
104105
precision="int8",
106+
encode_kwargs={"task": "clustering"},
105107
)
106108
data = component.to_dict()
107109

@@ -124,6 +126,7 @@ def test_to_dict_with_custom_init_parameters(self):
124126
"tokenizer_kwargs": {"model_max_length": 512},
125127
"config_kwargs": {"use_memory_efficient_attention": True},
126128
"precision": "int8",
129+
"encode_kwargs": {"task": "clustering"},
127130
},
128131
}
129132

@@ -316,6 +319,20 @@ def test_embed_metadata(self):
316319
precision="float32",
317320
)
318321

322+
def test_embed_encode_kwargs(self):
323+
embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
324+
embedder.embedding_backend = MagicMock()
325+
documents = [Document(content=f"document number {i}") for i in range(5)]
326+
embedder.run(documents=documents)
327+
embedder.embedding_backend.embed.assert_called_once_with(
328+
["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
329+
batch_size=32,
330+
show_progress_bar=True,
331+
normalize_embeddings=False,
332+
precision="float32",
333+
task="retrieval.passage",
334+
)
335+
319336
def test_prefix_suffix(self):
320337
embedder = SentenceTransformersDocumentEmbedder(
321338
model="model",

Diff for: test/components/embedders/test_sentence_transformers_text_embedder.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import random
45
from unittest.mock import MagicMock, patch
56

6-
import torch
7-
import random
87
import pytest
8+
import torch
99

1010
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
1111
from haystack.utils import ComponentDevice, Secret
@@ -70,6 +70,7 @@ def test_to_dict(self):
7070
"truncate_dim": None,
7171
"model_kwargs": None,
7272
"tokenizer_kwargs": None,
73+
"encode_kwargs": None,
7374
"config_kwargs": None,
7475
"precision": "float32",
7576
},
@@ -91,6 +92,7 @@ def test_to_dict_with_custom_init_parameters(self):
9192
tokenizer_kwargs={"model_max_length": 512},
9293
config_kwargs={"use_memory_efficient_attention": False},
9394
precision="int8",
95+
encode_kwargs={"task": "clustering"},
9496
)
9597
data = component.to_dict()
9698
assert data == {
@@ -110,6 +112,7 @@ def test_to_dict_with_custom_init_parameters(self):
110112
"tokenizer_kwargs": {"model_max_length": 512},
111113
"config_kwargs": {"use_memory_efficient_attention": False},
112114
"precision": "int8",
115+
"encode_kwargs": {"task": "clustering"},
113116
},
114117
}
115118

@@ -297,3 +300,17 @@ def test_run_quantization(self):
297300

298301
assert len(embedding_def) == 768
299302
assert all(isinstance(el, int) for el in embedding_def)
303+
304+
def test_embed_encode_kwargs(self):
305+
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
306+
embedder.embedding_backend = MagicMock()
307+
text = "a nice text to embed"
308+
embedder.run(text=text)
309+
embedder.embedding_backend.embed.assert_called_once_with(
310+
[text],
311+
batch_size=32,
312+
show_progress_bar=True,
313+
normalize_embeddings=False,
314+
precision="float32",
315+
task="retrieval.query",
316+
)

0 commit comments

Comments
 (0)