From 805c4a554b2284665b5ef1216667c856564fb997 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Thu, 18 Sep 2025 18:49:17 -0400 Subject: [PATCH] fix: Fix mcp timeout issue --- src/strands/tools/mcp/mcp_client.py | 12 +++++++++++- tests_integ/mcp/test_mcp_client.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 96e80385f..dec8ec313 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -18,6 +18,7 @@ from types import TracebackType from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +import anyio from mcp import ClientSession, ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import GetPromptResult, ListPromptsResult @@ -378,6 +379,13 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes return result + # Raise an exception if the underlying client raises an exception in a message + # This happens when the underlying client has an http timeout error + async def _handle_error_message(self, message: Exception | Any) -> None: + if isinstance(message, Exception): + raise message + await anyio.lowlevel.checkpoint() + async def _async_background_thread(self) -> None: """Asynchronous method that runs in the background thread to manage the MCP connection. @@ -388,7 +396,9 @@ async def _async_background_thread(self) -> None: try: async with self._transport_callable() as (read_stream, write_stream, *_): self._log_debug_with_thread("transport connection established") - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession( + read_stream, write_stream, message_handler=self._handle_error_message + ) as session: self._log_debug_with_thread("initializing MCP session") await session.initialize() diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 5e1dc958b..9d5ab5f13 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -31,6 +31,11 @@ def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], mcp = FastMCP("Comprehensive MCP Server", port=port) + @mcp.tool(description="Tool that will timeout") + def timeout_tool() -> str: + time.sleep(10) + return "This tool has timed out" + @mcp.tool(description="Calculator tool which performs calculations") def calculator(x: int, y: int) -> int: return x + y @@ -297,3 +302,27 @@ def slow_transport(): with client: tools = client.list_tools_sync() assert len(tools) >= 0 # Should work now + + +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", +) +@pytest.mark.asyncio +async def test_streamable_http_mcp_client_times_out_before_tool(): + """Test an mcp server that timesout before the tool is able to respond.""" + server_thread = threading.Thread( + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + ) + server_thread.start() + time.sleep(2) # wait for server to startup completely + + def transport_callback() -> MCPTransport: + return streamablehttp_client(sse_read_timeout=2, url="http://127.0.0.1:8001/mcp") + + streamable_http_client = MCPClient(transport_callback) + with streamable_http_client: + # Test tools + result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") + assert result["status"] == "error" + assert result["content"][0]["text"] == "Tool execution failed: Connection closed"