Skip to content

Commit 64f384b

Browse files
Amnah199sjrl
andauthored
feat: enable streaming ToolCall/Result from Agent (#9290)
* Testing solutions for streaming * Remove unused methods * Add fixes * Update docstrings * add release notes and test * PR comments * add a new util function * Adjust emit_tool_info * PR comments * Remove emit function, add streaming for tool_call --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 7db7199 commit 64f384b

File tree

7 files changed

+120
-12
lines changed

7 files changed

+120
-12
lines changed

haystack/components/agents/agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from haystack.dataclasses import ChatMessage
1616
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1717
from haystack.dataclasses.state_utils import merge_lists
18-
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
18+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
1919
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
2020
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2121
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
@@ -84,6 +84,7 @@ def __init__(
8484
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
8585
If set to False, the exception will be turned into a chat message and passed to the LLM.
8686
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
87+
The same callback can be configured to emit tool results when a tool is called.
8788
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
8889
"""
8990
# Check if chat_generator supports tools parameter
@@ -201,9 +202,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent":
201202
def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
202203
"""Prepare inputs for the chat generator."""
203204
generator_inputs: Dict[str, Any] = {"tools": self.tools}
204-
selected_callback = streaming_callback or self.streaming_callback
205-
if selected_callback is not None:
206-
generator_inputs["streaming_callback"] = selected_callback
205+
if streaming_callback is not None:
206+
generator_inputs["streaming_callback"] = streaming_callback
207207
return generator_inputs
208208

209209
def _create_agent_span(self) -> Any:
@@ -229,6 +229,7 @@ def run(
229229
230230
:param messages: List of chat messages to process
231231
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
232+
The same callback can be configured to emit tool results when a tool is called.
232233
:param kwargs: Additional data to pass to the State schema used by the Agent.
233234
The keys must match the schema defined in the Agent's `state_schema`.
234235
:return: Dictionary containing messages and outputs matching the defined output types
@@ -239,6 +240,10 @@ def run(
239240
if self.system_prompt is not None:
240241
messages = [ChatMessage.from_system(self.system_prompt)] + messages
241242

243+
streaming_callback = select_streaming_callback(
244+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
245+
)
246+
242247
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
243248

244249
state = State(schema=self.state_schema, data=kwargs)
@@ -271,7 +276,7 @@ def run(
271276
tool_invoker_result = Pipeline._run_component(
272277
component_name="tool_invoker",
273278
component={"instance": self._tool_invoker},
274-
inputs={"messages": llm_messages, "state": state},
279+
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
275280
component_visits=component_visits,
276281
parent_span=span,
277282
)
@@ -312,6 +317,7 @@ async def run_async(
312317
313318
:param messages: List of chat messages to process
314319
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
320+
The same callback can be configured to emit tool results when a tool is called.
315321
:param kwargs: Additional data to pass to the State schema used by the Agent.
316322
The keys must match the schema defined in the Agent's `state_schema`.
317323
:return: Dictionary containing messages and outputs matching the defined output types
@@ -322,6 +328,10 @@ async def run_async(
322328
if self.system_prompt is not None:
323329
messages = [ChatMessage.from_system(self.system_prompt)] + messages
324330

331+
streaming_callback = select_streaming_callback(
332+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
333+
)
334+
325335
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
326336

327337
state = State(schema=self.state_schema, data=kwargs)

haystack/components/generators/chat/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def run(
279279
chat_completion, # type: ignore
280280
streaming_callback, # type: ignore
281281
)
282+
282283
else:
283284
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
284285
completions = [
@@ -355,6 +356,7 @@ async def run_async(
355356
chat_completion, # type: ignore
356357
streaming_callback, # type: ignore
357358
)
359+
358360
else:
359361
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
360362
completions = [

haystack/components/generators/utils.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,47 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
6+
57
from haystack.dataclasses import StreamingChunk
68

79

810
def print_streaming_chunk(chunk: StreamingChunk) -> None:
911
"""
10-
Default callback function for streaming responses.
12+
Callback function to handle and display streaming output chunks.
13+
14+
This function processes a `StreamingChunk` object by:
15+
- Printing tool call metadata (if any), including function names and arguments, as they arrive.
16+
- Printing tool call results when available.
17+
- Printing the main content (e.g., text tokens) of the chunk as it is received.
18+
19+
The function outputs data directly to stdout and flushes output buffers to ensure immediate display during
20+
streaming.
1121
12-
Prints the tokens of the first completion to stdout as soon as they are received
22+
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
23+
tool results.
1324
"""
14-
print(chunk.content, flush=True, end="")
25+
# Print tool call metadata if available (from ChatGenerator)
26+
if chunk.meta.get("tool_calls"):
27+
for tool_call in chunk.meta["tool_calls"]:
28+
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
29+
# print the tool name
30+
if tool_call.function.name and not tool_call.function.arguments:
31+
print("[TOOL CALL]\n", flush=True, end="")
32+
print(f"Tool: {tool_call.function.name} ", flush=True, end="")
33+
34+
# print the tool arguments
35+
if tool_call.function.arguments:
36+
if tool_call.function.arguments.startswith("{"):
37+
print("\nArguments: ", flush=True, end="")
38+
print(tool_call.function.arguments, flush=True, end="")
39+
if tool_call.function.arguments.endswith("}"):
40+
print("\n\n", flush=True, end="")
41+
42+
# Print tool call results if available (from ToolInvoker)
43+
if chunk.meta.get("tool_result"):
44+
print(f"[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="")
45+
46+
# Print the main content of the chunk (from ChatGenerator)
47+
if chunk.content:
48+
print(chunk.content, flush=True, end="")

haystack/components/tools/tool_invoker.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from haystack import component, default_from_dict, default_to_dict, logging
1010
from haystack.core.component.sockets import Sockets
1111
from haystack.dataclasses import ChatMessage, State, ToolCall
12+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1213
from haystack.tools import (
1314
ComponentTool,
1415
Tool,
@@ -159,6 +160,7 @@ def __init__(
159160
tools: Union[List[Tool], Toolset],
160161
raise_on_failure: bool = True,
161162
convert_result_to_json_string: bool = False,
163+
streaming_callback: Optional[StreamingCallbackT] = None,
162164
):
163165
"""
164166
Initialize the ToolInvoker component.
@@ -173,6 +175,10 @@ def __init__(
173175
:param convert_result_to_json_string:
174176
If True, the tool invocation result will be converted to a string using `json.dumps`.
175177
If False, the tool invocation result will be converted to a string using `str`.
178+
:param streaming_callback:
179+
A callback function that will be called to emit tool results.
180+
Note that the result is only emitted once it becomes available — it is not
181+
streamed incrementally in real time.
176182
:raises ValueError:
177183
If no tools are provided or if duplicate tool names are found.
178184
"""
@@ -181,6 +187,7 @@ def __init__(
181187

182188
# could be a Toolset instance or a list of Tools
183189
self.tools = tools
190+
self.streaming_callback = streaming_callback
184191

185192
# Convert Toolset to list for internal use
186193
if isinstance(tools, Toolset):
@@ -272,7 +279,6 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
272279
except StringConversionError as conversion_error:
273280
# If _handle_error re-raises, this properly preserves the chain
274281
raise conversion_error from e
275-
276282
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
277283

278284
@staticmethod
@@ -358,13 +364,21 @@ def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
358364
state.set(state_key, output_value, handler_override=handler)
359365

360366
@component.output_types(tool_messages=List[ChatMessage], state=State)
361-
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
367+
def run(
368+
self,
369+
messages: List[ChatMessage],
370+
state: Optional[State] = None,
371+
streaming_callback: Optional[StreamingCallbackT] = None,
372+
) -> Dict[str, Any]:
362373
"""
363374
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
364375
365376
:param messages:
366377
A list of ChatMessage objects.
367378
:param state: The runtime state that should be used by the tools.
379+
:param streaming_callback: A callback function that will be called to emit tool results.
380+
Note that the result is only emitted once it becomes available — it is not
381+
streamed incrementally in real time.
368382
:returns:
369383
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
370384
Each ChatMessage objects wraps the result of a tool invocation.
@@ -383,6 +397,9 @@ def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dic
383397

384398
# Only keep messages with tool calls
385399
messages_with_tool_calls = [message for message in messages if message.tool_calls]
400+
streaming_callback = select_streaming_callback(
401+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
402+
)
386403

387404
tool_messages = []
388405
for message in messages_with_tool_calls:
@@ -406,6 +423,7 @@ def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dic
406423
# 2) Invoke the tool
407424
try:
408425
tool_result = tool_to_invoke.invoke(**final_args)
426+
409427
except ToolInvocationError as e:
410428
error_message = self._handle_error(e)
411429
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
@@ -434,6 +452,11 @@ def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dic
434452
)
435453
)
436454

455+
if streaming_callback is not None:
456+
streaming_callback(
457+
StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call})
458+
)
459+
437460
return {"tool_messages": tool_messages, "state": state}
438461

439462
def to_dict(self) -> Dict[str, Any]:
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
features:
3+
- |
4+
Add a `streaming_callback` parameter to `ToolInvoker` to enable streaming of tool results.
5+
Note that tool_result is emitted only after the tool execution completes and is not streamed incrementally.
6+
7+
- Update `print_streaming_chunk` to print ToolCall information if it is present in the chunk's metadata.
8+
9+
- Update `Agent` to forward the `streaming_callback` to `ToolInvoker` to emit tool results during tool invocation.

test/components/agents/test_agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from haystack.dataclasses import ChatMessage, ToolCall
2525
from haystack.dataclasses.chat_message import ChatRole, TextContent
2626
from haystack.dataclasses.streaming_chunk import StreamingChunk
27+
2728
from haystack.tools import Tool, ComponentTool
2829
from haystack.tools.toolset import Toolset
2930
from haystack.utils import serialize_callable, Secret
@@ -778,6 +779,25 @@ async def test_run_async_uses_chat_generator_run_async_when_available(self, weat
778779
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
779780
assert "Hello from run_async" in result["messages"][1].text
780781

782+
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
783+
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
784+
chat_generator = OpenAIChatGenerator()
785+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
786+
agent.warm_up()
787+
streaming_callback_called = False
788+
789+
def streaming_callback(chunk: StreamingChunk) -> None:
790+
nonlocal streaming_callback_called
791+
streaming_callback_called = True
792+
793+
result = agent.run(
794+
[ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback
795+
)
796+
797+
assert result is not None
798+
assert result["messages"] is not None
799+
assert streaming_callback_called
800+
781801

782802
class TestAgentTracing:
783803
def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):

test/components/tools/test_tool_invoker.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from haystack.dataclasses.state import State
1212
from haystack.tools import ComponentTool, Tool, Toolset
1313
from haystack.tools.errors import ToolInvocationError
14+
from haystack.dataclasses import StreamingChunk
1415

1516

1617
def weather_function(location):
@@ -162,14 +163,23 @@ def test_inject_state_args_param_in_state_and_llm(self):
162163
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
163164
assert args == {"location": "Paris"}
164165

165-
def test_run(self, invoker):
166+
def test_run_with_streaming_callback(self, invoker):
167+
streaming_callback_called = False
168+
169+
def streaming_callback(chunk: StreamingChunk) -> None:
170+
nonlocal streaming_callback_called
171+
streaming_callback_called = True
172+
166173
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
167174
message = ChatMessage.from_assistant(tool_calls=[tool_call])
168175

169-
result = invoker.run(messages=[message])
176+
result = invoker.run(messages=[message], streaming_callback=streaming_callback)
170177
assert "tool_messages" in result
171178
assert len(result["tool_messages"]) == 1
172179

180+
# check we called the streaming callback
181+
assert streaming_callback_called
182+
173183
tool_message = result["tool_messages"][0]
174184
assert isinstance(tool_message, ChatMessage)
175185
assert tool_message.is_from(ChatRole.TOOL)

0 commit comments

Comments
 (0)