From 38f307ca3b184fc574f22038b2025a1b2e8e624f Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 2 Nov 2025 19:56:01 +0100 Subject: [PATCH 01/24] clean all remaining stream ends on close --- src/mcp/server/sse.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b7ff332803..eaef666b30 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -196,7 +196,19 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): tg.start_soon(response_wrapper, scope, receive, send) logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + try: + yield (read_stream, write_stream) + finally: + # Close all remaining stream ends + for stream, name in [ + (read_stream, "read_stream"), + (write_stream, "write_stream"), + (sse_stream_reader, "sse_stream_reader"), + ]: + try: + await stream.aclose() + except Exception as e: + logger.debug(f"Error closing {name}: {e}") async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") From 7f241e78dfdbb40977d2706d8460f5d323b6ac53 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 2 Nov 2025 19:59:34 +0100 Subject: [PATCH 02/24] add disconnect_event to handle closure from client/task_group --- src/mcp/server/streaming_asgi_transport.py | 27 ++++++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index a74751312c..dd70201de1 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -10,6 +10,7 @@ import typing from typing import Any, cast +from typing import Callable, Awaitable import anyio import anyio.abc @@ -65,6 +66,8 @@ async def handle_async_request( ) -> Response: assert isinstance(request.stream, AsyncByteStream) + disconnect_event = anyio.Event() + # ASGI scope. scope = { "type": "http", @@ -97,11 +100,17 @@ async def handle_async_request( content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) # ASGI callables. + async def send_disconnect() -> None: + disconnect_event.set() + async def receive() -> dict[str, Any]: nonlocal request_complete + if disconnect_event.is_set(): + return {"type": "http.disconnect"} + if request_complete: - await response_complete.wait() + await disconnect_event.wait() return {"type": "http.disconnect"} try: @@ -140,7 +149,9 @@ async def process_messages() -> None: async with asgi_receive_channel: async for message in asgi_receive_channel: if message["type"] == "http.response.start": - assert not response_started + if response_started: + # Ignore duplicate response.start from ASGI app during SSE disconnect + continue status_code = message["status"] response_headers = message.get("headers", []) response_started = True @@ -163,7 +174,7 @@ async def process_messages() -> None: # Ensure events are set even if there's an error initial_response_ready.set() response_complete.set() - await content_send_channel.aclose() + # Create tasks for running the app and processing messages self.task_group.start_soon(run_app) @@ -176,7 +187,7 @@ async def process_messages() -> None: return Response( status_code, headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), + stream = StreamingASGIResponseStream(content_receive_channel, send_disconnect), ) @@ -192,12 +203,18 @@ class StreamingASGIResponseStream(AsyncByteStream): def __init__( self, receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + send_disconnect: Callable[[], Awaitable[None]], ) -> None: self.receive_channel = receive_channel + self.send_disconnect = send_disconnect async def __aiter__(self) -> typing.AsyncIterator[bytes]: try: async for chunk in self.receive_channel: yield chunk finally: - await self.receive_channel.aclose() + await self.aclose() + + async def aclose(self) -> None: + await self.receive_channel.aclose() + await self.send_disconnect() From 7250289addc216af186ac9b4c2e91ba81682e2e7 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 2 Nov 2025 20:01:59 +0100 Subject: [PATCH 03/24] create context_app using StreamingASGITransport and update test_request_context_propogation to apply this methodology --- tests/shared/test_sse.py | 77 ++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 19 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index fdb6ccfd8e..e2b813dd4b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -22,6 +22,7 @@ from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -367,9 +368,32 @@ def context_server(server_port: int) -> Generator[None, None, None]: if proc.is_alive(): print("context server process failed to terminate") +@pytest.fixture() +async def context_app() -> Starlette: + """Fixture that provides the context server app""" + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*", "testserver"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://testserver"] + ) + sse = SseServerTransport("/messages/", security_settings=security_settings) + context_server = RequestContextServer() + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + ) + return app + @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(context_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -378,27 +402,42 @@ async def test_request_context_propagation(context_server: None, server_url: str "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) + async with anyio.create_task_group() as tg: + def create_test_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + transport = StreamingASGITransport(app=context_app, task_group=tg) + return httpx.AsyncClient( + transport=transport, + base_url="http://testserver", + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + + async with sse_client("http://testserver/sse", headers=custom_headers, httpx_client_factory=create_test_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Parse the JSON response + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + # Parse the JSON response + assert len(tool_result.content) == 1 + headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio From 8a8e14f66a6a628988362bd33a8ef72cf12ef980 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 2 Nov 2025 20:06:20 +0100 Subject: [PATCH 04/24] add back await content_send_channel.aclose() --- src/mcp/server/streaming_asgi_transport.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index dd70201de1..498833d4e0 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -174,6 +174,7 @@ async def process_messages() -> None: # Ensure events are set even if there's an error initial_response_ready.set() response_complete.set() + await content_send_channel.aclose() # Create tasks for running the app and processing messages From 7a830248a296b519f076f069f4798a48808154ec Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 2 Nov 2025 20:07:21 +0100 Subject: [PATCH 05/24] revert spaces --- src/mcp/server/streaming_asgi_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 498833d4e0..3591819dab 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -188,7 +188,7 @@ async def process_messages() -> None: return Response( status_code, headers=response_headers, - stream = StreamingASGIResponseStream(content_receive_channel, send_disconnect), + stream=StreamingASGIResponseStream(content_receive_channel, send_disconnect), ) From c880cc1b0b9c95e24ef5ee00fde882759239b0a5 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Mon, 3 Nov 2025 02:23:44 +0100 Subject: [PATCH 06/24] remove try finally block for testing and just add sse_stream_reader.aclose() --- src/mcp/server/sse.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index eaef666b30..e106cf3c29 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -190,25 +190,14 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): ) await read_stream_writer.aclose() await write_stream_reader.aclose() + await sse_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") tg.start_soon(response_wrapper, scope, receive, send) logger.debug("Yielding read and write streams") - try: - yield (read_stream, write_stream) - finally: - # Close all remaining stream ends - for stream, name in [ - (read_stream, "read_stream"), - (write_stream, "write_stream"), - (sse_stream_reader, "sse_stream_reader"), - ]: - try: - await stream.aclose() - except Exception as e: - logger.debug(f"Error closing {name}: {e}") + yield (read_stream, write_stream) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") From 0aa5a45500eb121789dc31bc48b0bb41175f1d43 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Mon, 3 Nov 2025 19:12:25 +0100 Subject: [PATCH 07/24] fix ruff errors --- src/mcp/server/streaming_asgi_transport.py | 2 +- tests/shared/test_sse.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 3591819dab..068d277be3 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -9,8 +9,8 @@ """ import typing +from collections.abc import Awaitable, Callable from typing import Any, cast -from typing import Callable, Awaitable import anyio import anyio.abc diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index e2b813dd4b..acbc5a5740 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -21,8 +21,8 @@ from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.server.transport_security import TransportSecuritySettings from mcp.server.streaming_asgi_transport import StreamingASGITransport +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -418,7 +418,10 @@ def create_test_client( follow_redirects=True, ) - async with sse_client("http://testserver/sse", headers=custom_headers, httpx_client_factory=create_test_client) as ( + async with sse_client("http://testserver/sse", + headers=custom_headers, + httpx_client_factory=create_test_client, + sse_read_timeout=0.5) as ( read_stream, write_stream, ): @@ -432,7 +435,8 @@ def create_test_client( # Parse the JSON response assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" From b489d30aef38bf3cb958c0e4cdce37f280621498 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Mon, 3 Nov 2025 19:19:08 +0100 Subject: [PATCH 08/24] run precommit --- src/mcp/server/streaming_asgi_transport.py | 1 - tests/shared/test_sse.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 068d277be3..690a0c392d 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -176,7 +176,6 @@ async def process_messages() -> None: response_complete.set() await content_send_channel.aclose() - # Create tasks for running the app and processing messages self.task_group.start_soon(run_app) self.task_group.start_soon(process_messages) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index acbc5a5740..76a34e025b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -368,12 +368,13 @@ def context_server(server_port: int) -> Generator[None, None, None]: if proc.is_alive(): print("context server process failed to terminate") + @pytest.fixture() async def context_app() -> Starlette: """Fixture that provides the context server app""" security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*", "testserver"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://testserver"] + allowed_hosts=["127.0.0.1:*", "localhost:*", "testserver"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://testserver"], ) sse = SseServerTransport("/messages/", security_settings=security_settings) context_server = RequestContextServer() @@ -403,6 +404,7 @@ async def test_request_context_propagation(context_app: Starlette) -> None: } async with anyio.create_task_group() as tg: + def create_test_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -418,10 +420,12 @@ def create_test_client( follow_redirects=True, ) - async with sse_client("http://testserver/sse", - headers=custom_headers, - httpx_client_factory=create_test_client, - sse_read_timeout=0.5) as ( + async with sse_client( + "http://testserver/sse", + headers=custom_headers, + httpx_client_factory=create_test_client, + sse_read_timeout=0.5, + ) as ( read_stream, write_stream, ): From bd70a51d622966a86f0e276af07fc83e30eee40c Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 15:46:24 +0100 Subject: [PATCH 09/24] update all sse tests that use uvicorn to use StreamingASGITransport instead. DRY implementation --- tests/shared/test_sse.py | 367 ++++++++++++++------------------------- 1 file changed, 127 insertions(+), 240 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 76a34e025b..cf17d523e3 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,14 +1,12 @@ import json -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator +from anyio.abc import TaskGroup from typing import Any import anyio import httpx +from mcp.shared._httpx_utils import McpHttpClientFactory import pytest -import uvicorn from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette @@ -33,21 +31,10 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_SSE" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +TEST_SERVER_HOST = "testserver" +TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" # Test server implementation @@ -80,117 +67,116 @@ async def handle_list_tools() -> list[Tool]: async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] - -# Test fixtures -def make_server_app() -> Starlette: - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing +def create_asgi_client_factory(app: Starlette, tg: TaskGroup) -> McpHttpClientFactory: + """Factory function to create httpx clients with StreamingASGITransport""" + def asgi_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + transport = StreamingASGITransport(app=app, task_group=tg) + return httpx.AsyncClient( + transport=transport, + base_url=TEST_SERVER_BASE_URL, + headers=headers, + timeout=timeout, + auth=auth + ) + return asgi_client_factory + +def create_sse_app(server: Server) -> Starlette: + """Helper to create SSE app with given server""" security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=[TEST_SERVER_HOST], + allowed_origins=[TEST_SERVER_BASE_URL], ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) +# Test fixtures @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield +def server_app() -> Starlette: + """Create test Starlette app with SSE transport""" + app = create_sse_app(ServerTest()) + return app - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def tg() -> AsyncGenerator[TaskGroup, None]: + async with anyio.create_task_group() as tg: + yield tg @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: +async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client using StreamingASGITransport""" + transport = StreamingASGITransport(app=server_app, task_group=tg) + async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: yield client +@pytest.fixture() +async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(server_app, tg) + + async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams: + async with ClientSession(*streams) as session: + yield session + # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + + async def connection_test() -> None: + async with http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection(sse_client_session: ClientSession) -> None: + # Test initialization + result = await sse_client_session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await sse_client_session.send_ping() + assert isinstance(ping_result, EmptyResult) @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session +async def initialized_sse_client_session(sse_client_session: ClientSession) -> AsyncGenerator[ClientSession, None]: + session = sse_client_session + await session.initialize() + yield session @pytest.mark.anyio @@ -233,51 +219,32 @@ async def test_sse_client_timeout( pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - - @pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield +async def mounted_server_app(server_app: Starlette) -> Starlette: + """Create a mounted server app""" + app = Starlette(routes=[Mount("/mounted_app", app=server_app)]) + return app - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def sse_client_mounted_server_app_session(tg: TaskGroup, mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) + + async with sse_client(f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams: + async with ClientSession(*streams) as session: + yield session @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection_mounted_app(sse_client_mounted_server_app_session: ClientSession) -> None: + session = sse_client_mounted_server_app_session + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) # Test server with request context that returns headers in the response @@ -323,78 +290,14 @@ async def handle_list_tools() -> list[Tool]: ] -def run_context_server(server_port: int) -> None: - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] - ) - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("context server process failed to terminate") - - -@pytest.fixture() -async def context_app() -> Starlette: +async def context_server_app() -> Starlette: """Fixture that provides the context server app""" - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*", "testserver"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://testserver"], - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] - ) + app = create_sse_app(RequestContextServer()) return app - @pytest.mark.anyio -async def test_request_context_propagation(context_app: Starlette) -> None: +async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -403,61 +306,45 @@ async def test_request_context_propagation(context_app: Starlette) -> None: "X-Trace-Id": "trace-123", } - async with anyio.create_task_group() as tg: + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - def create_test_client( - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, - auth: httpx.Auth | None = None, - ) -> httpx.AsyncClient: - transport = StreamingASGITransport(app=context_app, task_group=tg) - return httpx.AsyncClient( - transport=transport, - base_url="http://testserver", - headers=headers, - timeout=timeout, - auth=auth, - follow_redirects=True, - ) - - async with sse_client( - "http://testserver/sse", - headers=custom_headers, - httpx_client_factory=create_test_client, - sse_read_timeout=0.5, - ) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + headers=custom_headers, + httpx_client_factory=asgi_client_factory, + sse_read_timeout=0.5, + ) as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - content_item = tool_result.content[0] - headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") + # Parse the JSON response + assert len(tool_result.content) == 1 + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] + + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) # Create multiple clients with different headers for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( + async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory) as ( read_stream, write_stream, ): From 5cc10faa014efe6b981f0a4068afdd84d57c4b05 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 15:48:51 +0100 Subject: [PATCH 10/24] run precommit --- tests/shared/test_sse.py | 43 ++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index cf17d523e3..829c268b3f 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,12 +1,11 @@ import json from collections.abc import AsyncGenerator -from anyio.abc import TaskGroup from typing import Any import anyio import httpx -from mcp.shared._httpx_utils import McpHttpClientFactory import pytest +from anyio.abc import TaskGroup from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette @@ -21,6 +20,7 @@ from mcp.server.sse import SseServerTransport from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -67,8 +67,10 @@ async def handle_list_tools() -> list[Tool]: async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] + def create_asgi_client_factory(app: Starlette, tg: TaskGroup) -> McpHttpClientFactory: """Factory function to create httpx clients with StreamingASGITransport""" + def asgi_client_factory( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -76,14 +78,12 @@ def asgi_client_factory( ) -> httpx.AsyncClient: transport = StreamingASGITransport(app=app, task_group=tg) return httpx.AsyncClient( - transport=transport, - base_url=TEST_SERVER_BASE_URL, - headers=headers, - timeout=timeout, - auth=auth + transport=transport, base_url=TEST_SERVER_BASE_URL, headers=headers, timeout=timeout, auth=auth ) + return asgi_client_factory + def create_sse_app(server: Server) -> Starlette: """Helper to create SSE app with given server""" security_settings = TransportSecuritySettings( @@ -107,12 +107,14 @@ async def handle_sse(request: Request) -> Response: # Test fixtures + @pytest.fixture() def server_app() -> Starlette: """Create test Starlette app with SSE transport""" app = create_sse_app(ServerTest()) return app + @pytest.fixture() async def tg() -> AsyncGenerator[TaskGroup, None]: async with anyio.create_task_group() as tg: @@ -126,11 +128,14 @@ async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ht async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: yield client + @pytest.fixture() async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: asgi_client_factory = create_asgi_client_factory(server_app, tg) - - async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams: + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory + ) as streams: async with ClientSession(*streams) as session: yield session @@ -139,7 +144,7 @@ async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGener @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - + async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 @@ -227,13 +232,18 @@ async def mounted_server_app(server_app: Starlette) -> Starlette: @pytest.fixture() -async def sse_client_mounted_server_app_session(tg: TaskGroup, mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]: +async def sse_client_mounted_server_app_session( + tg: TaskGroup, mounted_server_app: Starlette +) -> AsyncGenerator[ClientSession, None]: asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) - - async with sse_client(f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams: + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory + ) as streams: async with ClientSession(*streams) as session: yield session + @pytest.mark.anyio async def test_sse_client_basic_connection_mounted_app(sse_client_mounted_server_app_session: ClientSession) -> None: session = sse_client_mounted_server_app_session @@ -296,6 +306,7 @@ async def context_server_app() -> Starlette: app = create_sse_app(RequestContextServer()) return app + @pytest.mark.anyio async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" @@ -337,14 +348,16 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] - + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) # Create multiple clients with different headers for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory) as ( + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory + ) as ( read_stream, write_stream, ): From e1745f86a71d9171da768be6044422d6aeed39bd Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 15:57:50 +0100 Subject: [PATCH 11/24] remove taskgroup fixture due to premature closing --- tests/shared/test_sse.py | 145 +++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 74 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 829c268b3f..8acba05b4b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -116,28 +116,24 @@ def server_app() -> Starlette: @pytest.fixture() -async def tg() -> AsyncGenerator[TaskGroup, None]: - async with anyio.create_task_group() as tg: - yield tg - - -@pytest.fixture() -async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client using StreamingASGITransport""" - transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: - yield client + async with anyio.create_task_group() as tg: + transport = StreamingASGITransport(app=server_app, task_group=tg) + async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: + yield client @pytest.fixture() -async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: - asgi_client_factory = create_asgi_client_factory(server_app, tg) +async def sse_client_session(server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + async with anyio.create_task_group() as tg: + asgi_client_factory = create_asgi_client_factory(server_app, tg) - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory - ) as streams: - async with ClientSession(*streams) as session: - yield session + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory + ) as streams: + async with ClientSession(*streams) as session: + yield session # Tests @@ -232,16 +228,15 @@ async def mounted_server_app(server_app: Starlette) -> Starlette: @pytest.fixture() -async def sse_client_mounted_server_app_session( - tg: TaskGroup, mounted_server_app: Starlette -) -> AsyncGenerator[ClientSession, None]: - asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) +async def sse_client_mounted_server_app_session(mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + async with anyio.create_task_group() as tg: + asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) - async with sse_client( - f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory - ) as streams: - async with ClientSession(*streams) as session: - yield session + async with sse_client( + f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory + ) as streams: + async with ClientSession(*streams) as session: + yield session @pytest.mark.anyio @@ -308,7 +303,7 @@ async def context_server_app() -> Starlette: @pytest.mark.anyio -async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: +async def test_request_context_propagation(context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -317,61 +312,63 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St "X-Trace-Id": "trace-123", } - asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", - headers=custom_headers, - httpx_client_factory=asgi_client_factory, - sse_read_timeout=0.5, - ) as streams: - async with ClientSession(*streams) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) + async with anyio.create_task_group() as tg: + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - # Parse the JSON response - assert len(tool_result.content) == 1 - content_item = tool_result.content[0] - headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + headers=custom_headers, + httpx_client_factory=asgi_client_factory, + sse_read_timeout=0.5, + ) as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) + + # Parse the JSON response + assert len(tool_result.content) == 1 + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: +async def test_request_context_isolation(context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] - asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - - # Create multiple clients with different headers - for i in range(3): - headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory - ) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Call the tool that echoes context - tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - - assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + async with anyio.create_task_group() as tg: + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + + # Create multiple clients with different headers + for i in range(3): + headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory + ) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call the tool that echoes context + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + + assert len(tool_result.content) == 1 + context_data = json.loads( + tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" + ) + contexts.append(context_data) # Verify each request had its own context assert len(contexts) == 3 From 8fc64737a580a9b198b29361a134b7f26814f67a Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 17:23:06 +0100 Subject: [PATCH 12/24] prevent sse_client from cancelling external task groups --- src/mcp/client/sse.py | 18 +++-- tests/shared/test_sse.py | 139 ++++++++++++++++++++------------------- 2 files changed, 84 insertions(+), 73 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 791c602cdd..c16b7f6362 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,11 +1,11 @@ import logging -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, AsyncExitStack from typing import Any from urllib.parse import urljoin, urlparse import anyio import httpx -from anyio.abc import TaskStatus +from anyio.abc import TaskStatus, TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError @@ -29,6 +29,7 @@ async def sse_client( sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + maybe_task_group: TaskGroup | None = None, ): """ Client transport for SSE. @@ -52,7 +53,15 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: + async with AsyncExitStack() as stack: + # Only create a task group if one wasn't provided + if maybe_task_group is None: + tg = await stack.enter_async_context(anyio.create_task_group()) + else: + tg = maybe_task_group + + owns_task_group = maybe_task_group is None + try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( @@ -142,7 +151,8 @@ async def post_writer(endpoint_url: str): try: yield read_stream, write_stream finally: - tg.cancel_scope.cancel() + if owns_task_group: + tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 8acba05b4b..3759e51dc0 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -116,24 +116,28 @@ def server_app() -> Starlette: @pytest.fixture() -async def http_client(server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client using StreamingASGITransport""" +async def tg() -> AsyncGenerator[TaskGroup, None]: async with anyio.create_task_group() as tg: - transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: - yield client + yield tg @pytest.fixture() -async def sse_client_session(server_app: Starlette) -> AsyncGenerator[ClientSession, None]: - async with anyio.create_task_group() as tg: - asgi_client_factory = create_asgi_client_factory(server_app, tg) +async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client using StreamingASGITransport""" + transport = StreamingASGITransport(app=server_app, task_group=tg) + async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: + yield client - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory - ) as streams: - async with ClientSession(*streams) as session: - yield session + +@pytest.fixture() +async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(server_app, tg) + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session # Tests @@ -228,15 +232,16 @@ async def mounted_server_app(server_app: Starlette) -> Starlette: @pytest.fixture() -async def sse_client_mounted_server_app_session(mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]: - async with anyio.create_task_group() as tg: - asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) +async def sse_client_mounted_server_app_session( + tg: TaskGroup, mounted_server_app: Starlette +) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) - async with sse_client( - f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory - ) as streams: - async with ClientSession(*streams) as session: - yield session + async with sse_client( + f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session @pytest.mark.anyio @@ -303,7 +308,7 @@ async def context_server_app() -> Starlette: @pytest.mark.anyio -async def test_request_context_propagation(context_server_app: Starlette) -> None: +async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -312,63 +317,59 @@ async def test_request_context_propagation(context_server_app: Starlette) -> Non "X-Trace-Id": "trace-123", } - async with anyio.create_task_group() as tg: - asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", - headers=custom_headers, - httpx_client_factory=asgi_client_factory, - sse_read_timeout=0.5, - ) as streams: - async with ClientSession(*streams) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", + headers=custom_headers, + httpx_client_factory=asgi_client_factory, + sse_read_timeout=0.5, + + ) as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - content_item = tool_result.content[0] - headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") + # Parse the JSON response + assert len(tool_result.content) == 1 + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server_app: Starlette) -> None: +async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] - async with anyio.create_task_group() as tg: - asgi_client_factory = create_asgi_client_factory(context_server_app, tg) - - # Create multiple clients with different headers - for i in range(3): - headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - - async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory - ) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Call the tool that echoes context - tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - - assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + + # Create multiple clients with different headers + for i in range(3): + headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} + + async with sse_client( + f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # Call the tool that echoes context + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + + assert len(tool_result.content) == 1 + context_data = json.loads( + tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" + ) + contexts.append(context_data) # Verify each request had its own context assert len(contexts) == 3 From fabf6c5c0397afc6927c8debbeb00aa678214aba Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 17:23:43 +0100 Subject: [PATCH 13/24] run precommit --- src/mcp/client/sse.py | 8 ++++---- tests/shared/test_sse.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index c16b7f6362..141e82c3d7 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,11 +1,11 @@ import logging -from contextlib import asynccontextmanager, AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any from urllib.parse import urljoin, urlparse import anyio import httpx -from anyio.abc import TaskStatus, TaskGroup +from anyio.abc import TaskGroup, TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError @@ -59,9 +59,9 @@ async def sse_client( tg = await stack.enter_async_context(anyio.create_task_group()) else: tg = maybe_task_group - + owns_task_group = maybe_task_group is None - + try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 3759e51dc0..597f9edc29 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -134,7 +134,9 @@ async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGener asgi_client_factory = create_asgi_client_factory(server_app, tg) async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, + f"{TEST_SERVER_BASE_URL}/sse", + sse_read_timeout=0.5, + httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: yield session @@ -238,7 +240,9 @@ async def sse_client_mounted_server_app_session( asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) async with sse_client( - f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, + f"{TEST_SERVER_BASE_URL}/mounted_app/sse", + sse_read_timeout=0.5, + httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: yield session @@ -324,7 +328,6 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St headers=custom_headers, httpx_client_factory=asgi_client_factory, sse_read_timeout=0.5, - ) as streams: async with ClientSession(*streams) as session: # Initialize the session @@ -357,7 +360,9 @@ async def test_request_context_isolation(tg: TaskGroup, context_server_app: Star headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory, + f"{TEST_SERVER_BASE_URL}/sse", + headers=headers, + httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: await session.initialize() From e8a3b0e3ccf2013196f33eb16dbcfc6038b2f024 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 17:35:41 +0100 Subject: [PATCH 14/24] revert sse_client and add cleanup to outer task group in tests --- src/mcp/client/sse.py | 18 ++++-------------- tests/shared/test_sse.py | 5 ++++- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 141e82c3d7..791c602cdd 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,11 +1,11 @@ import logging -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import asynccontextmanager from typing import Any from urllib.parse import urljoin, urlparse import anyio import httpx -from anyio.abc import TaskGroup, TaskStatus +from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError @@ -29,7 +29,6 @@ async def sse_client( sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, - maybe_task_group: TaskGroup | None = None, ): """ Client transport for SSE. @@ -53,15 +52,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with AsyncExitStack() as stack: - # Only create a task group if one wasn't provided - if maybe_task_group is None: - tg = await stack.enter_async_context(anyio.create_task_group()) - else: - tg = maybe_task_group - - owns_task_group = maybe_task_group is None - + async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( @@ -151,8 +142,7 @@ async def post_writer(endpoint_url: str): try: yield read_stream, write_stream finally: - if owns_task_group: - tg.cancel_scope.cancel() + tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 597f9edc29..3a01f0f206 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -118,7 +118,10 @@ def server_app() -> Starlette: @pytest.fixture() async def tg() -> AsyncGenerator[TaskGroup, None]: async with anyio.create_task_group() as tg: - yield tg + try: + yield tg + finally: + tg.cancel_scope.cancel() @pytest.fixture() From 8188bf29b35afcac841a9b8b64cafcb88f982591 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 18:02:37 +0100 Subject: [PATCH 15/24] remove timeout for sse_client --- tests/shared/test_sse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 3a01f0f206..4761f196be 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -138,7 +138,6 @@ async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGener async with sse_client( f"{TEST_SERVER_BASE_URL}/sse", - sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: From 86a377fd34316fd2dbec3887ee2ca61e60635e7e Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 6 Nov 2025 23:53:34 +0100 Subject: [PATCH 16/24] add reset_sse_app_status workaround for sse_starlette quirk --- tests/shared/test_sse.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4761f196be..64a8d7dc0a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -8,6 +8,7 @@ from anyio.abc import TaskGroup from inline_snapshot import snapshot from pydantic import AnyUrl +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -32,6 +33,19 @@ Tool, ) + +@pytest.fixture(autouse=True) +def reset_sse_app_status(): + """Reset sse-starlette's global AppStatus singleton before each test. + + This is necessary because AppStatus.should_exit_event (a global anyio.Event) gets bound + to one event loop but accessed from others during parallel test execution (xdist workers), + causing RuntimeError("bound to a different event loop"), which prevents the SSE server + from responding (leaving status at 499) and causes ClosedResourceError during teardown. + """ + AppStatus.should_exit_event = anyio.Event() + + SERVER_NAME = "test_server_for_SSE" TEST_SERVER_HOST = "testserver" TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" @@ -106,8 +120,6 @@ async def handle_sse(request: Request) -> Response: # Test fixtures - - @pytest.fixture() def server_app() -> Starlette: """Create test Starlette app with SSE transport""" @@ -243,7 +255,6 @@ async def sse_client_mounted_server_app_session( async with sse_client( f"{TEST_SERVER_BASE_URL}/mounted_app/sse", - sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: @@ -329,7 +340,6 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St f"{TEST_SERVER_BASE_URL}/sse", headers=custom_headers, httpx_client_factory=asgi_client_factory, - sse_read_timeout=0.5, ) as streams: async with ClientSession(*streams) as session: # Initialize the session From f1748a0b3cf6a1c8c653f192fe8df6b69175116f Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 7 Nov 2025 00:13:33 +0100 Subject: [PATCH 17/24] improve workaround --- tests/shared/test_sse.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 64a8d7dc0a..c6df145e86 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -5,10 +5,11 @@ import anyio import httpx import pytest +import sse_starlette from anyio.abc import TaskGroup from inline_snapshot import snapshot +from packaging import version from pydantic import AnyUrl -from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -33,17 +34,40 @@ Tool, ) +SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) +NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") + @pytest.fixture(autouse=True) def reset_sse_app_status(): """Reset sse-starlette's global AppStatus singleton before each test. - This is necessary because AppStatus.should_exit_event (a global anyio.Event) gets bound - to one event loop but accessed from others during parallel test execution (xdist workers), - causing RuntimeError("bound to a different event loop"), which prevents the SSE server - from responding (leaving status at 499) and causes ClosedResourceError during teardown. + AppStatus.should_exit_event is a global asyncio.Event that gets bound to + an event loop. This ensures each test gets a fresh Event and prevents + RuntimeError("bound to a different event loop") during parallel test + execution with pytest-xdist. + + NOTE: This fixture is only necessary for sse-starlette < 3.0.0. + Version 3.0+ eliminated the global state issue entirely by using + context-local events instead of module-level singletons, providing + automatic test isolation without manual cleanup. + + See for more details. """ - AppStatus.should_exit_event = anyio.Event() + if not NEEDS_RESET: + yield + return + + # lazy import to avoid import errors + from sse_starlette.sse import AppStatus + + # Setup: Reset before test + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] + + yield + + # Teardown: Reset after test to prevent contamination + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] SERVER_NAME = "test_server_for_SSE" From 749d506c03b26614face7069d9432f40c905679e Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 7 Nov 2025 15:02:38 +0100 Subject: [PATCH 18/24] move workaround to conftest for other test files --- tests/conftest.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/shared/test_sse.py | 38 -------------------------------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..75da636b02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,45 @@ +import anyio import pytest +import sse_starlette +from packaging import version @pytest.fixture def anyio_backend(): return "asyncio" + + +SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) +NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") + + +@pytest.fixture(autouse=True) +def reset_sse_app_status(): + """Reset sse-starlette's global AppStatus singleton before each test. + + AppStatus.should_exit_event is a global asyncio.Event that gets bound to + an event loop. This ensures each test gets a fresh Event and prevents + RuntimeError("bound to a different event loop") during parallel test + execution with pytest-xdist. + + NOTE: This fixture is only necessary for sse-starlette < 3.0.0. + Version 3.0+ eliminated the global state issue entirely by using + context-local events instead of module-level singletons, providing + automatic test isolation without manual cleanup. + + See for more details. + """ + if not NEEDS_RESET: + yield + return + + # lazy import to avoid import errors + from sse_starlette.sse import AppStatus + + # Setup: Reset before test + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] + + yield + + # Teardown: Reset after test to prevent contamination + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index c6df145e86..8a4438eb1e 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -5,10 +5,8 @@ import anyio import httpx import pytest -import sse_starlette from anyio.abc import TaskGroup from inline_snapshot import snapshot -from packaging import version from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request @@ -34,42 +32,6 @@ Tool, ) -SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) -NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") - - -@pytest.fixture(autouse=True) -def reset_sse_app_status(): - """Reset sse-starlette's global AppStatus singleton before each test. - - AppStatus.should_exit_event is a global asyncio.Event that gets bound to - an event loop. This ensures each test gets a fresh Event and prevents - RuntimeError("bound to a different event loop") during parallel test - execution with pytest-xdist. - - NOTE: This fixture is only necessary for sse-starlette < 3.0.0. - Version 3.0+ eliminated the global state issue entirely by using - context-local events instead of module-level singletons, providing - automatic test isolation without manual cleanup. - - See for more details. - """ - if not NEEDS_RESET: - yield - return - - # lazy import to avoid import errors - from sse_starlette.sse import AppStatus - - # Setup: Reset before test - AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] - - yield - - # Teardown: Reset after test to prevent contamination - AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] - - SERVER_NAME = "test_server_for_SSE" TEST_SERVER_HOST = "testserver" TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" From 044728b72bd25b4625db4fcc43a46f92aa5f56e3 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 7 Nov 2025 19:32:34 +0100 Subject: [PATCH 19/24] add version check into reset_sse_app_status --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 75da636b02..335db5240c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,10 +9,6 @@ def anyio_backend(): return "asyncio" -SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) -NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") - - @pytest.fixture(autouse=True) def reset_sse_app_status(): """Reset sse-starlette's global AppStatus singleton before each test. @@ -29,6 +25,10 @@ def reset_sse_app_status(): See for more details. """ + + SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) + NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") + if not NEEDS_RESET: yield return From 2ec89074640dd32d3754d9f94104fe32044dbf2e Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 8 Nov 2025 01:40:33 +0100 Subject: [PATCH 20/24] revert duplicate http.response.start handling --- src/mcp/server/streaming_asgi_transport.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 690a0c392d..628cf132f2 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -149,9 +149,7 @@ async def process_messages() -> None: async with asgi_receive_channel: async for message in asgi_receive_channel: if message["type"] == "http.response.start": - if response_started: - # Ignore duplicate response.start from ASGI app during SSE disconnect - continue + assert not response_started status_code = message["status"] response_headers = message.get("headers", []) response_started = True From b5b4e6f3cfcfbe072c877dfb88b41641adc636a7 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 8 Nov 2025 01:41:16 +0100 Subject: [PATCH 21/24] add NoopASGI --- tests/test_helpers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index a4b4146e91..d561617c88 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -3,6 +3,8 @@ import socket import time +from starlette.types import Receive, Scope, Send + def wait_for_server(port: int, timeout: float = 5.0) -> None: """Wait for server to be ready to accept connections. @@ -29,3 +31,16 @@ def wait_for_server(port: int, timeout: float = 5.0) -> None: # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") + + +class NoopASGI: + """ + This helper exists only for test SSE handlers. Production MCP servers + would normally expose an ASGI endpoint directly. We return this no-op + ASGI app instead of Response() so Starlette does not send a second + http.response.start, which breaks httpx.ASGITransport and + StreamingASGITransport. + """ + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return From 6857129b8637651de0ebef75bc4192513ae17394 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 8 Nov 2025 01:42:39 +0100 Subject: [PATCH 22/24] add NoopASGI and clean tests --- tests/shared/test_sse.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 8a4438eb1e..5d6a129e69 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -10,7 +10,6 @@ from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount, Route import mcp.types as types @@ -31,6 +30,7 @@ TextResourceContents, Tool, ) +from tests.test_helpers import NoopASGI SERVER_NAME = "test_server_for_SSE" TEST_SERVER_HOST = "testserver" @@ -92,10 +92,10 @@ def create_sse_app(server: Server) -> Starlette: ) sse = SseServerTransport("/messages/", security_settings=security_settings) - async def handle_sse(request: Request) -> Response: + async def handle_sse(request: Request) -> NoopASGI: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) - return Response() + return NoopASGI() return Starlette( routes=[ @@ -135,7 +135,7 @@ async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGener asgi_client_factory = create_asgi_client_factory(server_app, tg) async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", + "/sse", httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: @@ -240,7 +240,7 @@ async def sse_client_mounted_server_app_session( asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) async with sse_client( - f"{TEST_SERVER_BASE_URL}/mounted_app/sse", + "/mounted_app/sse", httpx_client_factory=asgi_client_factory, ) as streams: async with ClientSession(*streams) as session: @@ -323,7 +323,7 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St asgi_client_factory = create_asgi_client_factory(context_server_app, tg) async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", + "/sse", headers=custom_headers, httpx_client_factory=asgi_client_factory, ) as streams: @@ -358,7 +358,7 @@ async def test_request_context_isolation(tg: TaskGroup, context_server_app: Star headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} async with sse_client( - f"{TEST_SERVER_BASE_URL}/sse", + "/sse", headers=headers, httpx_client_factory=asgi_client_factory, ) as streams: From 3a9e5769d935de8cbdb4df223ede7111f2857368 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 8 Nov 2025 01:43:22 +0100 Subject: [PATCH 23/24] remove uvicorn and DRY --- tests/server/test_sse_security.py | 335 ++++++++++++++---------------- 1 file changed, 157 insertions(+), 178 deletions(-) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 7a8e52bdab..9c7f0d7d47 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,39 +1,30 @@ """Tests for SSE server DNS rebinding protection.""" import logging -import multiprocessing -import socket +from collections.abc import AsyncGenerator +import anyio import httpx import pytest -import uvicorn +from anyio.abc import TaskGroup from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount, Route from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import NoopASGI logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +TEST_SERVER_HOST = "testserver" +TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - +# Test server implementation class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) @@ -42,12 +33,22 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): +@pytest.fixture() +async def tg() -> AsyncGenerator[TaskGroup, None]: + """Create a task group for the server.""" + async with anyio.create_task_group() as tg: + try: + yield tg + finally: + tg.cancel_scope.cancel() + + +def create_server_app_with_settings(security_settings: TransportSecuritySettings | None = None): """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> NoopASGI: try: async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: if streams: @@ -55,7 +56,7 @@ async def handle_sse(request: Request): except ValueError as e: # Validation error was already handled inside connect_sse logger.debug(f"SSE connection failed validation: {e}") - return Response() + return NoopASGI() routes = [ Route("/sse", endpoint=handle_sse), @@ -63,231 +64,209 @@ async def handle_sse(request: Request): ] starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + return starlette_app + + +def make_client(transport: httpx.AsyncBaseTransport) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL, timeout=5.0) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +async def close_client_streaming_response(response: httpx.Response): + """Close the client streaming response.""" + # consume the first non-empty line / event, then stop + async for line in response.aiter_lines(): + if line and line.strip(): # skip empty keepalive lines + break + # close the streaming response cleanly + await response.aclose() @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): +async def test_sse_security_default_settings(tg: TaskGroup): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) + server_app = create_server_app_with_settings() + transport = StreamingASGITransport(app=server_app, task_group=tg) - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = {"Host": "evil.com"} + server_app = create_server_app_with_settings(security_settings) + transport = httpx.ASGITransport(app=server_app, raise_app_exceptions=True) - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" + # Test with invalid host header + headers = {"Host": "evil.com"} - finally: - process.terminate() - process.join() + response = await make_client(transport).get("/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(tg: TaskGroup): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] + enable_dns_rebinding_protection=True, allowed_hosts=[TEST_SERVER_HOST], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" + # Test with invalid origin header + headers = {"Origin": "http://evil.com"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + response = await client.get("/sse", headers=headers) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(tg: TaskGroup): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, allowed_hosts=[TEST_SERVER_HOST], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) + + async with make_client(transport) as client: + # Test POST with invalid content type + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(tg: TaskGroup): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 + # Test with invalid host header - should still work + headers = {"Host": "evil.com"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully even with invalid host + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(tg: TaskGroup): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_hosts=[TEST_SERVER_HOST, "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} + # Test with custom allowed host + headers = {"Host": "custom.host"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with custom host + assert response.status_code == 200 + await close_client_streaming_response(response) - # Test with non-allowed host - headers = {"Host": "evil.com"} + # Test with non-allowed host + headers = {"Host": "evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + response = await client.get("/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(tg: TaskGroup): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost:*", "127.0.0.1:*"], + allowed_hosts=[TEST_SERVER_HOST, "localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: - # Test with various port numbers - for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + # Test with various port numbers + for test_port in [8080, 3000, 9999]: + headers = {"Host": f"localhost:{test_port}"} - headers = {"Origin": f"http://localhost:{test_port}"} + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + await close_client_streaming_response(response) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + headers = {"Origin": f"http://localhost:{test_port}"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(tg: TaskGroup): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, + allowed_hosts=[TEST_SERVER_HOST, "127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*"], ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) + + async with make_client(transport) as client: + # Test with various valid content types + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + + for content_type in valid_content_types: + # Use a valid UUID format (even though session won't exist) + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # Will get 404 because session doesn't exist, but that's OK + # We're testing that it passes the content-type check + assert response.status_code == 404 + assert response.text == "Could not find session" From 3d2fea9daf2cb452117ae90b236555d6c8850fb5 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 8 Nov 2025 02:05:24 +0100 Subject: [PATCH 24/24] add short comment explaining NoopASGI usage --- tests/server/test_sse_security.py | 1 + tests/shared/test_sse.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 9c7f0d7d47..1852643f58 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -56,6 +56,7 @@ async def handle_sse(request: Request) -> NoopASGI: except ValueError as e: # Validation error was already handled inside connect_sse logger.debug(f"SSE connection failed validation: {e}") + # connect_sse already responded; return a no-op ASGI endpoint return NoopASGI() routes = [ diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5d6a129e69..9a0556c212 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -95,6 +95,7 @@ def create_sse_app(server: Server) -> Starlette: async def handle_sse(request: Request) -> NoopASGI: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) + # connect_sse already responded; return a no-op ASGI endpoint return NoopASGI() return Starlette(