From f7391550c03454edda5d3fca47e4f8f636ce7966 Mon Sep 17 00:00:00 2001 From: Rui Fu Date: Wed, 24 Sep 2025 11:06:20 +0800 Subject: [PATCH 1/4] Expose MCP message handler configuration --- src/agents/mcp/server.py | 18 +++++ tests/mcp/test_message_handler.py | 110 ++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 tests/mcp/test_message_handler.py diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 0acb1345a..84d0fd04d 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -13,6 +13,7 @@ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult from typing_extensions import NotRequired, TypedDict @@ -103,6 +104,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """ Args: @@ -124,6 +126,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, used for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__(use_structured_content=use_structured_content) self.session: ClientSession | None = None @@ -135,6 +139,7 @@ def __init__( self.client_session_timeout_seconds = client_session_timeout_seconds self.max_retry_attempts = max_retry_attempts self.retry_backoff_seconds_base = retry_backoff_seconds_base + self.message_handler = message_handler # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True @@ -272,6 +277,7 @@ async def connect(self): timedelta(seconds=self.client_session_timeout_seconds) if self.client_session_timeout_seconds else None, + message_handler=self.message_handler, ) ) server_result = await session.initialize() @@ -394,6 +400,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the stdio transport. @@ -421,6 +428,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -429,6 +438,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = StdioServerParameters( @@ -492,6 +502,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -521,6 +532,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -529,6 +542,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = params @@ -592,6 +606,7 @@ def __init__( use_structured_content: bool = False, max_retry_attempts: int = 0, retry_backoff_seconds_base: float = 1.0, + message_handler: MessageHandlerFnT | None = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -622,6 +637,8 @@ def __init__( Defaults to no retries. retry_backoff_seconds_base: The base delay, in seconds, for exponential backoff between retries. + message_handler: Optional handler invoked for session messages as delivered by the + ClientSession. """ super().__init__( cache_tools_list, @@ -630,6 +647,7 @@ def __init__( use_structured_content, max_retry_attempts, retry_backoff_seconds_base, + message_handler=message_handler, ) self.params = params diff --git a/tests/mcp/test_message_handler.py b/tests/mcp/test_message_handler.py new file mode 100644 index 000000000..afb8d9276 --- /dev/null +++ b/tests/mcp/test_message_handler.py @@ -0,0 +1,110 @@ +import contextlib + +import anyio +import pytest +from mcp.shared.message import SessionMessage +from mcp.types import InitializeResult + +from agents.mcp.server import ( + MCPServerSse, + MCPServerStreamableHttp, + MCPServerStdio, + _MCPServerWithClientSession, +) +from mcp.client.session import MessageHandlerFnT + + +class _StubClientSession: + """Stub ClientSession that records the configured message handler.""" + + def __init__( + self, + read_stream, + write_stream, + read_timeout_seconds, + *, + message_handler=None, + **_: object, + ) -> None: + self.message_handler = message_handler + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def initialize(self) -> InitializeResult: + return InitializeResult( + protocolVersion="2024-11-05", + capabilities={}, + serverInfo={"name": "stub", "version": "1.0"}, + ) + + +class _MessageHandlerTestServer(_MCPServerWithClientSession): + def __init__(self, handler: MessageHandlerFnT | None): + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + message_handler=handler, + ) + + def create_streams(self): + @contextlib.asynccontextmanager + async def _streams(): + send_stream, recv_stream = anyio.create_memory_object_stream[SessionMessage | Exception]( + 1 + ) + try: + yield recv_stream, send_stream, None + finally: + await recv_stream.aclose() + await send_stream.aclose() + + return _streams() + + @property + def name(self) -> str: + return "test-server" + + +@pytest.mark.asyncio +async def test_client_session_receives_message_handler(monkeypatch): + captured: dict[str, object] = {} + + def _recording_client_session(*args, **kwargs): + session = _StubClientSession(*args, **kwargs) + captured["message_handler"] = session.message_handler + return session + + monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session) + + async def handler(message: SessionMessage) -> None: + del message + + server = _MessageHandlerTestServer(handler) + + try: + await server.connect() + finally: + await server.cleanup() + + assert captured["message_handler"] is handler + + +@pytest.mark.parametrize( + "server_cls, params", + [ + (MCPServerSse, {"url": "https://example.com"}), + (MCPServerStreamableHttp, {"url": "https://example.com"}), + (MCPServerStdio, {"command": "python"}), + ], +) +def test_message_handler_propagates_to_server_base(server_cls, params): + def handler(message: SessionMessage) -> None: + del message + + server = server_cls(params, message_handler=handler) + + assert server.message_handler is handler From 24a7ecfba6fbc207181aa54bb05d63311d2c8c3c Mon Sep 17 00:00:00 2001 From: Rui Fu Date: Wed, 24 Sep 2025 13:08:32 +0800 Subject: [PATCH 2/4] fix fmt --- src/agents/mcp/server.py | 2 +- tests/mcp/test_message_handler.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 84d0fd04d..6978e1a8c 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -11,9 +11,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.session import MessageHandlerFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client -from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult from typing_extensions import NotRequired, TypedDict diff --git a/tests/mcp/test_message_handler.py b/tests/mcp/test_message_handler.py index afb8d9276..c0a04ca33 100644 --- a/tests/mcp/test_message_handler.py +++ b/tests/mcp/test_message_handler.py @@ -2,16 +2,16 @@ import anyio import pytest +from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage from mcp.types import InitializeResult from agents.mcp.server import ( MCPServerSse, - MCPServerStreamableHttp, MCPServerStdio, + MCPServerStreamableHttp, _MCPServerWithClientSession, ) -from mcp.client.session import MessageHandlerFnT class _StubClientSession: @@ -53,9 +53,8 @@ def __init__(self, handler: MessageHandlerFnT | None): def create_streams(self): @contextlib.asynccontextmanager async def _streams(): - send_stream, recv_stream = anyio.create_memory_object_stream[SessionMessage | Exception]( - 1 - ) + send_stream, recv_stream = ( + anyio.create_memory_object_stream[SessionMessage | Exception](1)) try: yield recv_stream, send_stream, None finally: From e3dbcc300f264839687ccd2d72e92c13ce7b6d76 Mon Sep 17 00:00:00 2001 From: Rui Fu Date: Wed, 24 Sep 2025 13:08:32 +0800 Subject: [PATCH 3/4] Expose MCP message handler configuration --- src/agents/mcp/server.py | 2 +- tests/mcp/test_message_handler.py | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 84d0fd04d..6978e1a8c 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -11,9 +11,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.session import MessageHandlerFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client -from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult from typing_extensions import NotRequired, TypedDict diff --git a/tests/mcp/test_message_handler.py b/tests/mcp/test_message_handler.py index afb8d9276..4303d6098 100644 --- a/tests/mcp/test_message_handler.py +++ b/tests/mcp/test_message_handler.py @@ -2,16 +2,16 @@ import anyio import pytest +from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage -from mcp.types import InitializeResult +from mcp.types import Implementation, InitializeResult, ServerCapabilities from agents.mcp.server import ( MCPServerSse, - MCPServerStreamableHttp, MCPServerStdio, + MCPServerStreamableHttp, _MCPServerWithClientSession, ) -from mcp.client.session import MessageHandlerFnT class _StubClientSession: @@ -37,8 +37,8 @@ async def __aexit__(self, exc_type, exc, tb): async def initialize(self) -> InitializeResult: return InitializeResult( protocolVersion="2024-11-05", - capabilities={}, - serverInfo={"name": "stub", "version": "1.0"}, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="stub", version="1.0"), ) @@ -53,9 +53,9 @@ def __init__(self, handler: MessageHandlerFnT | None): def create_streams(self): @contextlib.asynccontextmanager async def _streams(): - send_stream, recv_stream = anyio.create_memory_object_stream[SessionMessage | Exception]( - 1 - ) + send_stream, recv_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](1) try: yield recv_stream, send_stream, None finally: @@ -80,8 +80,11 @@ def _recording_client_session(*args, **kwargs): monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session) - async def handler(message: SessionMessage) -> None: - del message + class _AsyncHandler: + async def __call__(self, message): + del message + + handler: MessageHandlerFnT = _AsyncHandler() server = _MessageHandlerTestServer(handler) @@ -102,8 +105,11 @@ async def handler(message: SessionMessage) -> None: ], ) def test_message_handler_propagates_to_server_base(server_cls, params): - def handler(message: SessionMessage) -> None: - del message + class _AsyncHandler: + async def __call__(self, message): + del message + + handler: MessageHandlerFnT = _AsyncHandler() server = server_cls(params, message_handler=handler) From 8e57a92700b513f59f4b9ec56301efa18277e7a0 Mon Sep 17 00:00:00 2001 From: Rui Fu Date: Sun, 5 Oct 2025 12:06:24 +0800 Subject: [PATCH 4/4] fmt --- tests/mcp/test_message_handler.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/mcp/test_message_handler.py b/tests/mcp/test_message_handler.py index 4303d6098..fdb65d91c 100644 --- a/tests/mcp/test_message_handler.py +++ b/tests/mcp/test_message_handler.py @@ -4,7 +4,15 @@ import pytest from mcp.client.session import MessageHandlerFnT from mcp.shared.message import SessionMessage -from mcp.types import Implementation, InitializeResult, ServerCapabilities +from mcp.shared.session import RequestResponder +from mcp.types import ( + ClientResult, + Implementation, + InitializeResult, + ServerCapabilities, + ServerNotification, + ServerRequest, +) from agents.mcp.server import ( MCPServerSse, @@ -14,6 +22,13 @@ ) +HandlerMessage = ( + RequestResponder[ServerRequest, ClientResult] + | ServerNotification + | Exception +) + + class _StubClientSession: """Stub ClientSession that records the configured message handler.""" @@ -35,10 +50,12 @@ async def __aexit__(self, exc_type, exc, tb): return False async def initialize(self) -> InitializeResult: + capabilities = ServerCapabilities.model_construct() + server_info = Implementation.model_construct(name="stub", version="1.0") return InitializeResult( protocolVersion="2024-11-05", - capabilities=ServerCapabilities(), - serverInfo=Implementation(name="stub", version="1.0"), + capabilities=capabilities, + serverInfo=server_info, ) @@ -81,7 +98,7 @@ def _recording_client_session(*args, **kwargs): monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session) class _AsyncHandler: - async def __call__(self, message): + async def __call__(self, message: HandlerMessage) -> None: del message handler: MessageHandlerFnT = _AsyncHandler() @@ -106,7 +123,7 @@ async def __call__(self, message): ) def test_message_handler_propagates_to_server_base(server_cls, params): class _AsyncHandler: - async def __call__(self, message): + async def __call__(self, message: HandlerMessage) -> None: del message handler: MessageHandlerFnT = _AsyncHandler()