Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ async def connect_mcp(
StdioMcpConnection,
validate_mcp_command,
)
from chainlit.session import WebsocketSession
from chainlit.session import WebsocketSession, safe_mcp_exit_stack_close

session = WebsocketSession.get_by_id(payload.sessionId)
context = init_ws_context(session)
Expand All @@ -1327,14 +1327,20 @@ async def connect_mcp(
if payload.name in session.mcp_sessions:
old_client_session, old_exit_stack = session.mcp_sessions[payload.name]
if on_mcp_disconnect := config.code.on_mcp_disconnect:
await on_mcp_disconnect(payload.name, old_client_session)
try:
await old_exit_stack.aclose()
except Exception:
pass
try:
await on_mcp_disconnect(payload.name, old_client_session)
except Exception:
logger.debug(
"Error in on_mcp_disconnect callback for %s",
payload.name,
exc_info=True,
)
await safe_mcp_exit_stack_close(old_exit_stack)
del session.mcp_sessions[payload.name]

exit_stack = AsyncExitStack()
exit_stack_stored = False
try:
exit_stack = AsyncExitStack()
mcp_connection: McpConnection

if payload.clientType == "sse":
Expand Down Expand Up @@ -1411,18 +1417,22 @@ async def connect_mcp(
# Initialize the session
await mcp_session.initialize()

# Store the session
session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)

# Call the callback
if config.code.on_mcp_connect:
await config.code.on_mcp_connect(mcp_connection, mcp_session)

# Store the session
session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)
exit_stack_stored = True

except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Could not connect to the MCP: {e!s}",
)
finally:
if not exit_stack_stored:
await safe_mcp_exit_stack_close(exit_stack)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -1459,7 +1469,7 @@ async def disconnect_mcp(
current_user: UserParam,
):
from chainlit.context import init_ws_context
from chainlit.session import WebsocketSession
from chainlit.session import WebsocketSession, safe_mcp_exit_stack_close

session = WebsocketSession.get_by_id(payload.sessionId)
context = init_ws_context(session)
Expand All @@ -1480,10 +1490,7 @@ async def disconnect_mcp(
if callback:
await callback(payload.name, client_session)

try:
await exit_stack.aclose()
except Exception:
pass
await safe_mcp_exit_stack_close(exit_stack)
del session.mcp_sessions[payload.name]

except Exception as e:
Expand Down
62 changes: 58 additions & 4 deletions backend/chainlit/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import builtins
import json
import mimetypes
import re
Expand All @@ -12,6 +13,62 @@
from chainlit.logger import logger
from chainlit.types import AskFileSpec, FileReference

_BASE_EXCEPTION_GROUP = getattr(builtins, "BaseExceptionGroup", None)


async def safe_mcp_exit_stack_close(exit_stack: AsyncExitStack) -> None:
"""Close an MCP exit stack, suppressing cross-task cancel scope errors.

AnyIO raises RuntimeError when an AsyncExitStack that was entered in one
asyncio task is closed from a different task (e.g., during HTTP request
handling for disconnect, session deletion, or reconnection).

The MCP SDK's streamable-http transport wraps this in a
BaseExceptionGroup via its internal TaskGroup, so both forms are caught.

This helper catches the error so MCP cleanup never propagates a cross-task
cancel scope exception, which would otherwise leave orphaned resources and
can cause 100% CPU spin loops.

See: https://github.com/Chainlit/chainlit/issues/2182
"""
try:
await exit_stack.aclose()
except RuntimeError as exc:
if _is_cancel_scope_error(exc):
logger.debug(
"Suppressed cross-task cancel scope error during MCP cleanup: %s",
exc,
)
else:
logger.warning("Error closing MCP exit stack: %s", exc, exc_info=True)
except Exception as exc:
if _is_cancel_scope_error(exc):
logger.debug(
"Suppressed cross-task cancel scope error during MCP cleanup: %s",
exc,
)
else:
logger.debug("Error closing MCP exit stack", exc_info=True)


def _is_cancel_scope_error(exc: BaseException) -> bool:
"""Check whether an exception is an anyio cancel-scope cross-task error.

Handles both a bare RuntimeError and a BaseExceptionGroup wrapping one
(as produced by anyio's TaskGroup when the streamable-http transport
tears down). Only treats a group as a cancel-scope error when *all*
contained exceptions match, so mixed groups surface real failures.
"""
if isinstance(exc, RuntimeError):
return "cancel scope" in str(exc)
if _BASE_EXCEPTION_GROUP and isinstance(exc, _BASE_EXCEPTION_GROUP):
return bool(exc.exceptions) and all(
_is_cancel_scope_error(e) for e in exc.exceptions
)
return False


if TYPE_CHECKING:
from mcp import ClientSession

Expand Down Expand Up @@ -322,10 +379,7 @@ async def delete(self):
ws_sessions_id.pop(self.id, None)

for _, exit_stack in self.mcp_sessions.values():
try:
await exit_stack.aclose()
except Exception:
pass
await safe_mcp_exit_stack_close(exit_stack)

async def flush_method_queue(self):
for method_name, queue in self.thread_queues.items():
Expand Down
118 changes: 118 additions & 0 deletions backend/tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import json
import tempfile
import uuid
Expand All @@ -11,10 +12,19 @@
HTTPSession,
JSONEncoderIgnoreNonSerializable,
WebsocketSession,
_is_cancel_scope_error,
clean_metadata,
safe_mcp_exit_stack_close,
)


def make_exception_group(message: str, exceptions: list[BaseException]):
base_exception_group = getattr(builtins, "BaseExceptionGroup", None)
if not base_exception_group:
pytest.skip("BaseExceptionGroup is unavailable on this Python version")
return base_exception_group(message, exceptions)


class TestJSONEncoderIgnoreNonSerializable:
"""Test suite for JSONEncoderIgnoreNonSerializable."""

Expand Down Expand Up @@ -620,3 +630,111 @@ async def test_websocket_session_delete_with_mcp_sessions(self):
await session.delete()

mock_exit_stack.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_websocket_session_delete_with_cancel_scope_error(self):
"""Test that session delete handles cancel scope RuntimeError gracefully."""

with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
session = WebsocketSession(
id="ws_id",
socket_id="socket_123",
emit=Mock(),
emit_call=Mock(),
user_env={},
client_type="webapp",
)

# Mock MCP session with exit stack that raises cancel scope error
mock_exit_stack = AsyncMock()
mock_exit_stack.aclose.side_effect = RuntimeError(
"Attempted to exit cancel scope in a different task"
)
session.mcp_sessions["mcp1"] = (Mock(), mock_exit_stack)

# Should not raise
await session.delete()

mock_exit_stack.aclose.assert_called_once()


class TestSafeMcpExitStackClose:
"""Test suite for safe_mcp_exit_stack_close helper."""

@pytest.mark.asyncio
async def test_closes_exit_stack_normally(self):
"""Test normal exit stack close succeeds."""
mock_exit_stack = AsyncMock()
await safe_mcp_exit_stack_close(mock_exit_stack)
mock_exit_stack.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_suppresses_cancel_scope_runtime_error(self):
"""Test that cancel scope RuntimeError is suppressed."""
mock_exit_stack = AsyncMock()
mock_exit_stack.aclose.side_effect = RuntimeError(
"Attempted to exit cancel scope in a different task than it was entered in"
)
# Should not raise
await safe_mcp_exit_stack_close(mock_exit_stack)
mock_exit_stack.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_logs_warning_for_non_cancel_scope_runtime_error(self):
"""Test that non-cancel-scope RuntimeErrors are logged as warnings."""
mock_exit_stack = AsyncMock()
mock_exit_stack.aclose.side_effect = RuntimeError("something else")
# Should not raise
await safe_mcp_exit_stack_close(mock_exit_stack)
mock_exit_stack.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_suppresses_other_exceptions(self):
"""Test that other exceptions during close are suppressed."""
mock_exit_stack = AsyncMock()
mock_exit_stack.aclose.side_effect = OSError("connection reset")
# Should not raise
await safe_mcp_exit_stack_close(mock_exit_stack)
mock_exit_stack.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_suppresses_cancel_scope_wrapped_in_exception_group(self):
"""Test that a BaseExceptionGroup wrapping a cancel scope error is suppressed."""
mock_exit_stack = AsyncMock()
inner = RuntimeError("Attempted to exit cancel scope in a different task")
mock_exit_stack.aclose.side_effect = make_exception_group("errors", [inner])
# Should not raise
await safe_mcp_exit_stack_close(mock_exit_stack)
mock_exit_stack.aclose.assert_called_once()


class TestIsCancelScopeError:
"""Test suite for _is_cancel_scope_error helper."""

def test_matches_cancel_scope_runtime_error(self):
assert _is_cancel_scope_error(
RuntimeError("Attempted to exit cancel scope in a different task")
)

def test_rejects_unrelated_runtime_error(self):
assert not _is_cancel_scope_error(RuntimeError("something unrelated"))

def test_rejects_non_runtime_error(self):
assert not _is_cancel_scope_error(ValueError("cancel scope in message"))

def test_matches_wrapped_in_exception_group(self):
inner = RuntimeError("Attempted to exit cancel scope in a different task")
assert _is_cancel_scope_error(make_exception_group("errors", [inner]))

def test_rejects_exception_group_without_cancel_scope(self):
assert not _is_cancel_scope_error(
make_exception_group("errors", [RuntimeError("unrelated")])
)

def test_rejects_mixed_exception_group(self):
cancel = RuntimeError("Attempted to exit cancel scope in a different task")
other = RuntimeError("unrelated failure")
assert not _is_cancel_scope_error(
make_exception_group("errors", [cancel, other])
)
Loading