diff --git a/haystack/components/embedders/backends/sentence_transformers_backend.py b/haystack/components/embedders/backends/sentence_transformers_backend.py index 9d86b80ca0..cff9135c86 100644 --- a/haystack/components/embedders/backends/sentence_transformers_backend.py +++ b/haystack/components/embedders/backends/sentence_transformers_backend.py @@ -29,6 +29,7 @@ def get_embedding_backend( truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}" @@ -42,6 +43,7 @@ def get_embedding_backend( truncate_dim=truncate_dim, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -61,6 +63,7 @@ def __init__( truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): sentence_transformers_import.check() self.model = SentenceTransformer( @@ -71,6 +74,7 @@ def __init__( truncate_dim=truncate_dim, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) def embed(self, data: List[str], **kwargs) -> List[List[float]]: diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index d7f78b54f6..195c5d10c8 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -54,6 +54,7 @@ def __init__( # noqa: PLR0913 truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", ): """ @@ -96,10 +97,12 @@ def __init__( # noqa: PLR0913 :param tokenizer_kwargs: Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. Refer to specific model documentation for available kwargs. + :param config_kwargs: + Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. - Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. + Quantized embeddings are smaller and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. """ @@ -117,6 +120,7 @@ def __init__( # noqa: PLR0913 self.truncate_dim = truncate_dim self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs + self.config_kwargs = config_kwargs self.embedding_backend = None self.precision = precision @@ -149,6 +153,7 @@ def to_dict(self) -> Dict[str, Any]: truncate_dim=self.truncate_dim, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs, precision=self.precision, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: @@ -186,6 +191,7 @@ def warm_up(self): truncate_dim=self.truncate_dim, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs, ) if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] diff --git a/releasenotes/notes/sentence-transformer-doc-embedder-config_kwargs-d7d254c6b94887c4.yaml b/releasenotes/notes/sentence-transformer-doc-embedder-config_kwargs-d7d254c6b94887c4.yaml new file mode 100644 index 0000000000..8693b75a50 --- /dev/null +++ b/releasenotes/notes/sentence-transformer-doc-embedder-config_kwargs-d7d254c6b94887c4.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + SentenceTransformersDocumentEmbedder now supports config_kwargs for additional parameters when loading the model configuration diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index a5e6af8278..1c0bc526ed 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -79,6 +79,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "config_kwargs": None, "precision": "float32", }, } @@ -99,6 +100,7 @@ def test_to_dict_with_custom_init_parameters(self): truncate_dim=256, model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, + config_kwargs={"use_memory_efficient_attention": True}, precision="int8", ) data = component.to_dict() @@ -120,6 +122,7 @@ def test_to_dict_with_custom_init_parameters(self): "truncate_dim": 256, "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, + "config_kwargs": {"use_memory_efficient_attention": True}, "precision": "int8", }, } @@ -140,6 +143,7 @@ def test_from_dict(self): "truncate_dim": 256, "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, + "config_kwargs": {"use_memory_efficient_attention": True}, "precision": "int8", } component = SentenceTransformersDocumentEmbedder.from_dict( @@ -162,6 +166,7 @@ def test_from_dict(self): assert component.truncate_dim == 256 assert component.model_kwargs == {"torch_dtype": torch.float32} assert component.tokenizer_kwargs == {"model_max_length": 512} + assert component.config_kwargs == {"use_memory_efficient_attention": True} assert component.precision == "int8" def test_from_dict_no_default_parameters(self): @@ -230,6 +235,7 @@ def test_warmup(self, mocked_factory): token=None, device=ComponentDevice.from_str("cpu"), tokenizer_kwargs={"model_max_length": 512}, + config_kwargs={"use_memory_efficient_attention": True}, ) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() @@ -242,6 +248,7 @@ def test_warmup(self, mocked_factory): truncate_dim=None, model_kwargs=None, tokenizer_kwargs={"model_max_length": 512}, + config_kwargs={"use_memory_efficient_attention": True}, ) @patch( @@ -291,11 +298,8 @@ def test_embed_metadata(self): model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" ) embedder.embedding_backend = MagicMock() - documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] - embedder.run(documents=documents) - embedder.embedding_backend.embed.assert_called_once_with( [ "meta_value 0\ndocument number 0", @@ -319,11 +323,8 @@ def test_prefix_suffix(self): embedding_separator="\n", ) embedder.embedding_backend = MagicMock() - documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] - embedder.run(documents=documents) - embedder.embedding_backend.embed.assert_called_once_with( [ "my_prefix meta_value 0\ndocument number 0 my_suffix", diff --git a/test/components/embedders/test_sentence_transformers_embedding_backend.py b/test/components/embedders/test_sentence_transformers_embedding_backend.py index 7ca42aab91..55014183b2 100644 --- a/test/components/embedders/test_sentence_transformers_embedding_backend.py +++ b/test/components/embedders/test_sentence_transformers_embedding_backend.py @@ -42,6 +42,7 @@ def test_model_initialization(mock_sentence_transformer): truncate_dim=256, model_kwargs=None, tokenizer_kwargs=None, + config_kwargs=None, )