diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..9c090c141 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -2,13 +2,37 @@ import asyncio import logging -from typing import Optional +from typing import Any, Dict, Optional +import httpx from google.adk.tools import BaseTool +from google.adk.tools.mcp_tool.mcp_tool import McpTool from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext +from google.adk.tools.tool_context import ToolContext +from mcp.shared.exceptions import McpError logger = logging.getLogger("kagent_adk." + __name__) +# Connection errors that indicate an unreachable MCP server. +# When these occur, the tool should return an error message to the LLM +# instead of raising, so the LLM can respond to the user rather than +# retrying the broken tool indefinitely. +# +# - ConnectionError: stdlib base for ConnectionResetError, ConnectionRefusedError, etc. +# - TimeoutError: stdlib timeout (e.g. socket.timeout) +# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError, +# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc. +# These do NOT inherit from stdlib ConnectionError/OSError. +# - McpError: raised by mcp.shared.session.send_request() when the underlying +# SSE/HTTP stream drops or a tool call hits the session read timeout. The MCP +# client wraps the transport-level error into McpError before it reaches us. +_CONNECTION_ERROR_TYPES = ( + ConnectionError, + TimeoutError, + httpx.TransportError, + McpError, +) + def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: message = "Failed to create MCP session: operation cancelled" @@ -17,6 +41,36 @@ def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: return asyncio.CancelledError(message) +class ConnectionSafeMcpTool(McpTool): + """McpTool wrapper that catches connection errors and returns them as + error text to the LLM instead of raising. + + Without this, a persistent connection failure (e.g. "connection reset by + peer") causes the LLM to retry the tool call in a tight loop, burning + 100% CPU for up to max_llm_calls iterations. + + See: https://github.com/kagent-dev/kagent/issues/1530 + """ + + async def run_async( + self, + *, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Dict[str, Any]: + try: + return await super().run_async(args=args, tool_context=tool_context) + except _CONNECTION_ERROR_TYPES as error: + error_message = ( + f"MCP tool '{self.name}' failed due to a connection error: " + f"{type(error).__name__}: {error}. " + "The MCP server may be unreachable. " + "Do not retry this tool — inform the user about the failure." + ) + logger.error(error_message) + return {"error": error_message} + + class KAgentMcpToolset(McpToolset): """McpToolset variant that catches and enriches errors during MCP session setup and handles cancel scope issues during cleanup. @@ -27,10 +81,26 @@ class KAgentMcpToolset(McpToolset): async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: - return await super().get_tools(readonly_context) + tools = await super().get_tools(readonly_context) except asyncio.CancelledError as error: raise _enrich_cancelled_error(error) from error + # Wrap each McpTool with ConnectionSafeMcpTool so that connection + # errors are returned as error text instead of raised. + # Uses __new__ + __dict__ copy to re-type the instance without calling + # McpTool.__init__ (which requires connection params we don't have). + # This is safe because McpTool uses plain instance attributes, not + # __slots__ or descriptors. + wrapped_tools: list[BaseTool] = [] + for tool in tools: + if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool): + safe_tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool) + safe_tool.__dict__.update(tool.__dict__) + wrapped_tools.append(safe_tool) + else: + wrapped_tools.append(tool) + return wrapped_tools + async def close(self) -> None: """Close MCP sessions and suppress known anyio cancel scope cleanup errors. diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py b/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py new file mode 100644 index 000000000..f8d2bbcfb --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py @@ -0,0 +1,168 @@ +"""Tests for ConnectionSafeMcpTool — connection errors are returned as +error text to the LLM instead of raised, preventing tight retry loops. + +See: https://github.com/kagent-dev/kagent/issues/1530 +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from google.adk.tools.mcp_tool.mcp_tool import McpTool +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from mcp.shared.exceptions import McpError + +from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset + + +def _make_connection_safe_tool(side_effect): + """Create a ConnectionSafeMcpTool with a mocked super().run_async.""" + tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool) + tool.name = "test-tool" + tool._mcp_tool = MagicMock() + tool._mcp_tool.name = "test-tool" + tool._mcp_session_manager = AsyncMock() + tool._header_provider = None + tool._auth_config = None + tool._confirmation_config = None + tool._progress_callback = None + tool._parent_run_async = AsyncMock(side_effect=side_effect) + return tool + + +@pytest.mark.asyncio +async def test_connection_reset_error_returns_error_dict(): + """ConnectionResetError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectionResetError" in result["error"] + assert "Connection reset by peer" in result["error"] + assert "Do not retry" in result["error"] + + +@pytest.mark.asyncio +async def test_connection_refused_error_returns_error_dict(): + """ConnectionRefusedError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectionRefusedError" in result["error"] + + +@pytest.mark.asyncio +async def test_timeout_error_returns_error_dict(): + """TimeoutError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(TimeoutError("timed out")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "TimeoutError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_connect_error_returns_error_dict(): + """httpx.ConnectError should be caught via httpx.TransportError.""" + tool = _make_connection_safe_tool(httpx.ConnectError("connection refused")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_read_error_returns_error_dict(): + """httpx.ReadError (connection reset by peer) should be caught.""" + tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ReadError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_connect_timeout_returns_error_dict(): + """httpx.ConnectTimeout should be caught via httpx.TransportError.""" + tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectTimeout" in result["error"] + + +@pytest.mark.asyncio +async def test_mcp_error_returns_error_dict(): + """McpError (raised by MCP session on stream drop / read timeout) should be caught.""" + from mcp.types import ErrorData + + tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout"))) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "McpError" in result["error"] + assert "session read timeout" in result["error"] + + +@pytest.mark.asyncio +async def test_non_connection_error_still_raises(): + """Non-connection errors (e.g. ValueError) should still propagate.""" + tool = _make_connection_safe_tool(ValueError("bad argument")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + with pytest.raises(ValueError, match="bad argument"): + await tool.run_async(args={}, tool_context=MagicMock()) + + +@pytest.mark.asyncio +async def test_cancelled_error_still_raises(): + """CancelledError must propagate — it's not a connection error.""" + tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled")) + + with patch.object(McpTool, "run_async", tool._parent_run_async): + with pytest.raises(asyncio.CancelledError): + await tool.run_async(args={}, tool_context=MagicMock()) + + +@pytest.mark.asyncio +async def test_get_tools_wraps_mcp_tools(): + """KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool.""" + # Create a real McpTool instance (bypassing __init__) so isinstance checks work + fake_mcp_tool = McpTool.__new__(McpTool) + fake_mcp_tool.name = "wrapped-tool" + fake_mcp_tool._some_attr = "value" + + # A non-McpTool object that should pass through unchanged + fake_other_tool = MagicMock() + fake_other_tool.name = "other-tool" + + toolset = KAgentMcpToolset.__new__(KAgentMcpToolset) + + async def mock_super_get_tools(self_arg, readonly_context=None): + return [fake_mcp_tool, fake_other_tool] + + with patch.object(McpToolset, "get_tools", mock_super_get_tools): + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert isinstance(tools[0], ConnectionSafeMcpTool) + assert tools[0].name == "wrapped-tool" + assert tools[0]._some_attr == "value" + # Non-McpTool should pass through unchanged + assert tools[1] is fake_other_tool