Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,37 @@

import asyncio
import logging
from typing import Optional
from typing import Any, Dict, Optional

import httpx
from google.adk.tools import BaseTool
from google.adk.tools.mcp_tool.mcp_tool import McpTool
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext
from google.adk.tools.tool_context import ToolContext
from mcp.shared.exceptions import McpError

logger = logging.getLogger("kagent_adk." + __name__)

# Connection errors that indicate an unreachable MCP server.
# When these occur, the tool should return an error message to the LLM
# instead of raising, so the LLM can respond to the user rather than
# retrying the broken tool indefinitely.
#
# - ConnectionError: stdlib base for ConnectionResetError, ConnectionRefusedError, etc.
# - TimeoutError: stdlib timeout (e.g. socket.timeout)
# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError,
# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc.
# These do NOT inherit from stdlib ConnectionError/OSError.
# - McpError: raised by mcp.shared.session.send_request() when the underlying
# SSE/HTTP stream drops or a tool call hits the session read timeout. The MCP
# client wraps the transport-level error into McpError before it reaches us.
_CONNECTION_ERROR_TYPES = (
ConnectionError,
TimeoutError,
httpx.TransportError,
McpError,
)


def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
message = "Failed to create MCP session: operation cancelled"
Expand All @@ -17,6 +41,36 @@ def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
return asyncio.CancelledError(message)


class ConnectionSafeMcpTool(McpTool):
"""McpTool wrapper that catches connection errors and returns them as
error text to the LLM instead of raising.

Without this, a persistent connection failure (e.g. "connection reset by
peer") causes the LLM to retry the tool call in a tight loop, burning
100% CPU for up to max_llm_calls iterations.

See: https://github.com/kagent-dev/kagent/issues/1530
"""

async def run_async(
self,
*,
args: Dict[str, Any],
tool_context: ToolContext,
) -> Dict[str, Any]:
try:
return await super().run_async(args=args, tool_context=tool_context)
except _CONNECTION_ERROR_TYPES as error:
error_message = (
f"MCP tool '{self.name}' failed due to a connection error: "
f"{type(error).__name__}: {error}. "
"The MCP server may be unreachable. "
"Do not retry this tool — inform the user about the failure."
)
logger.error(error_message)
return {"error": error_message}


class KAgentMcpToolset(McpToolset):
"""McpToolset variant that catches and enriches errors during MCP session setup
and handles cancel scope issues during cleanup.
Expand All @@ -27,10 +81,26 @@ class KAgentMcpToolset(McpToolset):

async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]:
try:
return await super().get_tools(readonly_context)
tools = await super().get_tools(readonly_context)
except asyncio.CancelledError as error:
raise _enrich_cancelled_error(error) from error

# Wrap each McpTool with ConnectionSafeMcpTool so that connection
# errors are returned as error text instead of raised.
# Uses __new__ + __dict__ copy to re-type the instance without calling
# McpTool.__init__ (which requires connection params we don't have).
# This is safe because McpTool uses plain instance attributes, not
# __slots__ or descriptors.
wrapped_tools: list[BaseTool] = []
for tool in tools:
if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool):
safe_tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
safe_tool.__dict__.update(tool.__dict__)
wrapped_tools.append(safe_tool)
else:
wrapped_tools.append(tool)
return wrapped_tools

async def close(self) -> None:
"""Close MCP sessions and suppress known anyio cancel scope cleanup errors.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Tests for ConnectionSafeMcpTool — connection errors are returned as
error text to the LLM instead of raised, preventing tight retry loops.

See: https://github.com/kagent-dev/kagent/issues/1530
"""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from google.adk.tools.mcp_tool.mcp_tool import McpTool
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from mcp.shared.exceptions import McpError

from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset


def _make_connection_safe_tool(side_effect):
"""Create a ConnectionSafeMcpTool with a mocked super().run_async."""
tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
tool.name = "test-tool"
tool._mcp_tool = MagicMock()
tool._mcp_tool.name = "test-tool"
tool._mcp_session_manager = AsyncMock()
tool._header_provider = None
tool._auth_config = None
tool._confirmation_config = None
tool._progress_callback = None
tool._parent_run_async = AsyncMock(side_effect=side_effect)
return tool


@pytest.mark.asyncio
async def test_connection_reset_error_returns_error_dict():
"""ConnectionResetError should be caught and returned as error text."""
tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock())

assert "error" in result
assert "ConnectionResetError" in result["error"]
assert "Connection reset by peer" in result["error"]
assert "Do not retry" in result["error"]


@pytest.mark.asyncio
async def test_connection_refused_error_returns_error_dict():
"""ConnectionRefusedError should be caught and returned as error text."""
tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectionRefusedError" in result["error"]


@pytest.mark.asyncio
async def test_timeout_error_returns_error_dict():
"""TimeoutError should be caught and returned as error text."""
tool = _make_connection_safe_tool(TimeoutError("timed out"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "TimeoutError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_connect_error_returns_error_dict():
"""httpx.ConnectError should be caught via httpx.TransportError."""
tool = _make_connection_safe_tool(httpx.ConnectError("connection refused"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_read_error_returns_error_dict():
"""httpx.ReadError (connection reset by peer) should be caught."""
tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ReadError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_connect_timeout_returns_error_dict():
"""httpx.ConnectTimeout should be caught via httpx.TransportError."""
tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectTimeout" in result["error"]


@pytest.mark.asyncio
async def test_mcp_error_returns_error_dict():
"""McpError (raised by MCP session on stream drop / read timeout) should be caught."""
from mcp.types import ErrorData

tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout")))

with patch.object(McpTool, "run_async", tool._parent_run_async):
result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "McpError" in result["error"]
assert "session read timeout" in result["error"]


@pytest.mark.asyncio
async def test_non_connection_error_still_raises():
"""Non-connection errors (e.g. ValueError) should still propagate."""
tool = _make_connection_safe_tool(ValueError("bad argument"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
with pytest.raises(ValueError, match="bad argument"):
await tool.run_async(args={}, tool_context=MagicMock())


@pytest.mark.asyncio
async def test_cancelled_error_still_raises():
"""CancelledError must propagate — it's not a connection error."""
tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled"))

with patch.object(McpTool, "run_async", tool._parent_run_async):
with pytest.raises(asyncio.CancelledError):
await tool.run_async(args={}, tool_context=MagicMock())


@pytest.mark.asyncio
async def test_get_tools_wraps_mcp_tools():
"""KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool."""
# Create a real McpTool instance (bypassing __init__) so isinstance checks work
fake_mcp_tool = McpTool.__new__(McpTool)
fake_mcp_tool.name = "wrapped-tool"
fake_mcp_tool._some_attr = "value"

# A non-McpTool object that should pass through unchanged
fake_other_tool = MagicMock()
fake_other_tool.name = "other-tool"

toolset = KAgentMcpToolset.__new__(KAgentMcpToolset)

async def mock_super_get_tools(self_arg, readonly_context=None):
return [fake_mcp_tool, fake_other_tool]

with patch.object(McpToolset, "get_tools", mock_super_get_tools):
tools = await toolset.get_tools()

assert len(tools) == 2
assert isinstance(tools[0], ConnectionSafeMcpTool)
assert tools[0].name == "wrapped-tool"
assert tools[0]._some_attr == "value"
# Non-McpTool should pass through unchanged
assert tools[1] is fake_other_tool
Loading