Skip to content

Commit f233e06

Browse files
feat : adding a new Protocol for TextEmbedder (#9353)
* initial import * removing unused imports * adding an Embbeder Protocol * adding tests * adding tests * adding release notes * renaming dir * removing dir * cleaning * adding clean tests * dealing eith elipsis and pylint * wip: extending tests * cleaning extended tests * adding an invalid TextEmbedder
1 parent 2ccdba3 commit f233e06

File tree

4 files changed

+104
-0
lines changed

4 files changed

+104
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from .protocol import TextEmbedder
6+
7+
__all__ = ["TextEmbedder"]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Any, Dict, Protocol, TypeVar
6+
7+
T = TypeVar("T", bound="TextEmbedder")
8+
9+
# See https://github.com/pylint-dev/pylint/issues/9319.
10+
# pylint: disable=unnecessary-ellipsis
11+
12+
13+
class TextEmbedder(Protocol):
14+
"""
15+
Protocol for Text Embedders.
16+
"""
17+
18+
def run(self, text: str) -> Dict[str, Any]:
19+
"""
20+
Generate embeddings for the input text.
21+
22+
Implementing classes may accept additional optional parameters in their run method.
23+
For example: `def run (self, text: str, param_a="default", param_b="another_default")`.
24+
25+
:param text:
26+
The input text to be embedded.
27+
:returns:
28+
A dictionary containing the keys:
29+
- 'embedding', which is expected to be a List[float] representing the embedding.
30+
- any optional keys such as 'metadata'.
31+
"""
32+
...
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
We now have a Protocol for TextEmbedder. The protocol makes it easier to create custom components or SuperComponents
5+
that expect any TextEmbedder as init parameter.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import inspect
2+
from typing import Any, Dict
3+
4+
import pytest
5+
6+
from haystack import component
7+
from haystack.components.embedders.types.protocol import TextEmbedder
8+
9+
10+
@component
11+
class MockTextEmbedder:
12+
def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> Dict[str, Any]:
13+
return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}}
14+
15+
16+
@component
17+
class MockInvalidTextEmbedder:
18+
def run(self, something_else: float) -> dict[str, bool]:
19+
return {"result": True}
20+
21+
22+
def test_protocol_implementation():
23+
embedder: TextEmbedder = MockTextEmbedder() # should not raise any type errors
24+
25+
# check if the run method has the correct signature
26+
run_signature = inspect.signature(MockTextEmbedder.run)
27+
assert "text" in run_signature.parameters
28+
assert run_signature.parameters["text"].annotation == str
29+
assert run_signature.return_annotation == Dict[str, Any]
30+
31+
result = embedder.run("test text")
32+
assert isinstance(result, dict)
33+
assert "embedding" in result
34+
assert isinstance(result["embedding"], list)
35+
assert all(isinstance(x, float) for x in result["embedding"])
36+
assert isinstance(result["metadata"], dict)
37+
38+
39+
def test_protocol_optional_parameters():
40+
embedder = MockTextEmbedder()
41+
42+
# default parameters
43+
result1 = embedder.run("test text")
44+
45+
# with custom parameters
46+
result2 = embedder.run("test text", param_a="custom_a", param_b="custom_b")
47+
48+
assert result1["metadata"]["param_a"] == "default"
49+
assert result2["metadata"]["param_a"] == "custom_a"
50+
assert result2["metadata"]["param_b"] == "custom_b"
51+
52+
53+
def test_protocol_invalid_implementation():
54+
run_signature = inspect.signature(MockInvalidTextEmbedder.run)
55+
56+
with pytest.raises(AssertionError):
57+
assert "text" in run_signature.parameters and run_signature.parameters["text"].annotation == str
58+
59+
with pytest.raises(AssertionError):
60+
assert run_signature.return_annotation == Dict[str, Any]

0 commit comments

Comments
 (0)