diff --git a/js/sdk/package.json b/js/sdk/package.json index 5f991f7a7..7e9820d04 100644 --- a/js/sdk/package.json +++ b/js/sdk/package.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "0.4.21", + "version": "0.4.22", "description": "", "main": "dist/index.js", "browser": "dist/index.browser.js", diff --git a/js/sdk/src/v3/clients/retrieval.ts b/js/sdk/src/v3/clients/retrieval.ts index 011f33531..37d0a0d81 100644 --- a/js/sdk/src/v3/clients/retrieval.ts +++ b/js/sdk/src/v3/clients/retrieval.ts @@ -327,7 +327,7 @@ async reasoningAgent(options: { if (options.ragGenerationConfig && options.ragGenerationConfig.stream) { return this.streamReasoningAgent(data); } else { - return await this.client.makeRequest("POST", "retrieval/reasoning_agent", { + return await this.client.makeRequest("POST", "retrieval/rawr", { data: data, }); } @@ -339,7 +339,7 @@ private async streamReasoningAgent( ): Promise> { return this.client.makeRequest>( "POST", - "retrieval/reasoning_agent", + "retrieval/rawr", { data: agentData, headers: { diff --git a/py/core/agent/base.py b/py/core/agent/base.py index bb6e7e278..bc3fea47b 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -322,84 +322,125 @@ async def process_llm_response( **kwargs, ) -> AsyncGenerator[str, None]: """ - Updated to: - 1) Accumulate interleaved content and tool calls gracefully. - 2) Finalize content even if no tool calls are made. - 3) Support processing of both content and tool calls in parallel. + Revised processing for the reasoning agent. + This version: + 1. Accumulates tool calls in a list (each with a unique internal_id). + 2. When finish_reason == "tool_calls", it records the tool calls in the conversation, + emits Thought messages, and then executes all calls in parallel. + 3. Most importantly, it then yields a matching tool result block (with the same id) + for each tool call so that Anthropic sees a proper correspondence. """ - pending_tool_calls = {} + pending_calls = ( + [] + ) # list of dicts: each has "internal_id", "original_id", "name", "arguments" content_buffer = "" function_arguments = "" inside_thoughts = False + async for chunk in stream: delta = chunk.choices[0].delta - if delta.content and delta.content.count( - "" - ) > delta.content.count(""): - inside_thoughts = True - elif ( - delta.content - and inside_thoughts - and delta.content.count("") - > delta.content.count("") - ): - inside_thoughts = False finish_reason = chunk.choices[0].finish_reason - # 1) Handle interleaved tool_calls - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in pending_tool_calls: - pending_tool_calls[idx] = { - "id": tc.id, # could be None - "name": tc.function.name or "", - "arguments": tc.function.arguments or "", - } - else: - # Accumulate partial tool call details - if tc.function.name: - pending_tool_calls[idx]["name"] = tc.function.name - if tc.function.arguments: - pending_tool_calls[idx][ - "arguments" - ] += tc.function.arguments - # Set the ID if it appears in later chunks - if tc.id and not pending_tool_calls[idx]["id"]: - pending_tool_calls[idx]["id"] = tc.id + # --- Update our chain-of-thought status based on tags --- + if delta.content: + num_open = delta.content.count("") + num_close = delta.content.count("") + if num_open > num_close: + inside_thoughts = True + elif inside_thoughts and num_close >= num_open: + inside_thoughts = False - # 2) Handle partial function_call (single-call logic) + # --- 1. Process any incoming tool_calls --- + if delta.tool_calls: + if ( + "anthropic" in self.rag_generation_config.model + or "claude" in self.rag_generation_config.model + ): + for tc in delta.tool_calls: + original_id = tc.id if tc.id else None + # Check if an existing pending call with this original_id is incomplete. + found = None + for call in pending_calls: + if call["original_id"] == original_id: + # If the accumulated arguments do not appear complete (e.g. not ending with "}") + if not call["arguments"].strip().endswith("}"): + found = call + break + if found is not None: + if tc.function.name: + found["name"] = tc.function.name + if tc.function.arguments: + found["arguments"] += tc.function.arguments + else: + # Create a new call entry. If the original_id is reused, + # add a suffix so that each call gets a unique internal_id. + new_internal_id = ( + original_id + if original_id + else f"call_{len(pending_calls)}" + ) + if original_id is not None: + count = sum( + 1 + for call in pending_calls + if call["original_id"] == original_id + ) + if count > 0: + new_internal_id = f"{original_id}_{count}" + pending_calls.append( + { + "internal_id": new_internal_id, + "original_id": original_id, + "name": tc.function.name or "", + "arguments": tc.function.arguments or "", + } + ) + else: + for tc in delta.tool_calls: + idx = tc.index + if len(pending_calls) <= idx: + pending_calls.append( + { + "internal_id": tc.id, # could be None + "name": tc.function.name or "", + "arguments": tc.function.arguments or "", + } + ) + else: + # Accumulate partial tool call details + if tc.function.arguments: + pending_calls[idx][ + "arguments" + ] += tc.function.arguments + + # --- 2. Process a function_call (if any) --- if delta.function_call: if delta.function_call.name: function_name = delta.function_call.name if delta.function_call.arguments: function_arguments += delta.function_call.arguments - # 3) Handle normal content + # --- 3. Process normal content tokens --- elif delta.content: content_buffer += delta.content yield delta.content - # 4) Check finish_reason for tool calls + # --- 4. Finalize on finish_reason == "tool_calls" --- if finish_reason == "tool_calls": - # Finalize the tool calls + # Build a list of tool call descriptors for the conversation message. calls_list = [] - sorted_indexes = sorted(pending_tool_calls.keys()) - for idx in sorted_indexes: - call_info = pending_tool_calls[idx] - call_id = call_info["id"] or f"call_{idx}" + for call in pending_calls: calls_list.append( { - "id": call_id, + "id": call["internal_id"], "type": "function", "function": { - "name": call_info["name"], - "arguments": call_info["arguments"], + "name": call["name"], + "arguments": call["arguments"], }, } ) - assistant_msg = Message( role="assistant", content=content_buffer or None, @@ -407,77 +448,85 @@ async def process_llm_response( ) await self.conversation.add_message(assistant_msg) - # Execute tool calls in parallel - for idx, tool_call in pending_tool_calls.items(): + # Optionally emit a Thought message for each tool call. + for call in pending_calls: if inside_thoughts: yield "" yield "" - name = tool_call["name"] - arguments = tool_call["arguments"] - yield f"Calling function: {name}, with payload {arguments}" + yield f"\n\nCalling function: {call['name']}, with payload {call['arguments']}" yield "" if inside_thoughts: yield "" + + # Execute all tool calls in parallel. async_calls = [ self.handle_function_or_tool_call( - call_info["name"], - call_info["arguments"], - tool_id=(call_info["id"] or f"call_{idx}"), + call["name"], + call["arguments"], + tool_id=call["internal_id"], *args, **kwargs, ) - for idx, call_info in pending_tool_calls.items() + for call in pending_calls ] await asyncio.gather(*async_calls) - - # Clear the tool call state - pending_tool_calls.clear() + # Reset state after processing. + pending_calls = [] content_buffer = "" + # --- 5. Finalize on finish_reason == "stop" --- elif finish_reason == "stop": - # Finalize content if streaming stops if content_buffer: await self.conversation.add_message( Message(role="assistant", content=content_buffer) ) - elif pending_tool_calls: - # TODO - RM COPY PASTA. + elif pending_calls: + # In case there are pending calls not triggered by a tool_calls finish. calls_list = [] - sorted_indexes = sorted(pending_tool_calls.keys()) - for idx in sorted_indexes: - call_info = pending_tool_calls[idx] - call_id = call_info["id"] or f"call_{idx}" + for call in pending_calls: calls_list.append( { - "id": call_id, + "id": call["internal_id"], "type": "function", "function": { - "name": call_info["name"], - "arguments": call_info["arguments"], + "name": call["name"], + "arguments": call["arguments"], }, } ) - assistant_msg = Message( role="assistant", content=content_buffer or None, tool_calls=calls_list, ) await self.conversation.add_message(assistant_msg) - return - self._completed = True + return - # If the stream ends without `finish_reason=stop` + # --- Finalize if stream ends unexpectedly --- if not self._completed and content_buffer: await self.conversation.add_message( Message(role="assistant", content=content_buffer) ) self._completed = True - # After the stream ends - if content_buffer and not self._completed: - await self.conversation.add_message( - Message(role="assistant", content=content_buffer) + if not self._completed and pending_calls: + calls_list = [] + for call in pending_calls: + calls_list.append( + { + "id": call["internal_id"], + "type": "function", + "function": { + "name": call["name"], + "arguments": call["arguments"], + }, + } + ) + assistant_msg = Message( + role="assistant", + content=content_buffer or None, + tool_calls=calls_list, ) + await self.conversation.add_message(assistant_msg) self._completed = True diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 9879a6cc1..4de9bd14d 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -703,7 +703,7 @@ async def stream_generator(): raise R2RException(str(e), 500) @self.router.post( - "/retrieval/reasoning_agent", + "/retrieval/rawr", dependencies=[Depends(self.rate_limit_dependency)], summary="Reasoning RAG Agent (Chain-of-Thought + Tools)", openapi_extra={ @@ -717,7 +717,7 @@ async def stream_generator(): client = R2RClient() # when using auth, do client.login(...) - response =client.retrieval.reasoning_agent( + response =client.retrieval.rawr( message={ "role": "user", "content": "What were the key contributions of Aristotle to logic and how did they influence later philosophers?" @@ -852,7 +852,7 @@ async def reasoning_agent_app( ), use_system_context=False, override_tools=tools, - reasoning_agent=True, + rawr=True, ) if rag_generation_config.stream: diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index e56b2958d..48aae39b4 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -616,7 +616,7 @@ async def agent( use_system_context: bool = False, max_tool_context_length: int = 32_768, override_tools: Optional[list[dict[str, Any]]] = None, - reasoning_agent: bool = False, + rawr: bool = False, ): try: if message and messages: @@ -747,7 +747,7 @@ async def agent( # STEP 1: Determine the final system prompt content if use_system_context: - if reasoning_agent: + if rawr: raise R2RException( status_code=400, message="Reasoning agent not supported with extended prompt", @@ -759,36 +759,35 @@ async def agent( filter_user_id=filter_user_id, filter_collection_ids=filter_collection_ids, model=rag_generation_config.model, - reasoning_agent=reasoning_agent, + rawr=rawr, ) ) elif task_prompt_override: - if reasoning_agent: + if rawr: raise R2RException( status_code=400, message="Reasoning agent not supported with task prompt override", ) system_instruction = task_prompt_override - elif reasoning_agent: + elif rawr: system_instruction = ( await self._build_aware_system_instruction( max_tool_context_length=max_tool_context_length, filter_user_id=filter_user_id, filter_collection_ids=filter_collection_ids, model=rag_generation_config.model, - reasoning_agent=reasoning_agent, + rawr=rawr, ) ) agent_config = deepcopy(self.config.agent) agent_config.tools = override_tools or agent_config.tools - if rag_generation_config.stream: async def stream_response(): try: - if not reasoning_agent: + if not rawr: agent = R2RStreamingRAGAgent( database_provider=self.providers.database, llm_provider=self.providers.llm, @@ -817,6 +816,8 @@ async def stream_response(): elif ( "claude-3-5-sonnet-20241022" in rag_generation_config.model + or "o3-mini" in rag_generation_config.model + or "gpt-4o" in rag_generation_config.model ): agent = R2RStreamingReasoningRAGAgent( database_provider=self.providers.database, @@ -831,7 +832,7 @@ async def stream_response(): else: raise R2RException( status_code=400, - message="Reasoning agent not supported for this model", + message=f"Reasoning agent not supported for this model {rag_generation_config.model}", ) async for chunk in agent.arun( @@ -1147,7 +1148,7 @@ async def _build_aware_system_instruction( filter_user_id: Optional[UUID] = None, filter_collection_ids: Optional[list[UUID]] = None, model: Optional[str] = None, - reasoning_agent: bool = False, + rawr: bool = False, ) -> str: """ High-level method that: @@ -1165,12 +1166,16 @@ async def _build_aware_system_instruction( filter_collection_ids=filter_collection_ids, ) - if not reasoning_agent: + if not rawr: prompt_name = "aware_rag_agent" else: if "gemini-2.0-flash-thinking-exp-01-21" in model: prompt_name = "aware_rag_agent_reasoning_xml_tooling" - elif "claude-3-5-sonnet-20241022" in model: + elif ( + "claude-3-5-sonnet-20241022" in model + or "o3-mini" in model + or "gpt-4o" in model + ): prompt_name = "aware_rag_agent_reasoning_prompted" else: raise R2RException( diff --git a/py/core/providers/database/prompts/aware_rag_agent_reasoning_prompted.yaml b/py/core/providers/database/prompts/aware_rag_agent_reasoning_prompted.yaml index 9707598ea..160121e4c 100644 --- a/py/core/providers/database/prompts/aware_rag_agent_reasoning_prompted.yaml +++ b/py/core/providers/database/prompts/aware_rag_agent_reasoning_prompted.yaml @@ -28,6 +28,7 @@ aware_rag_agent_reasoning_prompted: 5. NEVER provide a response without first showing your thinking 6. NEVER include response content inside thought tags 7. NEVER skip the thought process + 8. ATTEMPT TO DO MULTIPLE TOOL CALLS AT ONCE TO SAVE TIME Response Protocol: 1. If no relevant results are found, clearly state this diff --git a/py/core/providers/llm/anthropic.py b/py/core/providers/llm/anthropic.py index 42e7627ad..81ea7863b 100644 --- a/py/core/providers/llm/anthropic.py +++ b/py/core/providers/llm/anthropic.py @@ -3,6 +3,7 @@ import logging import os import time +import uuid from typing import Any, AsyncGenerator, Generator, Optional from anthropic import Anthropic, AsyncAnthropic @@ -23,6 +24,11 @@ logger = logging.getLogger(__name__) +def generate_tool_id() -> str: + """Generate a unique tool ID using UUID4.""" + return f"tool_{uuid.uuid4().hex[:12]}" + + def openai_message_to_anthropic_block(msg: dict) -> dict: """ Converts a single OpenAI-style message (including function/tool calls) @@ -234,19 +240,27 @@ def _split_system_messages( self, messages: list[dict] ) -> (list[dict], Optional[str]): """ - Extract the system message (if any) from a combined list of messages. - Return (filtered_messages, system_message). + Extract the system message and properly group tool results with their calls. """ system_msg = None filtered = [] + pending_tool_results = [] + for m in copy.deepcopy(messages): if m["role"] == "system" and system_msg is None: system_msg = m["content"] - else: + continue - m2 = None - if m.get("tool_calls") != None: - m2 = { + if m.get("tool_calls"): + # First add any content as a regular message + if m.get("content"): + filtered.append( + {"role": "assistant", "content": m["content"]} + ) + + # Add the tool calls message + filtered.append( + { "role": "assistant", "content": [ { @@ -260,12 +274,28 @@ def _split_system_messages( for call in m["tool_calls"] ], } - m.pop("tool_calls") + ) + + elif m["role"] in ["function", "tool"]: + # Collect tool results to combine them + pending_tool_results.append( + { + "type": "tool_result", + "tool_use_id": m.get("tool_call_id"), + "content": m["content"], + } + ) + + # If we have all expected results, add them as one message + if len(pending_tool_results) == len(filtered[-1]["content"]): + filtered.append( + {"role": "user", "content": pending_tool_results} + ) + pending_tool_results = [] + else: + # Regular message + filtered.append(openai_message_to_anthropic_block(m)) - m = openai_message_to_anthropic_block(m) - filtered.append(m) - if m2: - filtered.append(m2) return filtered, system_msg async def _execute_task(self, task: dict[str, Any]): @@ -285,6 +315,7 @@ async def _execute_task(self, task: dict[str, Any]): base_args = self._get_base_args(generation_config) filtered_messages, system_msg = self._split_system_messages(messages) + base_args["messages"] = filtered_messages if system_msg: base_args["system"] = system_msg @@ -493,7 +524,7 @@ def make_base_chunk() -> dict: { "index": 0, "type": "function", - "id": f"call_{buffer_data['message_id']}", + "id": f"call_{generate_tool_id()}", "function": { "name": buffer_data["tool_name"], "arguments": buffer_data[ diff --git a/py/core/providers/llm/openai.py b/py/core/providers/llm/openai.py index 750c07dc7..05834e47a 100644 --- a/py/core/providers/llm/openai.py +++ b/py/core/providers/llm/openai.py @@ -210,7 +210,10 @@ def _get_base_args(self, generation_config: GenerationConfig) -> dict: "top_p": generation_config.top_p, "stream": generation_config.stream, } - if "o1" not in generation_config.model: + if ( + "o1" not in generation_config.model + and "o3" not in generation_config.model + ): args["max_tokens"] = generation_config.max_tokens_to_sample args["temperature"] = generation_config.temperature else: diff --git a/py/sdk/asnyc_methods/retrieval.py b/py/sdk/asnyc_methods/retrieval.py index ea7a4db67..b92fd1f1f 100644 --- a/py/sdk/asnyc_methods/retrieval.py +++ b/py/sdk/asnyc_methods/retrieval.py @@ -217,7 +217,7 @@ async def agent( version="v3", ) - async def reasoning_agent( + async def rawr( self, message: Optional[dict | Message] = None, rag_generation_config: Optional[dict | GenerationConfig] = None, @@ -259,14 +259,14 @@ async def reasoning_agent( ): return self.client._make_streaming_request( "POST", - "retrieval/reasoning_agent", + "retrieval/rawr", json=data, version="v3", ) else: return await self.client._make_request( "POST", - "retrieval/reasoning_agent", + "retrieval/rawr", json=data, version="v3", )