diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b7ff33280..e106cf3c2 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -190,6 +190,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): ) await read_stream_writer.aclose() await write_stream_reader.aclose() + await sse_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index a74751312..628cf132f 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -9,6 +9,7 @@ """ import typing +from collections.abc import Awaitable, Callable from typing import Any, cast import anyio @@ -65,6 +66,8 @@ async def handle_async_request( ) -> Response: assert isinstance(request.stream, AsyncByteStream) + disconnect_event = anyio.Event() + # ASGI scope. scope = { "type": "http", @@ -97,11 +100,17 @@ async def handle_async_request( content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) # ASGI callables. + async def send_disconnect() -> None: + disconnect_event.set() + async def receive() -> dict[str, Any]: nonlocal request_complete + if disconnect_event.is_set(): + return {"type": "http.disconnect"} + if request_complete: - await response_complete.wait() + await disconnect_event.wait() return {"type": "http.disconnect"} try: @@ -176,7 +185,7 @@ async def process_messages() -> None: return Response( status_code, headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), + stream=StreamingASGIResponseStream(content_receive_channel, send_disconnect), ) @@ -192,12 +201,18 @@ class StreamingASGIResponseStream(AsyncByteStream): def __init__( self, receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + send_disconnect: Callable[[], Awaitable[None]], ) -> None: self.receive_channel = receive_channel + self.send_disconnect = send_disconnect async def __aiter__(self) -> typing.AsyncIterator[bytes]: try: async for chunk in self.receive_channel: yield chunk finally: - await self.receive_channel.aclose() + await self.aclose() + + async def aclose(self) -> None: + await self.receive_channel.aclose() + await self.send_disconnect() diff --git a/tests/conftest.py b/tests/conftest.py index af7e47993..335db5240 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,45 @@ +import anyio import pytest +import sse_starlette +from packaging import version @pytest.fixture def anyio_backend(): return "asyncio" + + +@pytest.fixture(autouse=True) +def reset_sse_app_status(): + """Reset sse-starlette's global AppStatus singleton before each test. + + AppStatus.should_exit_event is a global asyncio.Event that gets bound to + an event loop. This ensures each test gets a fresh Event and prevents + RuntimeError("bound to a different event loop") during parallel test + execution with pytest-xdist. + + NOTE: This fixture is only necessary for sse-starlette < 3.0.0. + Version 3.0+ eliminated the global state issue entirely by using + context-local events instead of module-level singletons, providing + automatic test isolation without manual cleanup. + + See for more details. + """ + + SSE_STARLETTE_VERSION = version.parse(sse_starlette.__version__) + NEEDS_RESET = SSE_STARLETTE_VERSION < version.parse("3.0.0") + + if not NEEDS_RESET: + yield + return + + # lazy import to avoid import errors + from sse_starlette.sse import AppStatus + + # Setup: Reset before test + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] + + yield + + # Teardown: Reset after test to prevent contamination + AppStatus.should_exit_event = anyio.Event() # type: ignore[attr-defined] diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 7a8e52bda..1852643f5 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,39 +1,30 @@ """Tests for SSE server DNS rebinding protection.""" import logging -import multiprocessing -import socket +from collections.abc import AsyncGenerator +import anyio import httpx import pytest -import uvicorn +from anyio.abc import TaskGroup from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount, Route from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import NoopASGI logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +TEST_SERVER_HOST = "testserver" +TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - +# Test server implementation class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) @@ -42,12 +33,22 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): +@pytest.fixture() +async def tg() -> AsyncGenerator[TaskGroup, None]: + """Create a task group for the server.""" + async with anyio.create_task_group() as tg: + try: + yield tg + finally: + tg.cancel_scope.cancel() + + +def create_server_app_with_settings(security_settings: TransportSecuritySettings | None = None): """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> NoopASGI: try: async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: if streams: @@ -55,7 +56,8 @@ async def handle_sse(request: Request): except ValueError as e: # Validation error was already handled inside connect_sse logger.debug(f"SSE connection failed validation: {e}") - return Response() + # connect_sse already responded; return a no-op ASGI endpoint + return NoopASGI() routes = [ Route("/sse", endpoint=handle_sse), @@ -63,231 +65,209 @@ async def handle_sse(request: Request): ] starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + return starlette_app + + +def make_client(transport: httpx.AsyncBaseTransport) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL, timeout=5.0) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +async def close_client_streaming_response(response: httpx.Response): + """Close the client streaming response.""" + # consume the first non-empty line / event, then stop + async for line in response.aiter_lines(): + if line and line.strip(): # skip empty keepalive lines + break + # close the streaming response cleanly + await response.aclose() @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): +async def test_sse_security_default_settings(tg: TaskGroup): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) + server_app = create_server_app_with_settings() + transport = StreamingASGITransport(app=server_app, task_group=tg) - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = {"Host": "evil.com"} + server_app = create_server_app_with_settings(security_settings) + transport = httpx.ASGITransport(app=server_app, raise_app_exceptions=True) - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" + # Test with invalid host header + headers = {"Host": "evil.com"} - finally: - process.terminate() - process.join() + response = await make_client(transport).get("/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(tg: TaskGroup): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] + enable_dns_rebinding_protection=True, allowed_hosts=[TEST_SERVER_HOST], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" + # Test with invalid origin header + headers = {"Origin": "http://evil.com"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + response = await client.get("/sse", headers=headers) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(tg: TaskGroup): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, allowed_hosts=[TEST_SERVER_HOST], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) + + async with make_client(transport) as client: + # Test POST with invalid content type + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(tg: TaskGroup): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 + # Test with invalid host header - should still work + headers = {"Host": "evil.com"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully even with invalid host + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(tg: TaskGroup): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_hosts=[TEST_SERVER_HOST, "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} + # Test with custom allowed host + headers = {"Host": "custom.host"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with custom host + assert response.status_code == 200 + await close_client_streaming_response(response) - # Test with non-allowed host - headers = {"Host": "evil.com"} + # Test with non-allowed host + headers = {"Host": "evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + response = await client.get("/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(tg: TaskGroup): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost:*", "127.0.0.1:*"], + allowed_hosts=[TEST_SERVER_HOST, "localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: - # Test with various port numbers - for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} + server_app = create_server_app_with_settings(settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + # Test with various port numbers + for test_port in [8080, 3000, 9999]: + headers = {"Host": f"localhost:{test_port}"} - headers = {"Origin": f"http://localhost:{test_port}"} + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + await close_client_streaming_response(response) - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + headers = {"Origin": f"http://localhost:{test_port}"} - finally: - process.terminate() - process.join() + async with make_client(transport) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", "/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + await close_client_streaming_response(response) @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(tg: TaskGroup): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, + allowed_hosts=[TEST_SERVER_HOST, "127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*"], ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + server_app = create_server_app_with_settings(security_settings) + transport = StreamingASGITransport(app=server_app, task_group=tg) + + async with make_client(transport) as client: + # Test with various valid content types + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + + for content_type in valid_content_types: + # Use a valid UUID format (even though session won't exist) + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # Will get 404 because session doesn't exist, but that's OK + # We're testing that it passes the content-type check + assert response.status_code == 404 + assert response.text == "Could not find session" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index fdb6ccfd8..9a0556c21 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,19 +1,15 @@ import json -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import Any import anyio import httpx import pytest -import uvicorn +from anyio.abc import TaskGroup from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount, Route import mcp.types as types @@ -21,7 +17,9 @@ from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -32,21 +30,11 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import NoopASGI SERVER_NAME = "test_server_for_SSE" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +TEST_SERVER_HOST = "testserver" +TEST_SERVER_BASE_URL = f"http://{TEST_SERVER_HOST}" # Test server implementation @@ -80,116 +68,123 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] return [TextContent(type="text", text=f"Called {name}")] -# Test fixtures -def make_server_app() -> Starlette: - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing +def create_asgi_client_factory(app: Starlette, tg: TaskGroup) -> McpHttpClientFactory: + """Factory function to create httpx clients with StreamingASGITransport""" + + def asgi_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + transport = StreamingASGITransport(app=app, task_group=tg) + return httpx.AsyncClient( + transport=transport, base_url=TEST_SERVER_BASE_URL, headers=headers, timeout=timeout, auth=auth + ) + + return asgi_client_factory + + +def create_sse_app(server: Server) -> Starlette: + """Helper to create SSE app with given server""" security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=[TEST_SERVER_HOST], + allowed_origins=[TEST_SERVER_BASE_URL], ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() - async def handle_sse(request: Request) -> Response: + async def handle_sse(request: Request) -> NoopASGI: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) - return Response() + # connect_sse already responded; return a no-op ASGI endpoint + return NoopASGI() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - -def run_server(server_port: int) -> None: - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) +# Test fixtures +@pytest.fixture() +def server_app() -> Starlette: + """Create test Starlette app with SSE transport""" + app = create_sse_app(ServerTest()) + return app @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def tg() -> AsyncGenerator[TaskGroup, None]: + async with anyio.create_task_group() as tg: + try: + yield tg + finally: + tg.cancel_scope.cancel() - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client using StreamingASGITransport""" + transport = StreamingASGITransport(app=server_app, task_group=tg) + async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client: + yield client @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(server_app, tg) + + async with sse_client( + "/sse", + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + async def connection_test() -> None: + async with http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection(sse_client_session: ClientSession) -> None: + # Test initialization + result = await sse_client_session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await sse_client_session.send_ping() + assert isinstance(ping_result, EmptyResult) @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session +async def initialized_sse_client_session(sse_client_session: ClientSession) -> AsyncGenerator[ClientSession, None]: + session = sse_client_session + await session.initialize() + yield session @pytest.mark.anyio @@ -232,51 +227,38 @@ async def test_sse_client_timeout( pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - - @pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def mounted_server_app(server_app: Starlette) -> Starlette: + """Create a mounted server app""" + app = Starlette(routes=[Mount("/mounted_app", app=server_app)]) + return app - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") +@pytest.fixture() +async def sse_client_mounted_server_app_session( + tg: TaskGroup, mounted_server_app: Starlette +) -> AsyncGenerator[ClientSession, None]: + asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg) + + async with sse_client( + "/mounted_app/sse", + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: + yield session @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME +async def test_sse_client_basic_connection_mounted_app(sse_client_mounted_server_app_session: ClientSession) -> None: + session = sse_client_mounted_server_app_session + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) # Test server with request context that returns headers in the response @@ -322,54 +304,15 @@ async def handle_list_tools() -> list[Tool]: ] -def run_context_server(server_port: int) -> None: - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] - ) - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("context server process failed to terminate") +async def context_server_app() -> Starlette: + """Fixture that provides the context server app""" + app = create_sse_app(RequestContextServer()) + return app @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -378,11 +321,14 @@ async def test_request_context_propagation(context_server: None, server_url: str "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + + async with sse_client( + "/sse", + headers=custom_headers, + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) @@ -391,9 +337,9 @@ async def test_request_context_propagation(context_server: None, server_url: str tool_result = await session.call_tool("echo_headers", {}) # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content_item = tool_result.content[0] + headers_data = json.loads(content_item.text if content_item.type == "text" else "{}") # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" @@ -402,19 +348,22 @@ async def test_request_context_propagation(context_server: None, server_url: str @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] + asgi_client_factory = create_asgi_client_factory(context_server_app, tg) + # Create multiple clients with different headers for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client( + "/sse", + headers=headers, + httpx_client_factory=asgi_client_factory, + ) as streams: + async with ClientSession(*streams) as session: await session.initialize() # Call the tool that echoes context diff --git a/tests/test_helpers.py b/tests/test_helpers.py index a4b4146e9..d561617c8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -3,6 +3,8 @@ import socket import time +from starlette.types import Receive, Scope, Send + def wait_for_server(port: int, timeout: float = 5.0) -> None: """Wait for server to be ready to accept connections. @@ -29,3 +31,16 @@ def wait_for_server(port: int, timeout: float = 5.0) -> None: # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") + + +class NoopASGI: + """ + This helper exists only for test SSE handlers. Production MCP servers + would normally expose an ASGI endpoint directly. We return this no-op + ASGI app instead of Response() so Starlette does not send a second + http.response.start, which breaks httpx.ASGITransport and + StreamingASGITransport. + """ + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return