diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 0acb1345a..6978e1a8c 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -11,6 +11,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.session import MessageHandlerFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.message import SessionMessage @@ -103,6 +104,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """ Args: @@ -124,6 +126,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, used for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__(use_structured_content=use_structured_content) self.session: ClientSession | None = None @@ -135,6 +139,7 @@ def __init__( self.client_session_timeout_seconds = client_session_timeout_seconds self.max_retry_attempts = max_retry_attempts self.retry_backoff_seconds_base = retry_backoff_seconds_base + self.message_handler = message_handler # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True @@ -272,6 +277,7 @@ async def connect(self): timedelta(seconds=self.client_session_timeout_seconds) if self.client_session_timeout_seconds else None, + message_handler=self.message_handler, ) ) server_result = await session.initialize() @@ -394,6 +400,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the stdio transport. @@ -421,6 +428,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -429,6 +438,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = StdioServerParameters( @@ -492,6 +502,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -521,6 +532,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -529,6 +542,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = params @@ -592,6 +606,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -622,6 +637,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -630,6 +647,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = params diff --git a/tests/mcp/test_message_handler.py b/tests/mcp/test_message_handler.py new file mode 100644 index 000000000..c0a04ca33 --- /dev/null +++ b/tests/mcp/test_message_handler.py @@ -0,0 +1,109 @@ +import contextlib + +import anyio +import pytest +from mcp.client.session import MessageHandlerFnT +from mcp.shared.message import SessionMessage +from mcp.types import InitializeResult + +from agents.mcp.server import ( + MCPServerSse, + MCPServerStdio, + MCPServerStreamableHttp, + _MCPServerWithClientSession, +) + + +class _StubClientSession: + """Stub ClientSession that records the configured message handler.""" + + def __init__( + self, + read_stream, + write_stream, + read_timeout_seconds, + *, + message_handler=None, + **_: object, + ) -> None: + self.message_handler = message_handler + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def initialize(self) -> InitializeResult: + return InitializeResult( + protocolVersion="2024-11-05", + capabilities={}, + serverInfo={"name": "stub", "version": "1.0"}, + ) + + +class _MessageHandlerTestServer(_MCPServerWithClientSession): + def __init__(self, handler: MessageHandlerFnT | None): + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + message_handler=handler, + ) + + def create_streams(self): + @contextlib.asynccontextmanager + async def _streams(): + send_stream, recv_stream = ( + anyio.create_memory_object_stream[SessionMessage | Exception](1)) + try: + yield recv_stream, send_stream, None + finally: + await recv_stream.aclose() + await send_stream.aclose() + + return _streams() + + @property + def name(self) -> str: + return "test-server" + + +@pytest.mark.asyncio +async def test_client_session_receives_message_handler(monkeypatch): + captured: dict[str, object] = {} + + def _recording_client_session(*args, **kwargs): + session = _StubClientSession(*args, **kwargs) + captured["message_handler"] = session.message_handler + return session + + monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session) + + async def handler(message: SessionMessage) -> None: + del message + + server = _MessageHandlerTestServer(handler) + + try: + await server.connect() + finally: + await server.cleanup() + + assert captured["message_handler"] is handler + + +@pytest.mark.parametrize( + "server_cls, params", + [ + (MCPServerSse, {"url": "https://example.com"}), + (MCPServerStreamableHttp, {"url": "https://example.com"}), + (MCPServerStdio, {"command": "python"}), + ], +) +def test_message_handler_propagates_to_server_base(server_cls, params): + def handler(message: SessionMessage) -> None: + del message + + server = server_cls(params, message_handler=handler) + + assert server.message_handler is handler