|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from itertools import chain |
| 6 | +from typing import Any, Dict, Type |
| 7 | + |
| 8 | +from haystack import component, default_from_dict, default_to_dict |
| 9 | +from haystack.core.component.types import Variadic |
| 10 | +from haystack.utils import deserialize_type, serialize_type |
| 11 | + |
| 12 | + |
| 13 | +@component |
| 14 | +class ListJoiner: |
| 15 | + """ |
| 16 | + A component that joins multiple lists into a single flat list. |
| 17 | +
|
| 18 | + The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list. |
| 19 | + The output order respects the pipeline's execution sequence, with earlier inputs being added first. |
| 20 | +
|
| 21 | + Usage example: |
| 22 | + ```python |
| 23 | + from haystack.components.builders import ChatPromptBuilder |
| 24 | + from haystack.components.generators.chat import OpenAIChatGenerator |
| 25 | + from haystack.dataclasses import ChatMessage |
| 26 | + from haystack import Pipeline |
| 27 | + from haystack.components.joiners import ListJoiner |
| 28 | + from typing import List |
| 29 | +
|
| 30 | +
|
| 31 | + user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")] |
| 32 | +
|
| 33 | + feedback_prompt = \""" |
| 34 | + You are given a question and an answer. |
| 35 | + Your task is to provide a score and a brief feedback on the answer. |
| 36 | + Question: {{query}} |
| 37 | + Answer: {{response}} |
| 38 | + \""" |
| 39 | + feedback_message = [ChatMessage.from_system(feedback_prompt)] |
| 40 | +
|
| 41 | + prompt_builder = ChatPromptBuilder(template=user_message) |
| 42 | + feedback_prompt_builder = ChatPromptBuilder(template=feedback_message) |
| 43 | + llm = OpenAIChatGenerator(model="gpt-4o-mini") |
| 44 | + feedback_llm = OpenAIChatGenerator(model="gpt-4o-mini") |
| 45 | +
|
| 46 | + pipe = Pipeline() |
| 47 | + pipe.add_component("prompt_builder", prompt_builder) |
| 48 | + pipe.add_component("llm", llm) |
| 49 | + pipe.add_component("feedback_prompt_builder", feedback_prompt_builder) |
| 50 | + pipe.add_component("feedback_llm", feedback_llm) |
| 51 | + pipe.add_component("list_joiner", ListJoiner(List[ChatMessage])) |
| 52 | +
|
| 53 | + pipe.connect("prompt_builder.prompt", "llm.messages") |
| 54 | + pipe.connect("prompt_builder.prompt", "list_joiner") |
| 55 | + pipe.connect("llm.replies", "list_joiner") |
| 56 | + pipe.connect("llm.replies", "feedback_prompt_builder.response") |
| 57 | + pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages") |
| 58 | + pipe.connect("feedback_llm.replies", "list_joiner") |
| 59 | +
|
| 60 | + query = "What is nuclear physics?" |
| 61 | + ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}}, |
| 62 | + "feedback_prompt_builder": {"template_variables":{"query": query}}}) |
| 63 | +
|
| 64 | + print(ans["list_joiner"]["values"]) |
| 65 | + ``` |
| 66 | + """ |
| 67 | + |
| 68 | + def __init__(self, list_type_: Type): |
| 69 | + """ |
| 70 | + Creates a ListJoiner component. |
| 71 | +
|
| 72 | + :param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]). |
| 73 | + All input lists must be of this type. |
| 74 | + """ |
| 75 | + self.list_type_ = list_type_ |
| 76 | + component.set_output_types(self, values=list_type_) |
| 77 | + |
| 78 | + def to_dict(self) -> Dict[str, Any]: |
| 79 | + """ |
| 80 | + Serializes the component to a dictionary. |
| 81 | +
|
| 82 | + :returns: Dictionary with serialized data. |
| 83 | + """ |
| 84 | + return default_to_dict(self, list_type_=serialize_type(self.list_type_)) |
| 85 | + |
| 86 | + @classmethod |
| 87 | + def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner": |
| 88 | + """ |
| 89 | + Deserializes the component from a dictionary. |
| 90 | +
|
| 91 | + :param data: Dictionary to deserialize from. |
| 92 | + :returns: Deserialized component. |
| 93 | + """ |
| 94 | + data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) |
| 95 | + return default_from_dict(cls, data) |
| 96 | + |
| 97 | + def run(self, values: Variadic[Any]) -> Dict[str, Any]: |
| 98 | + """ |
| 99 | + Joins multiple lists into a single flat list. |
| 100 | +
|
| 101 | + :param values:The list to be joined. |
| 102 | + :returns: Dictionary with 'values' key containing the joined list. |
| 103 | + """ |
| 104 | + result = list(chain(*values)) |
| 105 | + return {"values": result} |
0 commit comments