From 60cb15adb2e59a67d0f172a6a1d8e5e06ddb5c27 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Fri, 24 Jan 2025 18:12:54 -0800 Subject: [PATCH] fix tool calling --- py/core/agent/base.py | 392 ++++++++++++++++++++++--------- py/core/agent/rag.py | 339 ++++++++++++++------------ py/core/base/agent/agent.py | 1 + py/core/providers/llm/litellm.py | 2 + 4 files changed, 465 insertions(+), 269 deletions(-) diff --git a/py/core/agent/base.py b/py/core/agent/base.py index 50a77cd22..a9417fe16 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -158,158 +158,119 @@ async def process_llm_response( **kwargs, ) -> AsyncGenerator[str, None]: """ - Modified to: - 1) Collect partial tool calls in a dict keyed by their .index - 2) Execute them in parallel (asyncio.gather) once finish_reason="tool_calls" + Updated to: + 1) Accumulate interleaved content and tool calls gracefully. + 2) Finalize content even if no tool calls are made. + 3) Support processing of both content and tool calls in parallel. """ - # Dictionary: - # pending_tool_calls[index] = { - # "id": str or None, - # "name": str, - # "arguments": str, - # } pending_tool_calls = {} - - # For single function_call logic + content_buffer = "" function_name = None function_arguments = "" - - # Buffer for normal text - content_buffer = "" + tool_calls_active = False async for chunk in stream: delta = chunk.choices[0].delta + finish_reason = chunk.choices[0].finish_reason - # 1) Handle partial tool_calls + # 1) Handle interleaved tool_calls if delta.tool_calls: + tool_calls_active = True for tc in delta.tool_calls: idx = tc.index if idx not in pending_tool_calls: pending_tool_calls[idx] = { - "id": tc.id, # might be None + "id": tc.id, # could be None "name": tc.function.name or "", "arguments": tc.function.arguments or "", } else: - # Accumulate partial arguments + # Accumulate partial tool call details if tc.function.name: pending_tool_calls[idx]["name"] = tc.function.name if tc.function.arguments: pending_tool_calls[idx][ "arguments" ] += tc.function.arguments - # If we see an ID on a later chunk, set it now + # Set the ID if it appears in later chunks if tc.id and not pending_tool_calls[idx]["id"]: pending_tool_calls[idx]["id"] = tc.id - # 2) Handle partial function_call + # 2) Handle partial function_call (single-call logic) if delta.function_call: if delta.function_call.name: function_name = delta.function_call.name if delta.function_call.arguments: function_arguments += delta.function_call.arguments - # 3) Handle normal text + # 3) Handle normal content elif delta.content: - if content_buffer == "": + if not content_buffer: yield "" content_buffer += delta.content yield delta.content - # 4) Check finish_reason - finish_reason = chunk.choices[0].finish_reason - + # 4) Check finish_reason for tool calls if finish_reason == "tool_calls": - # The model has finished specifying this entire set of tool calls in an assistant message. - if not pending_tool_calls: - logger.warning( - "Got finish_reason=tool_calls but no pending tool calls." + # Finalize the tool calls + calls_list = [] + sorted_indexes = sorted(pending_tool_calls.keys()) + for idx in sorted_indexes: + call_info = pending_tool_calls[idx] + call_id = call_info["id"] or f"call_{idx}" + calls_list.append( + { + "id": call_id, + "type": "function", + "function": { + "name": call_info["name"], + "arguments": call_info["arguments"], + }, + } ) - else: - # 4a) Build a single 'assistant' message with all tool_calls - calls_list = [] - # Sort by index to ensure consistent ordering - sorted_indexes = sorted(pending_tool_calls.keys()) - for idx in sorted_indexes: - call_info = pending_tool_calls[idx] - call_id = ( - call_info["id"] - if call_info["id"] - else f"call_{idx}" - ) - calls_list.append( - { - "id": call_id, - "type": "function", - "function": { - "name": call_info["name"], - "arguments": call_info["arguments"], - }, - } - ) - - assistant_msg = Message( - role="assistant", - content=content_buffer or None, - tool_calls=calls_list, + + assistant_msg = Message( + role="assistant", + content=content_buffer or None, + tool_calls=calls_list, + ) + await self.conversation.add_message(assistant_msg) + + # Execute tool calls in parallel + async_calls = [ + self.handle_function_or_tool_call( + call_info["name"], + call_info["arguments"], + tool_id=(call_info["id"] or f"call_{idx}"), + *args, + **kwargs, ) - await self.conversation.add_message(assistant_msg) - - # 4b) Execute them in parallel using asyncio.gather - async_calls = [] - for idx in sorted_indexes: - call_info = pending_tool_calls[idx] - call_id = call_info["id"] or f"call_{idx}" - async_calls.append( - self.handle_function_or_tool_call( - call_info["name"], - call_info["arguments"], - tool_id=call_id, - *args, - **kwargs, - ) - ) - results = await asyncio.gather(*async_calls) - - # 4c) Now yield the blocks in the same order - for idx, tool_result in zip(sorted_indexes, results): - # We re-lookup the name, arguments, id - call_info = pending_tool_calls[idx] - call_id = call_info["id"] or f"call_{idx}" - call_name = call_info["name"] - call_args = call_info["arguments"] - - yield "" - yield f"{call_name}" - yield f"{call_args}" - - if tool_result.stream_result: - yield f"{tool_result.stream_result}" - else: - yield f"{tool_result.llm_formatted_result}" - - yield "" - - # 4d) Add a role="function" message - await self.conversation.add_message( - Message( - role="function", - name=call_id, - content=tool_result.llm_formatted_result, - ) - ) - - # 4e) Reset - pending_tool_calls.clear() - content_buffer = "" + for idx, call_info in pending_tool_calls.items() + ] + results = await asyncio.gather(*async_calls) + + # Yield tool call results + for idx, tool_result in zip(sorted_indexes, results): + call_info = pending_tool_calls[idx] + yield "" + yield f"{call_info['name']}" + yield f"{call_info['arguments']}" + if tool_result.stream_result: + yield f"{tool_result.stream_result}" + else: + yield f"{tool_result.llm_formatted_result}" + yield "" + + # Clear the tool call state + pending_tool_calls.clear() + content_buffer = "" elif finish_reason == "function_call": - # Single function call approach + # Single function call handling if not function_name: - logger.info("Function name not found in function call.") + logger.warning("Function name not found in function call.") continue - # Add the assistant message with function_call assistant_msg = Message( role="assistant", content=content_buffer if content_buffer else None, @@ -331,10 +292,8 @@ async def process_llm_response( yield f"{tool_result.stream_result}" else: yield f"{tool_result.llm_formatted_result}" - yield "" - # Add a function-role message await self.conversation.add_message( Message( role="function", @@ -342,13 +301,14 @@ async def process_llm_response( content=tool_result.llm_formatted_result, ) ) - - function_name = None - function_arguments = "" - content_buffer = "" + function_name, function_arguments, content_buffer = ( + None, + "", + "", + ) elif finish_reason == "stop": - # The model is done producing text + # Finalize content if streaming stops if content_buffer: await self.conversation.add_message( Message(role="assistant", content=content_buffer) @@ -356,6 +316,210 @@ async def process_llm_response( self._completed = True yield "" + # If the stream ends without `finish_reason=stop` + if not self._completed and content_buffer: + await self.conversation.add_message( + Message(role="assistant", content=content_buffer) + ) + self._completed = True + yield "" + + # async def process_llm_response( + # self, + # stream: AsyncGenerator[LLMChatCompletionChunk, None], + # *args, + # **kwargs, + # ) -> AsyncGenerator[str, None]: + # """ + # Modified to: + # 1) Collect partial tool calls in a dict keyed by their .index + # 2) Execute them in parallel (asyncio.gather) once finish_reason="tool_calls" + # """ + # # Dictionary: + # # pending_tool_calls[index] = { + # # "id": str or None, + # # "name": str, + # # "arguments": str, + # # } + # pending_tool_calls = {} + + # # For single function_call logic + # function_name = None + # function_arguments = "" + + # # Buffer for normal text + # content_buffer = "" + + # async for chunk in stream: + # delta = chunk.choices[0].delta + # print(f'chunk={chunk}, delta = {delta}') + # # 1) Handle partial tool_calls + # if delta.tool_calls: + # for tc in delta.tool_calls: + # idx = tc.index + # if idx not in pending_tool_calls: + # pending_tool_calls[idx] = { + # "id": tc.id, # might be None + # "name": tc.function.name or "", + # "arguments": tc.function.arguments or "", + # } + # else: + # # Accumulate partial arguments + # if tc.function.name: + # pending_tool_calls[idx]["name"] = tc.function.name + # if tc.function.arguments: + # pending_tool_calls[idx][ + # "arguments" + # ] += tc.function.arguments + # # If we see an ID on a later chunk, set it now + # if tc.id and not pending_tool_calls[idx]["id"]: + # pending_tool_calls[idx]["id"] = tc.id + + # # 2) Handle partial function_call + # if delta.function_call: + # if delta.function_call.name: + # function_name = delta.function_call.name + # if delta.function_call.arguments: + # function_arguments += delta.function_call.arguments + + # # 3) Handle normal text + # elif delta.content: + # if content_buffer == "": + # yield "" + # content_buffer += delta.content + # yield delta.content + + # # 4) Check finish_reason + # finish_reason = chunk.choices[0].finish_reason + + # if finish_reason == "tool_calls": + # # The model has finished specifying this entire set of tool calls in an assistant message. + # if not pending_tool_calls: + # logger.warning( + # "Got finish_reason=tool_calls but no pending tool calls." + # ) + # else: + # # 4a) Build a single 'assistant' message with all tool_calls + # calls_list = [] + # # Sort by index to ensure consistent ordering + # sorted_indexes = sorted(pending_tool_calls.keys()) + # for idx in sorted_indexes: + # call_info = pending_tool_calls[idx] + # call_id = ( + # call_info["id"] + # if call_info["id"] + # else f"call_{idx}" + # ) + # calls_list.append( + # { + # "id": call_id, + # "type": "function", + # "function": { + # "name": call_info["name"], + # "arguments": call_info["arguments"], + # }, + # } + # ) + + # assistant_msg = Message( + # role="assistant", + # content=content_buffer or None, + # tool_calls=calls_list, + # ) + # await self.conversation.add_message(assistant_msg) + + # # 4b) Execute them in parallel using asyncio.gather + # async_calls = [] + # for idx in sorted_indexes: + # call_info = pending_tool_calls[idx] + # call_id = call_info["id"] or f"call_{idx}" + # async_calls.append( + # self.handle_function_or_tool_call( + # call_info["name"], + # call_info["arguments"], + # tool_id=call_id, + # *args, + # **kwargs, + # ) + # ) + # results = await asyncio.gather(*async_calls) + + # # 4c) Now yield the blocks in the same order + # for idx, tool_result in zip(sorted_indexes, results): + # # We re-lookup the name, arguments, id + # call_info = pending_tool_calls[idx] + # call_id = call_info["id"] or f"call_{idx}" + # call_name = call_info["name"] + # call_args = call_info["arguments"] + + # yield "" + # yield f"{call_name}" + # yield f"{call_args}" + + # if tool_result.stream_result: + # yield f"{tool_result.stream_result}" + # else: + # yield f"{tool_result.llm_formatted_result}" + + # yield "" + + # # 4e) Reset + # pending_tool_calls.clear() + # content_buffer = "" + + # elif finish_reason == "function_call": + # # Single function call approach + # if not function_name: + # logger.info("Function name not found in function call.") + # continue + + # # Add the assistant message with function_call + # assistant_msg = Message( + # role="assistant", + # content=content_buffer if content_buffer else None, + # function_call={ + # "name": function_name, + # "arguments": function_arguments, + # }, + # ) + # await self.conversation.add_message(assistant_msg) + + # yield "" + # yield f"{function_name}" + # yield f"{function_arguments}" + + # tool_result = await self.handle_function_or_tool_call( + # function_name, function_arguments, *args, **kwargs + # ) + # if tool_result.stream_result: + # yield f"{tool_result.stream_result}" + # else: + # yield f"{tool_result.llm_formatted_result}" + + # yield "" + + # # Add a function-role message + # await self.conversation.add_message( + # Message( + # role="function", + # name=function_name, + # content=tool_result.llm_formatted_result, + # ) + # ) + + # function_name = None + # function_arguments = "" + # content_buffer = "" + + # elif finish_reason == "stop": + # # The model is done producing text + # if content_buffer: + # await self.conversation.add_message( + # Message(role="assistant", content=content_buffer) + # ) + # self._completed = True + # yield "" + # After the stream ends if content_buffer and not self._completed: await self.conversation.add_message( diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index 2e0866f76..6456dba06 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -140,31 +140,60 @@ def content(self) -> Tool: Tool to fetch entire documents from the local database. Typically used if the agent needs deeper or more structured context from documents, not just chunk-level hits. """ - return Tool( - name="content", - description=( - "Fetches the complete contents of all user documents from the local database. " - "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." - "For instance, a single document can be returned with a filter like so:" - "{'document_id': {'$eq': '...'}}." - ), - results_function=self._content_function, - llm_format_function=self.format_search_results_for_llm, - stream_function=self.format_search_results_for_stream, - parameters={ - "type": "object", - "properties": { - "filters": { - "type": "object", - "description": ( - "Dictionary with filter criteria, such as " - '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' - ), + if "gemini" in self.rag_generation_config.model: + tool = Tool( + name="content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + stream_function=self.format_search_results_for_stream, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "string", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, }, + "required": ["filters"], }, - "required": ["filters"], - }, - ) + ) + + else: + tool = Tool( + name="content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + stream_function=self.format_search_results_for_stream, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "object", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, + }, + "required": ["filters"], + }, + ) + return tool async def _content_function( self, @@ -267,138 +296,138 @@ async def _web_search_function( web_search_results=web_response.organic_results, ) - # --------------------------------------------------------------------- - # MULTI_SEARCH IMPLEMENTATION - # --------------------------------------------------------------------- - def multi_search(self) -> Tool: - """ - A tool that accepts multiple queries at once, runs local/web/content - searches *in parallel*, merges them, and returns aggregated results. - """ - return Tool( - name="multi_search", - description=( - "Run parallel searches for multiple queries. Submit ALL queries in a SINGLE request with this exact format:\n" - '{"queries": ["query1", "query2", "query3"], "include_web": false}\n\n' - "Example valid input:\n" - '{"queries": ["latest research on GPT-4", "advances in robotics 2024"], "include_web": false}\n\n' - "IMPORTANT:\n" - "- All queries must be in a single array under the 'queries' key\n" - "- Do NOT submit multiple separate JSON objects\n" - "- Do NOT add empty JSON objects {}\n" - "- Each query should be a string in the array\n" - "You can submit up to 10 queries in a single request. Results are limited to 20 per query." - ), - results_function=self._multi_search, - llm_format_function=self.format_search_results_for_llm, - stream_function=self.format_search_results_for_stream, - parameters={ - "type": "object", - "properties": { - "queries": { - "type": "array", - "items": {"type": "string"}, - "description": "Array of search queries to run in parallel. Example: ['query1', 'query2']", - "maxItems": 10, - }, - "include_web": { - "type": "boolean", - "description": "Whether to include web search results", - "default": False, - }, - }, - "required": ["queries"], - }, - ) - - async def _multi_search( - self, - queries: list[str], - include_web: bool = False, - include_content: bool = False, - *args, - **kwargs, - ) -> list[Tuple[str, AggregateSearchResult]]: - """ - Run local, web, and content searches *in parallel* for each query, - merge results, and return them sorted by best "score". - - :param queries: a list of search queries - :param include_web: whether to run web search (True by default) - :param include_content: whether to fetch entire documents (True by default) - :return: A list of (query, merged_results), sorted by highest chunk score - """ - # Set search results to 10 - self.search_settings.limit = 20 - - # Build tasks (one per query) - tasks = [ - self._multi_search_for_single_query( - q, include_web, include_content - ) - for q in queries - ] - # Run them all in parallel - partial_results = await asyncio.gather(*tasks) - return self._merge_aggregate_results(partial_results) - - async def _multi_search_for_single_query( - self, - query: str, - include_web: bool, - include_content: bool, - ) -> AggregateSearchResult: - """ - For a single query, run local, web, and content searches in parallel, - then merge everything into one AggregateSearchResult. - """ - # local always - searches = [self._local_search_function(query)] - - # optionally web - if include_web: - searches.append(self._web_search_function(query)) - - # optionally content - if include_content: - # pass any needed filters/options - searches.append(self._content_function(filters={}, options={})) - - # gather them concurrently - partial_results = await asyncio.gather(*searches) - - # merge all partial AggregateSearchResults - merged_result = self._merge_aggregate_results(partial_results) - return merged_result - - def _merge_aggregate_results( - self, results: list[AggregateSearchResult] - ) -> AggregateSearchResult: - """ - Concatenate chunk_search_results, web_search_results, etc. from multiple - AggregateSearchResult objects into one. - """ - all_chunks = [] - all_graphs = [] - all_web = [] - all_docs = [] - - for r in results: - if r.chunk_search_results: - all_chunks.extend(r.chunk_search_results) - if r.graph_search_results: - all_graphs.extend(r.graph_search_results) - if r.web_search_results: - all_web.extend(r.web_search_results) - if r.context_document_results: - all_docs.extend(r.context_document_results) - - return AggregateSearchResult( - chunk_search_results=all_chunks if all_chunks else None, - graph_search_results=all_graphs if all_graphs else None, - web_search_results=all_web if all_web else None, - context_document_results=all_docs if all_docs else None, - ) + # # --------------------------------------------------------------------- + # # MULTI_SEARCH IMPLEMENTATION + # # --------------------------------------------------------------------- + # def multi_search(self) -> Tool: + # """ + # A tool that accepts multiple queries at once, runs local/web/content + # searches *in parallel*, merges them, and returns aggregated results. + # """ + # return Tool( + # name="multi_search", + # description=( + # "Run parallel searches for multiple queries. Submit ALL queries in a SINGLE request with this exact format:\n" + # '{"queries": ["query1", "query2", "query3"], "include_web": false}\n\n' + # "Example valid input:\n" + # '{"queries": ["latest research on GPT-4", "advances in robotics 2024"], "include_web": false}\n\n' + # "IMPORTANT:\n" + # "- All queries must be in a single array under the 'queries' key\n" + # "- Do NOT submit multiple separate JSON objects\n" + # "- Do NOT add empty JSON objects {}\n" + # "- Each query should be a string in the array\n" + # "You can submit up to 10 queries in a single request. Results are limited to 20 per query." + # ), + # results_function=self._multi_search, + # llm_format_function=self.format_search_results_for_llm, + # stream_function=self.format_search_results_for_stream, + # parameters={ + # "type": "object", + # "properties": { + # "queries": { + # "type": "array", + # "items": {"type": "string"}, + # "description": "Array of search queries to run in parallel. Example: ['query1', 'query2']", + # "maxItems": 10, + # }, + # "include_web": { + # "type": "boolean", + # "description": "Whether to include web search results", + # "default": False, + # }, + # }, + # "required": ["queries"], + # }, + # ) + + # async def _multi_search( + # self, + # queries: list[str], + # include_web: bool = False, + # include_content: bool = False, + # *args, + # **kwargs, + # ) -> list[Tuple[str, AggregateSearchResult]]: + # """ + # Run local, web, and content searches *in parallel* for each query, + # merge results, and return them sorted by best "score". + + # :param queries: a list of search queries + # :param include_web: whether to run web search (True by default) + # :param include_content: whether to fetch entire documents (True by default) + # :return: A list of (query, merged_results), sorted by highest chunk score + # """ + # # Set search results to 10 + # self.search_settings.limit = 20 + + # # Build tasks (one per query) + # tasks = [ + # self._multi_search_for_single_query( + # q, include_web, include_content + # ) + # for q in queries + # ] + # # Run them all in parallel + # partial_results = await asyncio.gather(*tasks) + # return self._merge_aggregate_results(partial_results) + + # async def _multi_search_for_single_query( + # self, + # query: str, + # include_web: bool, + # include_content: bool, + # ) -> AggregateSearchResult: + # """ + # For a single query, run local, web, and content searches in parallel, + # then merge everything into one AggregateSearchResult. + # """ + # # local always + # searches = [self._local_search_function(query)] + + # # optionally web + # if include_web: + # searches.append(self._web_search_function(query)) + + # # optionally content + # if include_content: + # # pass any needed filters/options + # searches.append(self._content_function(filters={}, options={})) + + # # gather them concurrently + # partial_results = await asyncio.gather(*searches) + + # # merge all partial AggregateSearchResults + # merged_result = self._merge_aggregate_results(partial_results) + # return merged_result + + # def _merge_aggregate_results( + # self, results: list[AggregateSearchResult] + # ) -> AggregateSearchResult: + # """ + # Concatenate chunk_search_results, web_search_results, etc. from multiple + # AggregateSearchResult objects into one. + # """ + # all_chunks = [] + # all_graphs = [] + # all_web = [] + # all_docs = [] + + # for r in results: + # if r.chunk_search_results: + # all_chunks.extend(r.chunk_search_results) + # if r.graph_search_results: + # all_graphs.extend(r.graph_search_results) + # if r.web_search_results: + # all_web.extend(r.web_search_results) + # if r.context_document_results: + # all_docs.extend(r.context_document_results) + + # return AggregateSearchResult( + # chunk_search_results=all_chunks if all_chunks else None, + # graph_search_results=all_graphs if all_graphs else None, + # web_search_results=all_web if all_web else None, + # context_document_results=all_docs if all_docs else None, + # ) # --------------------------------------------------------------------- # 4) Utility format methods for search results diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index 59d8f0bca..22ab3c83d 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -47,6 +47,7 @@ def create_and_add_message( async def add_message(self, message): async with self._lock: self.messages.append(message) + print("latest message = ", message) async def get_messages(self) -> list[dict[str, Any]]: async with self._lock: diff --git a/py/core/providers/llm/litellm.py b/py/core/providers/llm/litellm.py index 02747f902..cfcd9281a 100644 --- a/py/core/providers/llm/litellm.py +++ b/py/core/providers/llm/litellm.py @@ -11,8 +11,10 @@ class LiteLLMCompletionProvider(CompletionProvider): def __init__(self, config: CompletionConfig, *args, **kwargs) -> None: super().__init__(config) try: + import litellm from litellm import acompletion, completion + litellm.modify_params = True self.acompletion = acompletion self.completion = completion logger.debug("LiteLLM imported successfully")