diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 558d3e298..6f4fe2587 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -286,7 +286,8 @@ def __init__(self, tool_result: ToolResult) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this result.""" - return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId")) + tool_result = cast(ToolResult, self.get("tool_result")) + return tool_result.get("toolUseId") @property def tool_result(self) -> ToolResult: @@ -314,7 +315,8 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) + tool_use = cast(dict, self.get("tool_stream_event")).get("tool_use") + return cast(ToolUse, tool_use).get("toolUseId") class ToolCancelEvent(TypedEvent): @@ -332,7 +334,8 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) + tool_use = cast(dict, self.get("tool_cancel_event")).get("tool_use") + return cast(ToolUse, tool_use).get("toolUseId") @property def message(self) -> str: @@ -350,7 +353,8 @@ def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: @property def tool_use_id(self) -> str: """The id of the tool interrupted.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + tool_use = cast(dict, self.get("tool_interrupt_event")).get("tool_use") + return cast(ToolUse, tool_use).get("toolUseId") @property def interrupts(self) -> list[Interrupt]: diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index ec77b48a2..97640a5ee 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,4 +1,6 @@ +import asyncio import base64 +import threading import time from unittest.mock import AsyncMock, MagicMock, patch @@ -723,3 +725,51 @@ async def test_handle_error_message_non_exception(): # This should not raise an exception await client._handle_error_message("normal message") + + +def _start_background_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread]: + """Spin up a background asyncio loop for tests that exercise thread hop logic.""" + loop = asyncio.new_event_loop() + ready = threading.Event() + + def run_loop() -> None: + asyncio.set_event_loop(loop) + ready.set() + loop.run_forever() + + thread = threading.Thread(target=run_loop, daemon=True) + thread.start() + ready.wait() + return loop, thread + + +def _create_resolved_future(loop: asyncio.AbstractEventLoop) -> asyncio.Future: + """Create a future on the given loop that is already resolved.""" + + async def make_future() -> asyncio.Future: + fut: asyncio.Future = asyncio.Future() + fut.set_result(None) + return fut + + return asyncio.run_coroutine_threadsafe(make_future(), loop).result() + + +def test_invoke_on_background_thread_aborts_when_connection_closes() -> None: + """Invoke should fail fast when the MCP connection has already collapsed (close_future resolved).""" + client = MCPClient(MagicMock()) + loop, thread = _start_background_loop() + try: + # Simulate an initialized background session with a resolved close future + client._background_thread_session = MagicMock() + client._background_thread_event_loop = loop + client._close_future = _create_resolved_future(loop) + + async def slow_coro() -> str: + await asyncio.sleep(1) + return "ok" + + with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"): + client._invoke_on_background_thread(slow_coro()).result(timeout=1) + finally: + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=1) diff --git a/tests/strands/tools/test_watcher.py b/tests/strands/tools/test_watcher.py index 75a5616fe..2cb8f708a 100644 --- a/tests/strands/tools/test_watcher.py +++ b/tests/strands/tools/test_watcher.py @@ -2,6 +2,7 @@ Tests for the SDK tool watcher module. """ +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -10,6 +11,20 @@ from strands.tools.watcher import ToolWatcher +@pytest.fixture(autouse=True) +def reset_tool_watcher_state(): + """Reset ToolWatcher shared state between tests to avoid cross-test leakage.""" + ToolWatcher._shared_observer = None + ToolWatcher._watched_dirs = set() + ToolWatcher._observer_started = False + ToolWatcher._registry_handlers = {} + yield + ToolWatcher._shared_observer = None + ToolWatcher._watched_dirs = set() + ToolWatcher._observer_started = False + ToolWatcher._registry_handlers = {} + + def test_tool_watcher_initialization(): """Test that the handler initializes with the correct tool registry.""" tool_registry = ToolRegistry() @@ -96,3 +111,34 @@ def test_on_modified_error_handling(mock_reload_tool): # Verify that reload_tool was called mock_reload_tool.assert_called_once_with("test_tool") + + +@patch("strands.tools.watcher.Observer") +def test_master_handler_routes_events_to_all_registries(mock_observer_cls): + """Master handler should fan out file changes to all registry handlers for the same directory.""" + mock_observer = MagicMock() + mock_observer_cls.return_value = mock_observer + + tools_dir = Path("/tmp/tools") + registry_a = MagicMock(spec=ToolRegistry) + registry_b = MagicMock(spec=ToolRegistry) + registry_a.get_tools_dirs.return_value = [tools_dir] + registry_b.get_tools_dirs.return_value = [tools_dir] + + ToolWatcher(registry_a) + ToolWatcher(registry_b) + + # Only one observer/schedule/start for the shared directory + mock_observer.schedule.assert_called_once() + mock_observer.start.assert_called_once() + assert len(ToolWatcher._registry_handlers[str(tools_dir)]) == 2 + + event = MagicMock() + event.src_path = str(tools_dir / "my_tool.py") + event.is_directory = False + + master_handler = ToolWatcher.MasterChangeHandler(str(tools_dir)) + master_handler.on_modified(event) + + registry_a.reload_tool.assert_called_once_with("my_tool") + registry_b.reload_tool.assert_called_once_with("my_tool")