Skip to content

Commit b0809b7

Browse files
authored
feat: Add a ListJoiner component (#8810)
* Add a ListJoiner * Add tests and release notes
1 parent d2348ad commit b0809b7

File tree

4 files changed

+182
-1
lines changed

4 files changed

+182
-1
lines changed

haystack/components/joiners/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .answer_joiner import AnswerJoiner
66
from .branch import BranchJoiner
77
from .document_joiner import DocumentJoiner
8+
from .list_joiner import ListJoiner
89
from .string_joiner import StringJoiner
910

10-
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner"]
11+
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner", "ListJoiner"]
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from itertools import chain
6+
from typing import Any, Dict, Type
7+
8+
from haystack import component, default_from_dict, default_to_dict
9+
from haystack.core.component.types import Variadic
10+
from haystack.utils import deserialize_type, serialize_type
11+
12+
13+
@component
14+
class ListJoiner:
15+
"""
16+
A component that joins multiple lists into a single flat list.
17+
18+
The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list.
19+
The output order respects the pipeline's execution sequence, with earlier inputs being added first.
20+
21+
Usage example:
22+
```python
23+
from haystack.components.builders import ChatPromptBuilder
24+
from haystack.components.generators.chat import OpenAIChatGenerator
25+
from haystack.dataclasses import ChatMessage
26+
from haystack import Pipeline
27+
from haystack.components.joiners import ListJoiner
28+
from typing import List
29+
30+
31+
user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")]
32+
33+
feedback_prompt = \"""
34+
You are given a question and an answer.
35+
Your task is to provide a score and a brief feedback on the answer.
36+
Question: {{query}}
37+
Answer: {{response}}
38+
\"""
39+
feedback_message = [ChatMessage.from_system(feedback_prompt)]
40+
41+
prompt_builder = ChatPromptBuilder(template=user_message)
42+
feedback_prompt_builder = ChatPromptBuilder(template=feedback_message)
43+
llm = OpenAIChatGenerator(model="gpt-4o-mini")
44+
feedback_llm = OpenAIChatGenerator(model="gpt-4o-mini")
45+
46+
pipe = Pipeline()
47+
pipe.add_component("prompt_builder", prompt_builder)
48+
pipe.add_component("llm", llm)
49+
pipe.add_component("feedback_prompt_builder", feedback_prompt_builder)
50+
pipe.add_component("feedback_llm", feedback_llm)
51+
pipe.add_component("list_joiner", ListJoiner(List[ChatMessage]))
52+
53+
pipe.connect("prompt_builder.prompt", "llm.messages")
54+
pipe.connect("prompt_builder.prompt", "list_joiner")
55+
pipe.connect("llm.replies", "list_joiner")
56+
pipe.connect("llm.replies", "feedback_prompt_builder.response")
57+
pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages")
58+
pipe.connect("feedback_llm.replies", "list_joiner")
59+
60+
query = "What is nuclear physics?"
61+
ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}},
62+
"feedback_prompt_builder": {"template_variables":{"query": query}}})
63+
64+
print(ans["list_joiner"]["values"])
65+
```
66+
"""
67+
68+
def __init__(self, list_type_: Type):
69+
"""
70+
Creates a ListJoiner component.
71+
72+
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
73+
All input lists must be of this type.
74+
"""
75+
self.list_type_ = list_type_
76+
component.set_output_types(self, values=list_type_)
77+
78+
def to_dict(self) -> Dict[str, Any]:
79+
"""
80+
Serializes the component to a dictionary.
81+
82+
:returns: Dictionary with serialized data.
83+
"""
84+
return default_to_dict(self, list_type_=serialize_type(self.list_type_))
85+
86+
@classmethod
87+
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
88+
"""
89+
Deserializes the component from a dictionary.
90+
91+
:param data: Dictionary to deserialize from.
92+
:returns: Deserialized component.
93+
"""
94+
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
95+
return default_from_dict(cls, data)
96+
97+
def run(self, values: Variadic[Any]) -> Dict[str, Any]:
98+
"""
99+
Joins multiple lists into a single flat list.
100+
101+
:param values:The list to be joined.
102+
:returns: Dictionary with 'values' key containing the joined list.
103+
"""
104+
result = list(chain(*values))
105+
return {"values": result}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Added a new component `ListJoiner` which joins lists of values from different components to a single list.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import List
6+
7+
from haystack import Document
8+
from haystack.dataclasses import ChatMessage
9+
from haystack.dataclasses.answer import GeneratedAnswer
10+
from haystack.components.joiners.list_joiner import ListJoiner
11+
12+
13+
class TestListJoiner:
14+
def test_init(self):
15+
joiner = ListJoiner(List[ChatMessage])
16+
assert isinstance(joiner, ListJoiner)
17+
assert joiner.list_type_ == List[ChatMessage]
18+
19+
def test_to_dict(self):
20+
joiner = ListJoiner(List[ChatMessage])
21+
data = joiner.to_dict()
22+
assert data == {
23+
"type": "haystack.components.joiners.list_joiner.ListJoiner",
24+
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
25+
}
26+
27+
def test_from_dict(self):
28+
data = {
29+
"type": "haystack.components.joiners.list_joiner.ListJoiner",
30+
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
31+
}
32+
list_joiner = ListJoiner.from_dict(data)
33+
assert isinstance(list_joiner, ListJoiner)
34+
assert list_joiner.list_type_ == List[ChatMessage]
35+
36+
def test_empty_list(self):
37+
joiner = ListJoiner(List[ChatMessage])
38+
result = joiner.run([])
39+
assert result == {"values": []}
40+
41+
def test_list_of_empty_lists(self):
42+
joiner = ListJoiner(List[ChatMessage])
43+
result = joiner.run([[], []])
44+
assert result == {"values": []}
45+
46+
def test_single_list_of_chat_messages(self):
47+
joiner = ListJoiner(List[ChatMessage])
48+
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
49+
result = joiner.run([messages])
50+
assert result == {"values": messages}
51+
52+
def test_multiple_lists_of_chat_messages(self):
53+
joiner = ListJoiner(List[ChatMessage])
54+
messages1 = [ChatMessage.from_user("Hello")]
55+
messages2 = [ChatMessage.from_assistant("Hi there")]
56+
messages3 = [ChatMessage.from_system("System message")]
57+
result = joiner.run([messages1, messages2, messages3])
58+
assert result == {"values": messages1 + messages2 + messages3}
59+
60+
def test_list_of_generated_answers(self):
61+
joiner = ListJoiner(List[GeneratedAnswer])
62+
answers1 = [GeneratedAnswer(query="q1", data="a1", meta={}, documents=[Document(content="d1")])]
63+
answers2 = [GeneratedAnswer(query="q2", data="a2", meta={}, documents=[Document(content="d2")])]
64+
result = joiner.run([answers1, answers2])
65+
assert result == {"values": answers1 + answers2}
66+
67+
def test_mixed_empty_and_non_empty_lists(self):
68+
joiner = ListJoiner(List[ChatMessage])
69+
messages = [ChatMessage.from_user("Hello")]
70+
result = joiner.run([messages, [], messages])
71+
assert result == {"values": messages + messages}

0 commit comments

Comments
 (0)