diff --git a/apps/miroflow-agent/src/core/orchestrator.py b/apps/miroflow-agent/src/core/orchestrator.py index 5619bacf..273a462e 100644 --- a/apps/miroflow-agent/src/core/orchestrator.py +++ b/apps/miroflow-agent/src/core/orchestrator.py @@ -1198,5 +1198,11 @@ async def run_main_agent( "Main Agent | Task Completed", f"Main agent task {task_id} completed successfully", ) + + # Close all persistent MCP sessions to release subprocess resources + await self.main_agent_tool_manager.close_all_sessions() + for tm in self.sub_agent_tool_managers.values(): + await tm.close_all_sessions() + gc.collect() return final_summary, final_boxed_answer, failure_experience_summary diff --git a/libs/miroflow-tools/src/miroflow_tools/dev_mcp_servers/search_and_scrape_webpage.py b/libs/miroflow-tools/src/miroflow_tools/dev_mcp_servers/search_and_scrape_webpage.py index e8d7355e..f231ea84 100644 --- a/libs/miroflow-tools/src/miroflow_tools/dev_mcp_servers/search_and_scrape_webpage.py +++ b/libs/miroflow-tools/src/miroflow_tools/dev_mcp_servers/search_and_scrape_webpage.py @@ -36,6 +36,17 @@ # Initialize FastMCP server mcp = FastMCP("search_and_scrape_webpage") +# Module-level shared httpx client for connection pooling +_serper_client: httpx.AsyncClient | None = None + + +def _get_serper_client() -> httpx.AsyncClient: + """Get or create a shared httpx.AsyncClient for Serper API requests.""" + global _serper_client + if _serper_client is None or _serper_client.is_closed: + _serper_client = httpx.AsyncClient(timeout=60.0) + return _serper_client + @retry( stop=stop_after_attempt(3), @@ -47,15 +58,16 @@ async def make_serper_request( payload: Dict[str, Any], headers: Dict[str, str] ) -> httpx.Response: - """Make HTTP request to Serper API with retry logic.""" - async with httpx.AsyncClient() as client: - response = await client.post( - f"{SERPER_BASE_URL}/search", - json=payload, - headers=headers, - ) - response.raise_for_status() - return response + """Make HTTP request to Serper API with retry logic. + Uses a module-level shared client for connection pooling.""" + client = _get_serper_client() + response = await client.post( + f"{SERPER_BASE_URL}/search", + json=payload, + headers=headers, + ) + response.raise_for_status() + return response def _is_banned_url(url: str) -> bool: diff --git a/libs/miroflow-tools/src/miroflow_tools/manager.py b/libs/miroflow-tools/src/miroflow_tools/manager.py index f97964f6..f0d905c5 100644 --- a/libs/miroflow-tools/src/miroflow_tools/manager.py +++ b/libs/miroflow-tools/src/miroflow_tools/manager.py @@ -3,6 +3,7 @@ import asyncio import functools +import logging from typing import Any, Awaitable, Callable, Protocol, TypeVar from mcp import ClientSession, StdioServerParameters # (already imported in config.py) @@ -11,7 +12,7 @@ from .mcp_servers.browser_session import PlaywrightSession -# logger = logging.getLogger("miroflow_agent") +logger = logging.getLogger("miroflow_agent") R = TypeVar("R") @@ -36,6 +37,76 @@ async def wrapper(*args, **kwargs) -> R: return decorator +class PersistentMCPSession: + """Maintains a persistent MCP session for a single server (stdio or SSE). + + Instead of spawning a new subprocess for every tool call, this keeps the + subprocess alive and reuses the session across calls. On connection + failure it transparently reconnects once. + """ + + def __init__(self, server_name: str, server_params): + self.server_name = server_name + self.server_params = server_params + self._client = None + self._session: ClientSession | None = None + self._lock = asyncio.Lock() + + async def _connect(self): + """Establish the underlying transport + MCP session.""" + if isinstance(self.server_params, StdioServerParameters): + self._client = stdio_client(self.server_params) + elif isinstance(self.server_params, str) and self.server_params.startswith( + ("http://", "https://") + ): + self._client = sse_client(self.server_params) + else: + raise TypeError( + f"Unknown server params type for {self.server_name}: " + f"{type(self.server_params)}" + ) + + read, write = await self._client.__aenter__() + self._session = ClientSession(read, write, sampling_callback=None) + await self._session.__aenter__() + await self._session.initialize() + + async def ensure_connected(self): + """Connect if not already connected.""" + async with self._lock: + if self._session is None: + await self._connect() + + async def call_tool(self, tool_name: str, arguments: dict | None = None) -> str: + """Call a tool, reconnecting once on failure.""" + await self.ensure_connected() + try: + tool_result = await self._session.call_tool(tool_name, arguments=arguments) + return tool_result.content[-1].text if tool_result.content else "" + except Exception: + # Connection may have died — reconnect once and retry + await self.close() + await self.ensure_connected() + tool_result = await self._session.call_tool(tool_name, arguments=arguments) + return tool_result.content[-1].text if tool_result.content else "" + + async def close(self): + """Tear down session and transport.""" + async with self._lock: + if self._session is not None: + try: + await self._session.__aexit__(None, None, None) + except Exception: + pass + self._session = None + if self._client is not None: + try: + await self._client.__aexit__(None, None, None) + except Exception: + pass + self._client = None + + class ToolManagerProtocol(Protocol): """this enables other kinds of tool manager.""" @@ -58,6 +129,8 @@ def __init__(self, server_configs, tool_blacklist=None): self.browser_session = None self.tool_blacklist = tool_blacklist if tool_blacklist else set() self.task_log = None + # Persistent MCP sessions keyed by server name + self._sessions: dict[str, PersistentMCPSession] = {} def set_task_log(self, task_log): """Set the task logger for structured logging.""" @@ -101,110 +174,134 @@ def get_server_params(self, server_name): """Get parameters for the specified server""" return self.server_dict.get(server_name) - async def get_all_tool_definitions(self): - """ - Connect to all configured servers and get their tool definitions. - Returns a list suitable for passing to the Prompt generator. - """ - all_servers_for_prompt = [] - # Process remote server tools - for config in self.server_configs: - server_name = config["name"] - server_params = config["params"] - one_server_for_prompt = {"name": server_name, "tools": []} - self._log( - "info", - "ToolManager | Get Tool Definitions", - f"Getting tool definitions for server '{server_name}'...", + def _get_or_create_session(self, server_name: str) -> PersistentMCPSession: + """Get an existing persistent session or create a new one.""" + if server_name not in self._sessions: + server_params = self.server_dict[server_name] + self._sessions[server_name] = PersistentMCPSession( + server_name, server_params ) + return self._sessions[server_name] - try: - if isinstance(server_params, StdioServerParameters): - async with stdio_client(server_params) as (read, write): - async with ClientSession( - read, write, sampling_callback=None - ) as session: - await session.initialize() - tools_response = await session.list_tools() - # black list some tools - for tool in tools_response.tools: - if (server_name, tool.name) in self.tool_blacklist: - self._log( - "info", - "ToolManager | Tool Blacklisted", - f"Tool '{tool.name}' in server '{server_name}' is blacklisted, skipping.", - ) - continue - one_server_for_prompt["tools"].append( - { - "name": tool.name, - "description": tool.description, - "schema": tool.inputSchema, - } - ) - elif isinstance(server_params, str) and server_params.startswith( - ("http://", "https://") - ): - # SSE endpoint - async with sse_client(server_params) as (read, write): - async with ClientSession( - read, write, sampling_callback=None - ) as session: - await session.initialize() - tools_response = await session.list_tools() - for tool in tools_response.tools: - # Can add specific tool filtering logic here (if needed) - # if server_name == "tool-excel" and tool.name not in ["get_workbook_metadata", "read_data_from_excel"]: - # continue - one_server_for_prompt["tools"].append( - { - "name": tool.name, - "description": tool.description, - "schema": tool.inputSchema, - } - ) - else: - self._log( - "error", - "ToolManager | Unknown Parameter Type", - f"Error: Unknown parameter type for server '{server_name}': {type(server_params)}", - ) - raise TypeError( - f"Unknown server params type for {server_name}: {type(server_params)}" - ) + async def _get_single_server_tools(self, config): + """Connect to a single server and get its tool definitions.""" + server_name = config["name"] + server_params = config["params"] + one_server_for_prompt = {"name": server_name, "tools": []} + self._log( + "info", + "ToolManager | Get Tool Definitions", + f"Getting tool definitions for server '{server_name}'...", + ) + try: + if isinstance(server_params, StdioServerParameters): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read, write, sampling_callback=None + ) as session: + await session.initialize() + tools_response = await session.list_tools() + # black list some tools + for tool in tools_response.tools: + if (server_name, tool.name) in self.tool_blacklist: + self._log( + "info", + "ToolManager | Tool Blacklisted", + f"Tool '{tool.name}' in server '{server_name}' is blacklisted, skipping.", + ) + continue + one_server_for_prompt["tools"].append( + { + "name": tool.name, + "description": tool.description, + "schema": tool.inputSchema, + } + ) + elif isinstance(server_params, str) and server_params.startswith( + ("http://", "https://") + ): + # SSE endpoint + async with sse_client(server_params) as (read, write): + async with ClientSession( + read, write, sampling_callback=None + ) as session: + await session.initialize() + tools_response = await session.list_tools() + for tool in tools_response.tools: + one_server_for_prompt["tools"].append( + { + "name": tool.name, + "description": tool.description, + "schema": tool.inputSchema, + } + ) + else: self._log( - "info", - "ToolManager | Tool Definitions Success", - f"Successfully obtained {len(one_server_for_prompt['tools'])} tool definitions from server '{server_name}'.", + "error", + "ToolManager | Unknown Parameter Type", + f"Error: Unknown parameter type for server '{server_name}': {type(server_params)}", + ) + raise TypeError( + f"Unknown server params type for {server_name}: {type(server_params)}" ) - all_servers_for_prompt.append(one_server_for_prompt) - except Exception as e: + self._log( + "info", + "ToolManager | Tool Definitions Success", + f"Successfully obtained {len(one_server_for_prompt['tools'])} tool definitions from server '{server_name}'.", + ) + + except Exception as e: + self._log( + "error", + "ToolManager | Connection Error", + f"Error: Unable to connect or get tools from server '{server_name}': {e}", + ) + # Still add server entry, but mark tool list as empty or include error information + one_server_for_prompt["tools"] = [{"error": f"Unable to fetch tools: {e}"}] + + return one_server_for_prompt + + async def get_all_tool_definitions(self): + """ + Connect to all configured servers and get their tool definitions. + Servers are initialized in parallel using asyncio.gather() for speed. + Returns a list suitable for passing to the Prompt generator. + """ + # Launch all server connections in parallel + results = await asyncio.gather( + *(self._get_single_server_tools(config) for config in self.server_configs), + return_exceptions=True, + ) + + all_servers_for_prompt = [] + for config, result in zip(self.server_configs, results): + if isinstance(result, Exception): + server_name = config["name"] self._log( "error", "ToolManager | Connection Error", - f"Error: Unable to connect or get tools from server '{server_name}': {e}", + f"Error: Unable to connect or get tools from server '{server_name}': {result}", ) - # Still add server entry, but mark tool list as empty or include error information - one_server_for_prompt["tools"] = [ - {"error": f"Unable to fetch tools: {e}"} - ] - all_servers_for_prompt.append(one_server_for_prompt) + all_servers_for_prompt.append( + { + "name": server_name, + "tools": [{"error": f"Unable to fetch tools: {result}"}], + } + ) + else: + all_servers_for_prompt.append(result) return all_servers_for_prompt @with_timeout(1200) async def execute_tool_call(self, server_name, tool_name, arguments) -> Any: """ - Execute a single tool call. - :param server_name: Server name - :param tool_name: Tool name - :param arguments: Tool arguments dictionary - :return: Dictionary containing result or error + Execute a single tool call using a persistent MCP session. + Sessions are lazily created on first call and reused across subsequent calls. """ - # Original remote server call logic server_params = self.get_server_params(server_name) if not server_params: self._log( @@ -221,10 +318,11 @@ async def execute_tool_call(self, server_name, tool_name, arguments) -> Any: self._log( "info", "ToolManager | Tool Call Start", - f"Connecting to server '{server_name}' to call tool '{tool_name}'", + f"Calling tool '{tool_name}' on server '{server_name}' (persistent session)", metadata={"arguments": arguments}, ) + # Playwright keeps its own special session (browser state) if server_name == "playwright": try: if self.browser_session is None: @@ -244,138 +342,100 @@ async def execute_tool_call(self, server_name, tool_name, arguments) -> Any: "tool_name": tool_name, "error": f"Tool call failed: {str(e)}", } - else: - try: - result_content = None - if isinstance(server_params, StdioServerParameters): - async with stdio_client(server_params) as (read, write): - async with ClientSession( - read, write, sampling_callback=None - ) as session: - await session.initialize() - try: - tool_result = await session.call_tool( - tool_name, arguments=arguments - ) - result_content = ( - tool_result.content[-1].text - if tool_result.content - else "" - ) - # post hoc check for browsing agent reading answers from hf datsets - if self._should_block_hf_scraping(tool_name, arguments): - result_content = "You are trying to scrape a Hugging Face dataset for answers, please do not use the scrape tool for this purpose." - except Exception as tool_error: - self._log( - "error", - "ToolManager | Tool Execution Error", - f"Tool execution error: {tool_error}", - ) - return { - "server_name": server_name, - "tool_name": tool_name, - "error": f"Tool execution failed: {str(tool_error)}", - } - elif isinstance(server_params, str) and server_params.startswith( - ("http://", "https://") - ): - async with sse_client(server_params) as (read, write): - async with ClientSession( - read, write, sampling_callback=None - ) as session: - await session.initialize() - try: - tool_result = await session.call_tool( - tool_name, arguments=arguments - ) - result_content = ( - tool_result.content[-1].text - if tool_result.content - else "" - ) - # post hoc check for browsing agent reading answers from hf datsets - if self._should_block_hf_scraping(tool_name, arguments): - result_content = "You are trying to scrape a Hugging Face dataset for answers, please do not use the scrape tool for this purpose." - except Exception as tool_error: - self._log( - "error", - "ToolManager | Tool Execution Error", - f"Tool execution error: {tool_error}", - ) - return { - "server_name": server_name, - "tool_name": tool_name, - "error": f"Tool execution failed: {str(tool_error)}", - } - else: - raise TypeError( - f"Unknown server params type for {server_name}: {type(server_params)}" + + # All other servers: use persistent session + try: + session = self._get_or_create_session(server_name) + result_content = await session.call_tool(tool_name, arguments) + + # post hoc check for browsing agent reading answers from hf datasets + if self._should_block_hf_scraping(tool_name, arguments): + result_content = "You are trying to scrape a Hugging Face dataset for answers, please do not use the scrape tool for this purpose." + + self._log( + "info", + "ToolManager | Tool Call Success", + f"Tool '{tool_name}' (server: '{server_name}') called successfully.", + ) + + return { + "server_name": server_name, + "tool_name": tool_name, + "result": result_content, + } + + except Exception as outer_e: + self._log( + "error", + "ToolManager | Tool Call Failed", + f"Error: Failed to call tool '{tool_name}' (server: '{server_name}'): {outer_e}", + ) + + error_message = str(outer_e) + + if ( + tool_name in ["scrape", "scrape_website"] + and "unhandled errors" in error_message + and "url" in arguments + and arguments["url"] is not None + ): + try: + self._log( + "info", + "ToolManager | Fallback Attempt", + "Attempting fallback using MarkItDown...", + ) + from markitdown import MarkItDown + + md = MarkItDown( + docintel_endpoint="" ) + result = md.convert(arguments["url"]) + self._log( + "info", + "ToolManager | Fallback Success", + "MarkItDown fallback successful", + ) + return { + "server_name": server_name, + "tool_name": tool_name, + "result": result.text_content, + } + except Exception as inner_e: + self._log( + "error", + "ToolManager | Fallback Failed", + f"Fallback also failed: {inner_e}", + ) + + return { + "server_name": server_name, + "tool_name": tool_name, + "error": f"Tool call failed: {error_message}", + } + async def close_all_sessions(self): + """Close all persistent MCP sessions. Call at task end.""" + for name, session in self._sessions.items(): + try: + await session.close() self._log( "info", - "ToolManager | Tool Call Success", - f"Tool '{tool_name}' (server: '{server_name}') called successfully.", + "ToolManager | Session Closed", + f"Closed persistent session for server '{name}'", ) - - return { - "server_name": server_name, - "tool_name": tool_name, - "result": result_content, # Return extracted text content - } - - except Exception as outer_e: # Rename this to outer_e to avoid shadowing + except Exception as e: self._log( "error", - "ToolManager | Tool Call Failed", - f"Error: Failed to call tool '{tool_name}' (server: '{server_name}'): {outer_e}", + "ToolManager | Session Close Error", + f"Error closing session for server '{name}': {e}", ) + self._sessions.clear() - # Store the original error message for later use - error_message = str(outer_e) - - if ( - tool_name in ["scrape", "scrape_website"] - and "unhandled errors" in error_message - and "url" in arguments - and arguments["url"] is not None - ): - try: - self._log( - "info", - "ToolManager | Fallback Attempt", - "Attempting fallback using MarkItDown...", - ) - from markitdown import MarkItDown - - md = MarkItDown( - docintel_endpoint="" - ) - result = md.convert(arguments["url"]) - self._log( - "info", - "ToolManager | Fallback Success", - "MarkItDown fallback successful", - ) - return { - "server_name": server_name, - "tool_name": tool_name, - "result": result.text_content, # Return extracted text content - } - except ( - Exception - ) as inner_e: # Use a different name to avoid shadowing - # Log the inner exception if needed - self._log( - "error", - "ToolManager | Fallback Failed", - f"Fallback also failed: {inner_e}", - ) - # No need for pass here as we'll continue to the return statement - - # Always use the outer exception for the final error response - return { - "server_name": server_name, - "tool_name": tool_name, - "error": f"Tool call failed: {error_message}", - } + # Also close playwright session + if self.browser_session is not None: + try: + await self.browser_session.close() + except Exception: + pass + self.browser_session = None