Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def __init__(self, tool_result: ToolResult) -> None:
@property
def tool_use_id(self) -> str:
"""The toolUseId associated with this result."""
return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId"))
tool_result = cast(ToolResult, self.get("tool_result"))
return tool_result.get("toolUseId")

@property
def tool_result(self) -> ToolResult:
Expand Down Expand Up @@ -314,7 +315,8 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None:
@property
def tool_use_id(self) -> str:
"""The toolUseId associated with this stream."""
return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId"))
tool_use = cast(dict, self.get("tool_stream_event")).get("tool_use")
return cast(ToolUse, tool_use).get("toolUseId")


class ToolCancelEvent(TypedEvent):
Expand All @@ -332,7 +334,8 @@ def __init__(self, tool_use: ToolUse, message: str) -> None:
@property
def tool_use_id(self) -> str:
"""The id of the tool cancelled."""
return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId"))
tool_use = cast(dict, self.get("tool_cancel_event")).get("tool_use")
return cast(ToolUse, tool_use).get("toolUseId")

@property
def message(self) -> str:
Expand All @@ -350,7 +353,8 @@ def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None:
@property
def tool_use_id(self) -> str:
"""The id of the tool interrupted."""
return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId"))
tool_use = cast(dict, self.get("tool_interrupt_event")).get("tool_use")
return cast(ToolUse, tool_use).get("toolUseId")

@property
def interrupts(self) -> list[Interrupt]:
Expand Down
50 changes: 50 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import base64
import threading
import time
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -723,3 +725,51 @@ async def test_handle_error_message_non_exception():

# This should not raise an exception
await client._handle_error_message("normal message")


def _start_background_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread]:
"""Spin up a background asyncio loop for tests that exercise thread hop logic."""
loop = asyncio.new_event_loop()
ready = threading.Event()

def run_loop() -> None:
asyncio.set_event_loop(loop)
ready.set()
loop.run_forever()

thread = threading.Thread(target=run_loop, daemon=True)
thread.start()
ready.wait()
return loop, thread


def _create_resolved_future(loop: asyncio.AbstractEventLoop) -> asyncio.Future:
"""Create a future on the given loop that is already resolved."""

async def make_future() -> asyncio.Future:
fut: asyncio.Future = asyncio.Future()
fut.set_result(None)
return fut

return asyncio.run_coroutine_threadsafe(make_future(), loop).result()


def test_invoke_on_background_thread_aborts_when_connection_closes() -> None:
"""Invoke should fail fast when the MCP connection has already collapsed (close_future resolved)."""
client = MCPClient(MagicMock())
loop, thread = _start_background_loop()
try:
# Simulate an initialized background session with a resolved close future
client._background_thread_session = MagicMock()
client._background_thread_event_loop = loop
client._close_future = _create_resolved_future(loop)

async def slow_coro() -> str:
await asyncio.sleep(1)
return "ok"

with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"):
client._invoke_on_background_thread(slow_coro()).result(timeout=1)
finally:
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=1)
46 changes: 46 additions & 0 deletions tests/strands/tools/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests for the SDK tool watcher module.
"""

from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -10,6 +11,20 @@
from strands.tools.watcher import ToolWatcher


@pytest.fixture(autouse=True)
def reset_tool_watcher_state():
"""Reset ToolWatcher shared state between tests to avoid cross-test leakage."""
ToolWatcher._shared_observer = None
ToolWatcher._watched_dirs = set()
ToolWatcher._observer_started = False
ToolWatcher._registry_handlers = {}
yield
ToolWatcher._shared_observer = None
ToolWatcher._watched_dirs = set()
ToolWatcher._observer_started = False
ToolWatcher._registry_handlers = {}


def test_tool_watcher_initialization():
"""Test that the handler initializes with the correct tool registry."""
tool_registry = ToolRegistry()
Expand Down Expand Up @@ -96,3 +111,34 @@ def test_on_modified_error_handling(mock_reload_tool):

# Verify that reload_tool was called
mock_reload_tool.assert_called_once_with("test_tool")


@patch("strands.tools.watcher.Observer")
def test_master_handler_routes_events_to_all_registries(mock_observer_cls):
"""Master handler should fan out file changes to all registry handlers for the same directory."""
mock_observer = MagicMock()
mock_observer_cls.return_value = mock_observer

tools_dir = Path("/tmp/tools")
registry_a = MagicMock(spec=ToolRegistry)
registry_b = MagicMock(spec=ToolRegistry)
registry_a.get_tools_dirs.return_value = [tools_dir]
registry_b.get_tools_dirs.return_value = [tools_dir]

ToolWatcher(registry_a)
ToolWatcher(registry_b)

# Only one observer/schedule/start for the shared directory
mock_observer.schedule.assert_called_once()
mock_observer.start.assert_called_once()
assert len(ToolWatcher._registry_handlers[str(tools_dir)]) == 2

event = MagicMock()
event.src_path = str(tools_dir / "my_tool.py")
event.is_directory = False

master_handler = ToolWatcher.MasterChangeHandler(str(tools_dir))
master_handler.on_modified(event)

registry_a.reload_tool.assert_called_once_with("my_tool")
registry_b.reload_tool.assert_called_once_with("my_tool")