Skip to content

Commit

Permalink
Merge pull request #1883 from SciPhi-AI/features/fix-tool-calling
Browse files Browse the repository at this point in the history
fix tool calling
  • Loading branch information
emrgnt-cmplxty authored Jan 25, 2025
2 parents 956b738 + b17e19e commit 82fb5d9
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 270 deletions.
197 changes: 83 additions & 114 deletions py/core/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,158 +158,120 @@ async def process_llm_response(
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Modified to:
1) Collect partial tool calls in a dict keyed by their .index
2) Execute them in parallel (asyncio.gather) once finish_reason="tool_calls"
Updated to:
1) Accumulate interleaved content and tool calls gracefully.
2) Finalize content even if no tool calls are made.
3) Support processing of both content and tool calls in parallel.
"""
# Dictionary:
# pending_tool_calls[index] = {
# "id": str or None,
# "name": str,
# "arguments": str,
# }
pending_tool_calls = {}

# For single function_call logic
content_buffer = ""
function_name = None
function_arguments = ""

# Buffer for normal text
content_buffer = ""
tool_calls_active = False

async for chunk in stream:
print("chunk = ", chunk)
delta = chunk.choices[0].delta
finish_reason = chunk.choices[0].finish_reason

# 1) Handle partial tool_calls
# 1) Handle interleaved tool_calls
if delta.tool_calls:
tool_calls_active = True
for tc in delta.tool_calls:
idx = tc.index
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {
"id": tc.id, # might be None
"id": tc.id, # could be None
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
else:
# Accumulate partial arguments
# Accumulate partial tool call details
if tc.function.name:
pending_tool_calls[idx]["name"] = tc.function.name
if tc.function.arguments:
pending_tool_calls[idx][
"arguments"
] += tc.function.arguments
# If we see an ID on a later chunk, set it now
# Set the ID if it appears in later chunks
if tc.id and not pending_tool_calls[idx]["id"]:
pending_tool_calls[idx]["id"] = tc.id

# 2) Handle partial function_call
# 2) Handle partial function_call (single-call logic)
if delta.function_call:
if delta.function_call.name:
function_name = delta.function_call.name
if delta.function_call.arguments:
function_arguments += delta.function_call.arguments

# 3) Handle normal text
# 3) Handle normal content
elif delta.content:
if content_buffer == "":
if not content_buffer:
yield "<completion>"
content_buffer += delta.content
yield delta.content

# 4) Check finish_reason
finish_reason = chunk.choices[0].finish_reason

# 4) Check finish_reason for tool calls
if finish_reason == "tool_calls":
# The model has finished specifying this entire set of tool calls in an assistant message.
if not pending_tool_calls:
logger.warning(
"Got finish_reason=tool_calls but no pending tool calls."
# Finalize the tool calls
calls_list = []
sorted_indexes = sorted(pending_tool_calls.keys())
for idx in sorted_indexes:
call_info = pending_tool_calls[idx]
call_id = call_info["id"] or f"call_{idx}"
calls_list.append(
{
"id": call_id,
"type": "function",
"function": {
"name": call_info["name"],
"arguments": call_info["arguments"],
},
}
)
else:
# 4a) Build a single 'assistant' message with all tool_calls
calls_list = []
# Sort by index to ensure consistent ordering
sorted_indexes = sorted(pending_tool_calls.keys())
for idx in sorted_indexes:
call_info = pending_tool_calls[idx]
call_id = (
call_info["id"]
if call_info["id"]
else f"call_{idx}"
)
calls_list.append(
{
"id": call_id,
"type": "function",
"function": {
"name": call_info["name"],
"arguments": call_info["arguments"],
},
}
)

assistant_msg = Message(
role="assistant",
content=content_buffer or None,
tool_calls=calls_list,

assistant_msg = Message(
role="assistant",
content=content_buffer or None,
tool_calls=calls_list,
)
await self.conversation.add_message(assistant_msg)

# Execute tool calls in parallel
async_calls = [
self.handle_function_or_tool_call(
call_info["name"],
call_info["arguments"],
tool_id=(call_info["id"] or f"call_{idx}"),
*args,
**kwargs,
)
await self.conversation.add_message(assistant_msg)

# 4b) Execute them in parallel using asyncio.gather
async_calls = []
for idx in sorted_indexes:
call_info = pending_tool_calls[idx]
call_id = call_info["id"] or f"call_{idx}"
async_calls.append(
self.handle_function_or_tool_call(
call_info["name"],
call_info["arguments"],
tool_id=call_id,
*args,
**kwargs,
)
)
results = await asyncio.gather(*async_calls)

# 4c) Now yield the <tool_call> blocks in the same order
for idx, tool_result in zip(sorted_indexes, results):
# We re-lookup the name, arguments, id
call_info = pending_tool_calls[idx]
call_id = call_info["id"] or f"call_{idx}"
call_name = call_info["name"]
call_args = call_info["arguments"]

yield "<tool_call>"
yield f"<name>{call_name}</name>"
yield f"<arguments>{call_args}</arguments>"

if tool_result.stream_result:
yield f"<results>{tool_result.stream_result}</results>"
else:
yield f"<results>{tool_result.llm_formatted_result}</results>"

yield "</tool_call>"

# 4d) Add a role="function" message
await self.conversation.add_message(
Message(
role="function",
name=call_id,
content=tool_result.llm_formatted_result,
)
)

# 4e) Reset
pending_tool_calls.clear()
content_buffer = ""
for idx, call_info in pending_tool_calls.items()
]
results = await asyncio.gather(*async_calls)

# Yield tool call results
for idx, tool_result in zip(sorted_indexes, results):
call_info = pending_tool_calls[idx]
yield "<tool_call>"
yield f"<name>{call_info['name']}</name>"
yield f"<arguments>{call_info['arguments']}</arguments>"
if tool_result.stream_result:
yield f"<results>{tool_result.stream_result}</results>"
else:
yield f"<results>{tool_result.llm_formatted_result}</results>"
yield "</tool_call>"

# Clear the tool call state
pending_tool_calls.clear()
content_buffer = ""

elif finish_reason == "function_call":
# Single function call approach
# Single function call handling
if not function_name:
logger.info("Function name not found in function call.")
logger.warning("Function name not found in function call.")
continue

# Add the assistant message with function_call
assistant_msg = Message(
role="assistant",
content=content_buffer if content_buffer else None,
Expand All @@ -331,31 +293,38 @@ async def process_llm_response(
yield f"<results>{tool_result.stream_result}</results>"
else:
yield f"<results>{tool_result.llm_formatted_result}</results>"

yield "</function_call>"

# Add a function-role message
await self.conversation.add_message(
Message(
role="function",
name=function_name,
content=tool_result.llm_formatted_result,
)
)

function_name = None
function_arguments = ""
content_buffer = ""
function_name, function_arguments, content_buffer = (
None,
"",
"",
)

elif finish_reason == "stop":
# The model is done producing text
# Finalize content if streaming stops
if content_buffer:
await self.conversation.add_message(
Message(role="assistant", content=content_buffer)
)
self._completed = True
yield "</completion>"

# If the stream ends without `finish_reason=stop`
if not self._completed and content_buffer:
await self.conversation.add_message(
Message(role="assistant", content=content_buffer)
)
self._completed = True
yield "</completion>"

# After the stream ends
if content_buffer and not self._completed:
await self.conversation.add_message(
Expand Down
Loading

0 comments on commit 82fb5d9

Please sign in to comment.