Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ agent = Agent(
mcp_config={
# Try to convert MCP tool schemas to strict JSON schema.
"convert_schemas_to_strict": True,
# Optional: namespace tools as "<server_name>_<tool_name>" to
# avoid collisions when multiple MCP servers expose the same tool name.
"prefix_tool_names_with_server_name": True,
# If None, MCP tool failures are raised as exceptions instead of
# returning model-visible error text.
"failure_error_function": None,
Expand All @@ -47,6 +50,7 @@ agent = Agent(
Notes:

- `convert_schemas_to_strict` is best-effort. If a schema cannot be converted, the original schema is used.
- `prefix_tool_names_with_server_name` sanitizes server names and prefixes MCP tool names, for example `github_create_issue`.
- `failure_error_function` controls how MCP tool call failures are surfaced to the model.
- When `failure_error_function` is unset, the SDK uses the default tool error formatter.
- Server-level `failure_error_function` overrides `Agent.mcp_config["failure_error_function"]` for that server.
Expand Down
9 changes: 9 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class MCPConfig(TypedDict):
default_tool_error_function.
"""

prefix_tool_names_with_server_name: NotRequired[bool]
"""If True, MCP tools are exposed as `<server_name>_<tool_name>` to avoid collisions across
servers that publish the same tool names. Defaults to False.
"""


@dataclass
class AgentBase(Generic[TContext]):
Expand Down Expand Up @@ -182,12 +187,16 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[
failure_error_function = self.mcp_config.get(
"failure_error_function", default_tool_error_function
)
prefix_tool_names_with_server_name = self.mcp_config.get(
"prefix_tool_names_with_server_name", False
)
return await MCPUtil.get_all_function_tools(
self.mcp_servers,
convert_schemas_to_strict,
run_context,
self,
failure_error_function=failure_error_function,
prefix_tool_names_with_server_name=prefix_tool_names_with_server_name,
)

async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
Expand Down
53 changes: 38 additions & 15 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ async def get_all_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
prefix_tool_names_with_server_name: bool = False,
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
Expand All @@ -189,6 +190,7 @@ async def get_all_function_tools(
run_context,
agent,
failure_error_function=failure_error_function,
prefix_tool_names_with_server_name=prefix_tool_names_with_server_name,
)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
Expand All @@ -209,24 +211,39 @@ async def get_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
prefix_tool_names_with_server_name: bool = False,
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools(run_context, agent)
span.span_data.result = [tool.name for tool in tools]

tool_name_prefix = (
cls._server_tool_name_prefix(server.name) if prefix_tool_names_with_server_name else ""
)
return [
cls.to_function_tool(
tool,
server,
convert_schemas_to_strict,
agent,
failure_error_function=failure_error_function,
tool_name_override=f"{tool_name_prefix}{tool.name}" if tool_name_prefix else None,
)
for tool in tools
]

@staticmethod
def _server_tool_name_prefix(server_name: str) -> str:
normalized = "".join(
char if char.isalnum() or char in ("_", "-") else "_" for char in server_name
)
normalized = normalized.strip("_-")
if not normalized:
normalized = "server"
return f"{normalized}_"

@classmethod
def to_function_tool(
cls,
Expand All @@ -235,6 +252,7 @@ def to_function_tool(
convert_schemas_to_strict: bool,
agent: AgentBase | None = None,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
tool_name_override: str | None = None,
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool.
Expand All @@ -243,7 +261,10 @@ def to_function_tool(
When omitted, this helper preserves the historical behavior and leaves
``needs_approval`` disabled.
"""
invoke_func_impl = functools.partial(cls.invoke_mcp_tool, server, tool)
tool_name = tool_name_override or tool.name
invoke_func_impl = functools.partial(
cls.invoke_mcp_tool, server, tool, tool_display_name=tool_name
)
effective_failure_error_function = server._get_failure_error_function(
failure_error_function
)
Expand Down Expand Up @@ -280,18 +301,18 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
SpanError(
message="Error running tool (non-fatal)",
data={
"tool_name": tool.name,
"tool_name": tool_name,
"error": str(e),
},
)
)

# Log the error.
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} failed")
logger.debug(f"MCP tool {tool_name} failed")
else:
logger.error(
f"MCP tool {tool.name} failed: {input_json} {e}",
f"MCP tool {tool_name} failed: {input_json} {e}",
exc_info=e,
)

Expand All @@ -302,7 +323,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
) = server._get_needs_approval_for_tool(tool, agent)

return FunctionTool(
name=tool.name,
name=tool_name,
description=tool.description or "",
params_json_schema=schema,
on_invoke_tool=invoke_func,
Expand Down Expand Up @@ -361,26 +382,28 @@ async def invoke_mcp_tool(
input_json: str,
*,
meta: dict[str, Any] | None = None,
tool_display_name: str | None = None,
) -> ToolOutput:
"""Invoke an MCP tool and return the result as ToolOutput."""
tool_name = tool_display_name or tool.name
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
logger.debug(f"Invalid JSON input for tool {tool_name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
f"Invalid JSON input for tool {tool_name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
logger.debug(f"Invoking MCP tool {tool_name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}")

try:
resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data)
resolved_meta = await cls._resolve_meta(server, context, tool_name, json_data)
merged_meta = cls._merge_mcp_meta(resolved_meta, meta)
if merged_meta is None:
result = await server.call_tool(tool.name, json_data)
Expand All @@ -390,15 +413,15 @@ async def invoke_mcp_tool(
# Re-raise UserError as-is (it already has a good message)
raise
except Exception as e:
logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}")
logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}")
raise AgentsException(
f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}"
f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
logger.debug(f"MCP tool {tool_name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")
logger.debug(f"MCP tool {tool_name} returned {result}")

# If structured content is requested and available, use it exclusively
tool_output: ToolOutput
Expand Down
127 changes: 126 additions & 1 deletion tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel, TypeAdapter

from agents import Agent, FunctionTool, RunContextWrapper, default_tool_error_function
from agents.exceptions import AgentsException, ModelBehaviorError
from agents.exceptions import AgentsException, ModelBehaviorError, UserError
from agents.mcp import MCPServer, MCPUtil
from agents.tool_context import ToolContext

Expand Down Expand Up @@ -80,6 +80,97 @@ async def test_get_all_function_tools():
assert all(tool.name in names for tool in tools)


@pytest.mark.asyncio
async def test_get_all_function_tools_duplicate_names_raise_by_default():
server1 = FakeMCPServer(server_name="github")
server1.add_tool("create_issue", {})

server2 = FakeMCPServer(server_name="linear")
server2.add_tool("create_issue", {})

run_context = RunContextWrapper(context=None)
agent = Agent(name="test_agent", instructions="Test agent")

with pytest.raises(UserError, match="Duplicate tool names found across MCP servers"):
await MCPUtil.get_all_function_tools([server1, server2], False, run_context, agent)


@pytest.mark.asyncio
async def test_get_all_function_tools_can_prefix_with_server_name():
server1 = FakeMCPServer(server_name="GitHub MCP Server")
server1.add_tool("create_issue", {})

server2 = FakeMCPServer(server_name="linear")
server2.add_tool("create_issue", {})

run_context = RunContextWrapper(context=None)
agent = Agent(
name="test_agent",
instructions="Test agent",
mcp_servers=[server1, server2],
mcp_config={"prefix_tool_names_with_server_name": True},
)

tools = await agent.get_mcp_tools(run_context)
tool_names = {tool.name for tool in tools}
assert tool_names == {"GitHub_MCP_Server_create_issue", "linear_create_issue"}

github_tool = next(tool for tool in tools if tool.name == "GitHub_MCP_Server_create_issue")
linear_tool = next(tool for tool in tools if tool.name == "linear_create_issue")
assert isinstance(github_tool, FunctionTool)
assert isinstance(linear_tool, FunctionTool)

github_ctx = ToolContext(
context=None,
tool_name=github_tool.name,
tool_call_id="prefixed_call_1",
tool_arguments='{"title":"a"}',
)
linear_ctx = ToolContext(
context=None,
tool_name=linear_tool.name,
tool_call_id="prefixed_call_2",
tool_arguments='{"title":"b"}',
)

github_result = await github_tool.on_invoke_tool(github_ctx, '{"title":"a"}')
linear_result = await linear_tool.on_invoke_tool(linear_ctx, '{"title":"b"}')
assert isinstance(github_result, dict)
assert isinstance(linear_result, dict)
assert server1.tool_calls == ["create_issue"]
assert server2.tool_calls == ["create_issue"]


@pytest.mark.asyncio
async def test_get_all_function_tools_prefix_falls_back_for_empty_server_name_slug():
server = FakeMCPServer(server_name="!!!")
server.add_tool("search", {})

run_context = RunContextWrapper(context=None)
agent = Agent(
name="test_agent",
instructions="Test agent",
mcp_servers=[server],
mcp_config={"prefix_tool_names_with_server_name": True},
)

tools = await agent.get_mcp_tools(run_context)
assert len(tools) == 1
prefixed_tool = tools[0]
assert isinstance(prefixed_tool, FunctionTool)
assert prefixed_tool.name == "server_search"

tool_context = ToolContext(
context=None,
tool_name=prefixed_tool.name,
tool_call_id="prefixed_call_3",
tool_arguments='{"query":"docs"}',
)
result = await prefixed_tool.on_invoke_tool(tool_context, '{"query":"docs"}')
assert isinstance(result, dict)
assert server.tool_calls == ["search"]


@pytest.mark.asyncio
async def test_invoke_mcp_tool():
"""Test that the invoke_mcp_tool function invokes an MCP tool and returns the result."""
Expand Down Expand Up @@ -290,6 +381,40 @@ async def call_tool(
assert "Timed out" in result


@pytest.mark.asyncio
async def test_mcp_tool_failure_logs_prefixed_name_when_tool_data_logging_enabled(
caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch
):
import agents._debug as debug_settings

caplog.set_level(logging.ERROR)
monkeypatch.setattr(debug_settings, "DONT_LOG_TOOL_DATA", False)

server = CrashingFakeMCPServer()
server.add_tool("crashing_tool", {})

mcp_tool = MCPTool(name="crashing_tool", inputSchema={})
agent = Agent(name="test-agent")
function_tool = MCPUtil.to_function_tool(
mcp_tool,
server,
convert_schemas_to_strict=False,
agent=agent,
tool_name_override="prefixed_crashing_tool",
)

tool_context = ToolContext(
context=None,
tool_name="prefixed_crashing_tool",
tool_call_id="test_call_prefixed_log",
tool_arguments="{}",
)
result = await function_tool.on_invoke_tool(tool_context, "{}")

assert isinstance(result, str)
assert "MCP tool prefixed_crashing_tool failed" in caplog.text


@pytest.mark.asyncio
async def test_to_function_tool_legacy_call_without_agent_uses_server_policy():
"""Legacy three-argument to_function_tool calls should honor server policy."""
Expand Down