Skip to content

Commit

Permalink
Merge pull request #1919 from SciPhi-AI/feature/extend-reasoning-models
Browse files Browse the repository at this point in the history
extending reasoning models
  • Loading branch information
emrgnt-cmplxty authored Jan 31, 2025
2 parents 5941e72 + 28c441c commit 7b9a196
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 113 deletions.
2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.4.21",
"version": "0.4.22",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
4 changes: 2 additions & 2 deletions js/sdk/src/v3/clients/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async reasoningAgent(options: {
if (options.ragGenerationConfig && options.ragGenerationConfig.stream) {
return this.streamReasoningAgent(data);
} else {
return await this.client.makeRequest("POST", "retrieval/reasoning_agent", {
return await this.client.makeRequest("POST", "retrieval/rawr", {
data: data,
});
}
Expand All @@ -339,7 +339,7 @@ private async streamReasoningAgent(
): Promise<ReadableStream<Uint8Array>> {
return this.client.makeRequest<ReadableStream<Uint8Array>>(
"POST",
"retrieval/reasoning_agent",
"retrieval/rawr",
{
data: agentData,
headers: {
Expand Down
207 changes: 128 additions & 79 deletions py/core/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,162 +322,211 @@ async def process_llm_response(
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Updated to:
1) Accumulate interleaved content and tool calls gracefully.
2) Finalize content even if no tool calls are made.
3) Support processing of both content and tool calls in parallel.
Revised processing for the reasoning agent.
This version:
1. Accumulates tool calls in a list (each with a unique internal_id).
2. When finish_reason == "tool_calls", it records the tool calls in the conversation,
emits Thought messages, and then executes all calls in parallel.
3. Most importantly, it then yields a matching tool result block (with the same id)
for each tool call so that Anthropic sees a proper correspondence.
"""
pending_tool_calls = {}
pending_calls = (
[]
) # list of dicts: each has "internal_id", "original_id", "name", "arguments"
content_buffer = ""
function_arguments = ""

inside_thoughts = False

async for chunk in stream:
delta = chunk.choices[0].delta
if delta.content and delta.content.count(
"<Thought>"
) > delta.content.count("</Thought>"):
inside_thoughts = True
elif (
delta.content
and inside_thoughts
and delta.content.count("</Thought>")
> delta.content.count("<Thought>")
):
inside_thoughts = False
finish_reason = chunk.choices[0].finish_reason

# 1) Handle interleaved tool_calls
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {
"id": tc.id, # could be None
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
else:
# Accumulate partial tool call details
if tc.function.name:
pending_tool_calls[idx]["name"] = tc.function.name
if tc.function.arguments:
pending_tool_calls[idx][
"arguments"
] += tc.function.arguments
# Set the ID if it appears in later chunks
if tc.id and not pending_tool_calls[idx]["id"]:
pending_tool_calls[idx]["id"] = tc.id
# --- Update our chain-of-thought status based on <Thought> tags ---
if delta.content:
num_open = delta.content.count("<Thought>")
num_close = delta.content.count("</Thought>")
if num_open > num_close:
inside_thoughts = True
elif inside_thoughts and num_close >= num_open:
inside_thoughts = False

# 2) Handle partial function_call (single-call logic)
# --- 1. Process any incoming tool_calls ---
if delta.tool_calls:
if (
"anthropic" in self.rag_generation_config.model
or "claude" in self.rag_generation_config.model
):
for tc in delta.tool_calls:
original_id = tc.id if tc.id else None
# Check if an existing pending call with this original_id is incomplete.
found = None
for call in pending_calls:
if call["original_id"] == original_id:
# If the accumulated arguments do not appear complete (e.g. not ending with "}")
if not call["arguments"].strip().endswith("}"):
found = call
break
if found is not None:
if tc.function.name:
found["name"] = tc.function.name
if tc.function.arguments:
found["arguments"] += tc.function.arguments
else:
# Create a new call entry. If the original_id is reused,
# add a suffix so that each call gets a unique internal_id.
new_internal_id = (
original_id
if original_id
else f"call_{len(pending_calls)}"
)
if original_id is not None:
count = sum(
1
for call in pending_calls
if call["original_id"] == original_id
)
if count > 0:
new_internal_id = f"{original_id}_{count}"
pending_calls.append(
{
"internal_id": new_internal_id,
"original_id": original_id,
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
)
else:
for tc in delta.tool_calls:
idx = tc.index
if len(pending_calls) <= idx:
pending_calls.append(
{
"internal_id": tc.id, # could be None
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
)
else:
# Accumulate partial tool call details
if tc.function.arguments:
pending_calls[idx][
"arguments"
] += tc.function.arguments

# --- 2. Process a function_call (if any) ---
if delta.function_call:
if delta.function_call.name:
function_name = delta.function_call.name
if delta.function_call.arguments:
function_arguments += delta.function_call.arguments

# 3) Handle normal content
# --- 3. Process normal content tokens ---
elif delta.content:
content_buffer += delta.content
yield delta.content

# 4) Check finish_reason for tool calls
# --- 4. Finalize on finish_reason == "tool_calls" ---
if finish_reason == "tool_calls":
# Finalize the tool calls
# Build a list of tool call descriptors for the conversation message.
calls_list = []
sorted_indexes = sorted(pending_tool_calls.keys())
for idx in sorted_indexes:
call_info = pending_tool_calls[idx]
call_id = call_info["id"] or f"call_{idx}"
for call in pending_calls:
calls_list.append(
{
"id": call_id,
"id": call["internal_id"],
"type": "function",
"function": {
"name": call_info["name"],
"arguments": call_info["arguments"],
"name": call["name"],
"arguments": call["arguments"],
},
}
)

assistant_msg = Message(
role="assistant",
content=content_buffer or None,
tool_calls=calls_list,
)
await self.conversation.add_message(assistant_msg)

# Execute tool calls in parallel
for idx, tool_call in pending_tool_calls.items():
# Optionally emit a Thought message for each tool call.
for call in pending_calls:
if inside_thoughts:
yield "</Thought>"
yield "<Thought>"
name = tool_call["name"]
arguments = tool_call["arguments"]
yield f"Calling function: {name}, with payload {arguments}"
yield f"\n\nCalling function: {call['name']}, with payload {call['arguments']}"
yield "</Thought>"
if inside_thoughts:
yield "<Thought>"

# Execute all tool calls in parallel.
async_calls = [
self.handle_function_or_tool_call(
call_info["name"],
call_info["arguments"],
tool_id=(call_info["id"] or f"call_{idx}"),
call["name"],
call["arguments"],
tool_id=call["internal_id"],
*args,
**kwargs,
)
for idx, call_info in pending_tool_calls.items()
for call in pending_calls
]
await asyncio.gather(*async_calls)

# Clear the tool call state
pending_tool_calls.clear()
# Reset state after processing.
pending_calls = []
content_buffer = ""

# --- 5. Finalize on finish_reason == "stop" ---
elif finish_reason == "stop":
# Finalize content if streaming stops
if content_buffer:
await self.conversation.add_message(
Message(role="assistant", content=content_buffer)
)
elif pending_tool_calls:
# TODO - RM COPY PASTA.
elif pending_calls:
# In case there are pending calls not triggered by a tool_calls finish.
calls_list = []
sorted_indexes = sorted(pending_tool_calls.keys())
for idx in sorted_indexes:
call_info = pending_tool_calls[idx]
call_id = call_info["id"] or f"call_{idx}"
for call in pending_calls:
calls_list.append(
{
"id": call_id,
"id": call["internal_id"],
"type": "function",
"function": {
"name": call_info["name"],
"arguments": call_info["arguments"],
"name": call["name"],
"arguments": call["arguments"],
},
}
)

assistant_msg = Message(
role="assistant",
content=content_buffer or None,
tool_calls=calls_list,
)
await self.conversation.add_message(assistant_msg)
return

self._completed = True
return

# If the stream ends without `finish_reason=stop`
# --- Finalize if stream ends unexpectedly ---
if not self._completed and content_buffer:
await self.conversation.add_message(
Message(role="assistant", content=content_buffer)
)
self._completed = True

# After the stream ends
if content_buffer and not self._completed:
await self.conversation.add_message(
Message(role="assistant", content=content_buffer)
if not self._completed and pending_calls:
calls_list = []
for call in pending_calls:
calls_list.append(
{
"id": call["internal_id"],
"type": "function",
"function": {
"name": call["name"],
"arguments": call["arguments"],
},
}
)
assistant_msg = Message(
role="assistant",
content=content_buffer or None,
tool_calls=calls_list,
)
await self.conversation.add_message(assistant_msg)
self._completed = True
6 changes: 3 additions & 3 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ async def stream_generator():
raise R2RException(str(e), 500)

@self.router.post(
"/retrieval/reasoning_agent",
"/retrieval/rawr",
dependencies=[Depends(self.rate_limit_dependency)],
summary="Reasoning RAG Agent (Chain-of-Thought + Tools)",
openapi_extra={
Expand All @@ -717,7 +717,7 @@ async def stream_generator():
client = R2RClient()
# when using auth, do client.login(...)
response =client.retrieval.reasoning_agent(
response =client.retrieval.rawr(
message={
"role": "user",
"content": "What were the key contributions of Aristotle to logic and how did they influence later philosophers?"
Expand Down Expand Up @@ -852,7 +852,7 @@ async def reasoning_agent_app(
),
use_system_context=False,
override_tools=tools,
reasoning_agent=True,
rawr=True,
)

if rag_generation_config.stream:
Expand Down
Loading

0 comments on commit 7b9a196

Please sign in to comment.