Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Look through all streaming chunks for tools calls #8829

Merged
merged 13 commits into from
Feb 11, 2025
58 changes: 29 additions & 29 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,46 +337,46 @@ def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
finish_reason=meta["finish_reason"],
)

def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
def _convert_streaming_chunks_to_chat_message(
self, chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.
"""

text = "".join([chunk.content for chunk in chunks])
tool_calls = []

# if it's a tool call , we need to build the payload dict from all the chunks
if bool(chunks[0].meta.get("tool_calls")):
tools_len = len(chunks[0].meta.get("tool_calls", []))

payloads = [{"arguments": "", "name": ""} for _ in range(tools_len)]
for chunk_payload in chunks:
deltas = chunk_payload.meta.get("tool_calls") or []
# Process tool calls if present in any chunk
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by ID
for chunk_payload in chunks:
tool_calls_meta = chunk_payload.meta.get("tool_calls")
if tool_calls_meta:
for delta in tool_calls_meta:
if not delta.id in tool_call_data:
tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""}

# deltas is a list of ChoiceDeltaToolCall or ChoiceDeltaFunctionCall
for i, delta in enumerate(deltas):
payloads[i]["id"] = delta.id or payloads[i].get("id", "")
if delta.function:
payloads[i]["name"] += delta.function.name or ""
payloads[i]["arguments"] += delta.function.arguments or ""

for payload in payloads:
arguments_str = payload["arguments"]
try:
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=payload["id"],
_name=payload["name"],
_arguments=arguments_str,
)
if delta.function.name:
tool_call_data[delta.id]["name"] = delta.function.name
if delta.function.arguments:
tool_call_data[delta.id]["arguments"] = delta.function.arguments

# Convert accumulated tool call data into ToolCall objects
for call_data in tool_call_data.values():
try:
arguments = json.loads(call_data["arguments"])
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=call_data["id"],
_name=call_data["name"],
_arguments=call_data["arguments"],
)

meta = {
"model": chunk.model,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Improved `OpenAIChatGenerator` streaming response tool call processing: The logic now scans all chunks to correctly identify the first chunk with tool calls, ensuring accurate payload construction and preventing errors when tool call data isn’t confined to the initial chunk.
91 changes: 91 additions & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,94 @@ def test_live_run_with_tools(self, tools):
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"""Test that tool calls can be found in any chunk of the streaming response."""
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))

# Create a list of chunks where tool calls appear in different positions
chunks = [
# First chunk has no tool calls
StreamingChunk("Hello! Let me help you with that. "),
# Second chunk has the first tool call
StreamingChunk("I'll check the weather. "),
# Third chunk has no tool calls
StreamingChunk("Now, let me check another city. "),
# Fourth chunk has another tool call
StreamingChunk(""),
]

# Add received_at to first chunk
chunks[0].meta["received_at"] = "2024-02-07T14:21:47.446186Z"

# Add tool calls meta to second chunk
chunks[1].meta["tool_calls"] = [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_1",
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "London"}'
),
)
]

# Add tool calls meta to fourth chunk
chunks[3].meta["tool_calls"] = [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0, # Same index as first tool call since it's the same function
id="call_1", # Same ID as first tool call since it's the same function
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "Paris"}'
),
)
]

# Add required meta information to the last chunk
chunks[-1].meta.update({"model": "gpt-4", "index": 0, "finish_reason": "tool_calls"})

# Create the final ChatCompletionChunk that would be passed as the first parameter
final_chunk = ChatCompletionChunk(
id="chatcmpl-123",
model="gpt-4",
object="chat.completion.chunk",
created=1234567890,
choices=[
chat_completion_chunk.Choice(
index=0,
finish_reason="tool_calls",
delta=chat_completion_chunk.ChoiceDelta(
tool_calls=[
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_1",
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "Paris"}'
),
)
]
),
)
],
)

# Convert chunks to a chat message
result = component._convert_streaming_chunks_to_chat_message(final_chunk, chunks)

# Verify the content is concatenated correctly
expected_text = "Hello! Let me help you with that. I'll check the weather. Now, let me check another city. "
assert result.text == expected_text

# Verify both tool calls were found and processed
assert len(result.tool_calls) == 1 # Now we expect only one tool call since they have the same ID
assert result.tool_calls[0].id == "call_1"
assert result.tool_calls[0].tool_name == "get_weather"
assert result.tool_calls[0].arguments == {"city": "Paris"} # The last value overwrites the previous one

# Verify meta information
assert result.meta["model"] == "gpt-4"
assert result.meta["finish_reason"] == "tool_calls"
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] == "2024-02-07T14:21:47.446186Z"