Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,12 @@ async def cleanup(self):
else:
logger.error(f"Error cleaning up server: {e}")
finally:
# Always reset the exit stack so we don't retain callbacks/references from the
# previous connection. This keeps teardown deterministic and allows reconnecting
# with a fresh stack even if cleanup encountered recoverable errors.
self.exit_stack = AsyncExitStack()
self.session = None
self.server_initialize_result = None


class MCPServerStdioParams(TypedDict):
Expand Down
13 changes: 13 additions & 0 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Literal, TypeVar, cast

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import (
Expand Down Expand Up @@ -124,6 +127,16 @@ class RunResultBase(abc.ABC):
_trace_state: TraceState | None = field(default=None, init=False, repr=False)
"""Serialized trace metadata captured during the run."""

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
# RunResult objects are runtime values; schema generation should treat them as instances
# instead of recursively traversing internal dataclass annotations.
return core_schema.is_instance_schema(cls)

@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
Expand Down
42 changes: 42 additions & 0 deletions tests/mcp/test_connect_disconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from .helpers import DummyStreamsContextManager, tee


class CountingStreamsContextManager:
def __init__(self, counter: dict[str, int]):
self.counter = counter

async def __aenter__(self):
self.counter["enter"] += 1
return (object(), object())

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.counter["exit"] += 1


@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
Expand Down Expand Up @@ -67,3 +79,33 @@ async def test_manual_connect_disconnect_works(

await server.cleanup()
assert server.session is None, "Server should be disconnected"


@pytest.mark.asyncio
@patch("agents.mcp.server.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("agents.mcp.server.stdio_client")
async def test_cleanup_resets_exit_stack_and_reconnects(
mock_stdio_client: AsyncMock, mock_initialize: AsyncMock
):
counter = {"enter": 0, "exit": 0}
mock_stdio_client.side_effect = lambda params: CountingStreamsContextManager(counter)

server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)

await server.connect()
original_exit_stack = server.exit_stack

await server.cleanup()
assert server.session is None
assert server.exit_stack is not original_exit_stack
assert server.server_initialize_result is None
assert counter == {"enter": 1, "exit": 1}

await server.connect()
await server.cleanup()
assert counter == {"enter": 2, "exit": 2}
12 changes: 11 additions & 1 deletion tests/test_result_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from agents import (
Agent,
Expand Down Expand Up @@ -45,6 +45,16 @@ class Foo(BaseModel):
bar: int


def test_run_result_streaming_supports_pydantic_model_rebuild() -> None:
class StreamingRunContainer(BaseModel):
query_id: str
run_stream: RunResultStreaming | None

model_config = ConfigDict(arbitrary_types_allowed=True)

StreamingRunContainer.model_rebuild()


def _create_message(text: str) -> ResponseOutputMessage:
return ResponseOutputMessage(
id="msg",
Expand Down