diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index f9393e5e3b..c4dd6ba656 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1307,7 +1307,7 @@ async def connect_mcp( StdioMcpConnection, validate_mcp_command, ) - from chainlit.session import WebsocketSession + from chainlit.session import WebsocketSession, safe_mcp_exit_stack_close session = WebsocketSession.get_by_id(payload.sessionId) context = init_ws_context(session) @@ -1327,14 +1327,20 @@ async def connect_mcp( if payload.name in session.mcp_sessions: old_client_session, old_exit_stack = session.mcp_sessions[payload.name] if on_mcp_disconnect := config.code.on_mcp_disconnect: - await on_mcp_disconnect(payload.name, old_client_session) - try: - await old_exit_stack.aclose() - except Exception: - pass + try: + await on_mcp_disconnect(payload.name, old_client_session) + except Exception: + logger.debug( + "Error in on_mcp_disconnect callback for %s", + payload.name, + exc_info=True, + ) + await safe_mcp_exit_stack_close(old_exit_stack) + del session.mcp_sessions[payload.name] + exit_stack = AsyncExitStack() + exit_stack_stored = False try: - exit_stack = AsyncExitStack() mcp_connection: McpConnection if payload.clientType == "sse": @@ -1411,18 +1417,22 @@ async def connect_mcp( # Initialize the session await mcp_session.initialize() - # Store the session - session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack) - # Call the callback if config.code.on_mcp_connect: await config.code.on_mcp_connect(mcp_connection, mcp_session) + # Store the session + session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack) + exit_stack_stored = True + except Exception as e: raise HTTPException( status_code=400, detail=f"Could not connect to the MCP: {e!s}", ) + finally: + if not exit_stack_stored: + await safe_mcp_exit_stack_close(exit_stack) else: raise HTTPException( status_code=400, @@ -1459,7 +1469,7 @@ async def disconnect_mcp( current_user: UserParam, ): from chainlit.context import init_ws_context - from chainlit.session import WebsocketSession + from chainlit.session import WebsocketSession, safe_mcp_exit_stack_close session = WebsocketSession.get_by_id(payload.sessionId) context = init_ws_context(session) @@ -1480,10 +1490,7 @@ async def disconnect_mcp( if callback: await callback(payload.name, client_session) - try: - await exit_stack.aclose() - except Exception: - pass + await safe_mcp_exit_stack_close(exit_stack) del session.mcp_sessions[payload.name] except Exception as e: diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index d6bd3f6214..53bf2bed62 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -1,4 +1,5 @@ import asyncio +import builtins import json import mimetypes import re @@ -12,6 +13,62 @@ from chainlit.logger import logger from chainlit.types import AskFileSpec, FileReference +_BASE_EXCEPTION_GROUP = getattr(builtins, "BaseExceptionGroup", None) + + +async def safe_mcp_exit_stack_close(exit_stack: AsyncExitStack) -> None: + """Close an MCP exit stack, suppressing cross-task cancel scope errors. + + AnyIO raises RuntimeError when an AsyncExitStack that was entered in one + asyncio task is closed from a different task (e.g., during HTTP request + handling for disconnect, session deletion, or reconnection). + + The MCP SDK's streamable-http transport wraps this in a + BaseExceptionGroup via its internal TaskGroup, so both forms are caught. + + This helper catches the error so MCP cleanup never propagates a cross-task + cancel scope exception, which would otherwise leave orphaned resources and + can cause 100% CPU spin loops. + + See: https://github.com/Chainlit/chainlit/issues/2182 + """ + try: + await exit_stack.aclose() + except RuntimeError as exc: + if _is_cancel_scope_error(exc): + logger.debug( + "Suppressed cross-task cancel scope error during MCP cleanup: %s", + exc, + ) + else: + logger.warning("Error closing MCP exit stack: %s", exc, exc_info=True) + except Exception as exc: + if _is_cancel_scope_error(exc): + logger.debug( + "Suppressed cross-task cancel scope error during MCP cleanup: %s", + exc, + ) + else: + logger.debug("Error closing MCP exit stack", exc_info=True) + + +def _is_cancel_scope_error(exc: BaseException) -> bool: + """Check whether an exception is an anyio cancel-scope cross-task error. + + Handles both a bare RuntimeError and a BaseExceptionGroup wrapping one + (as produced by anyio's TaskGroup when the streamable-http transport + tears down). Only treats a group as a cancel-scope error when *all* + contained exceptions match, so mixed groups surface real failures. + """ + if isinstance(exc, RuntimeError): + return "cancel scope" in str(exc) + if _BASE_EXCEPTION_GROUP and isinstance(exc, _BASE_EXCEPTION_GROUP): + return bool(exc.exceptions) and all( + _is_cancel_scope_error(e) for e in exc.exceptions + ) + return False + + if TYPE_CHECKING: from mcp import ClientSession @@ -322,10 +379,7 @@ async def delete(self): ws_sessions_id.pop(self.id, None) for _, exit_stack in self.mcp_sessions.values(): - try: - await exit_stack.aclose() - except Exception: - pass + await safe_mcp_exit_stack_close(exit_stack) async def flush_method_queue(self): for method_name, queue in self.thread_queues.items(): diff --git a/backend/tests/test_session.py b/backend/tests/test_session.py index e98b7a0994..0532fb2dd3 100644 --- a/backend/tests/test_session.py +++ b/backend/tests/test_session.py @@ -1,3 +1,4 @@ +import builtins import json import tempfile import uuid @@ -11,10 +12,19 @@ HTTPSession, JSONEncoderIgnoreNonSerializable, WebsocketSession, + _is_cancel_scope_error, clean_metadata, + safe_mcp_exit_stack_close, ) +def make_exception_group(message: str, exceptions: list[BaseException]): + base_exception_group = getattr(builtins, "BaseExceptionGroup", None) + if not base_exception_group: + pytest.skip("BaseExceptionGroup is unavailable on this Python version") + return base_exception_group(message, exceptions) + + class TestJSONEncoderIgnoreNonSerializable: """Test suite for JSONEncoderIgnoreNonSerializable.""" @@ -620,3 +630,111 @@ async def test_websocket_session_delete_with_mcp_sessions(self): await session.delete() mock_exit_stack.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_websocket_session_delete_with_cancel_scope_error(self): + """Test that session delete handles cancel scope RuntimeError gracefully.""" + + with tempfile.TemporaryDirectory() as tmpdir: + with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)): + session = WebsocketSession( + id="ws_id", + socket_id="socket_123", + emit=Mock(), + emit_call=Mock(), + user_env={}, + client_type="webapp", + ) + + # Mock MCP session with exit stack that raises cancel scope error + mock_exit_stack = AsyncMock() + mock_exit_stack.aclose.side_effect = RuntimeError( + "Attempted to exit cancel scope in a different task" + ) + session.mcp_sessions["mcp1"] = (Mock(), mock_exit_stack) + + # Should not raise + await session.delete() + + mock_exit_stack.aclose.assert_called_once() + + +class TestSafeMcpExitStackClose: + """Test suite for safe_mcp_exit_stack_close helper.""" + + @pytest.mark.asyncio + async def test_closes_exit_stack_normally(self): + """Test normal exit stack close succeeds.""" + mock_exit_stack = AsyncMock() + await safe_mcp_exit_stack_close(mock_exit_stack) + mock_exit_stack.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_suppresses_cancel_scope_runtime_error(self): + """Test that cancel scope RuntimeError is suppressed.""" + mock_exit_stack = AsyncMock() + mock_exit_stack.aclose.side_effect = RuntimeError( + "Attempted to exit cancel scope in a different task than it was entered in" + ) + # Should not raise + await safe_mcp_exit_stack_close(mock_exit_stack) + mock_exit_stack.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_logs_warning_for_non_cancel_scope_runtime_error(self): + """Test that non-cancel-scope RuntimeErrors are logged as warnings.""" + mock_exit_stack = AsyncMock() + mock_exit_stack.aclose.side_effect = RuntimeError("something else") + # Should not raise + await safe_mcp_exit_stack_close(mock_exit_stack) + mock_exit_stack.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_suppresses_other_exceptions(self): + """Test that other exceptions during close are suppressed.""" + mock_exit_stack = AsyncMock() + mock_exit_stack.aclose.side_effect = OSError("connection reset") + # Should not raise + await safe_mcp_exit_stack_close(mock_exit_stack) + mock_exit_stack.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_suppresses_cancel_scope_wrapped_in_exception_group(self): + """Test that a BaseExceptionGroup wrapping a cancel scope error is suppressed.""" + mock_exit_stack = AsyncMock() + inner = RuntimeError("Attempted to exit cancel scope in a different task") + mock_exit_stack.aclose.side_effect = make_exception_group("errors", [inner]) + # Should not raise + await safe_mcp_exit_stack_close(mock_exit_stack) + mock_exit_stack.aclose.assert_called_once() + + +class TestIsCancelScopeError: + """Test suite for _is_cancel_scope_error helper.""" + + def test_matches_cancel_scope_runtime_error(self): + assert _is_cancel_scope_error( + RuntimeError("Attempted to exit cancel scope in a different task") + ) + + def test_rejects_unrelated_runtime_error(self): + assert not _is_cancel_scope_error(RuntimeError("something unrelated")) + + def test_rejects_non_runtime_error(self): + assert not _is_cancel_scope_error(ValueError("cancel scope in message")) + + def test_matches_wrapped_in_exception_group(self): + inner = RuntimeError("Attempted to exit cancel scope in a different task") + assert _is_cancel_scope_error(make_exception_group("errors", [inner])) + + def test_rejects_exception_group_without_cancel_scope(self): + assert not _is_cancel_scope_error( + make_exception_group("errors", [RuntimeError("unrelated")]) + ) + + def test_rejects_mixed_exception_group(self): + cancel = RuntimeError("Attempted to exit cancel scope in a different task") + other = RuntimeError("unrelated failure") + assert not _is_cancel_scope_error( + make_exception_group("errors", [cancel, other]) + )