diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 57986062e..c28a4adaf 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -21,8 +21,7 @@ from .models.model import BidiModel from .models.nova_sonic import BidiNovaSonicModel -# Built-in tools -from .tools import stop_conversation + # Event types - For type hints and event handling from .types.events import ( @@ -50,8 +49,6 @@ "BidiAudioIO", # Model providers "BidiNovaSonicModel", - # Built-in tools - "stop_conversation", # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2b883cf73..cd7f8412f 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -248,6 +248,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_results: list[ToolResult] = [] + # Ensure request_state exists for tools like strands_tools.stop + if "request_state" not in self._invocation_state: + self._invocation_state["request_state"] = {} + invocation_state: dict[str, Any] = { **self._invocation_state, "agent": self._agent, @@ -273,7 +277,6 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(tool_event) - # Normal flow for all tools (including stop_conversation) tool_result_event = cast(ToolResultEvent, tool_event) tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} @@ -282,16 +285,17 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) - # Check for stop_conversation before sending to model - if tool_use["name"] == "stop_conversation": - logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"]) + # Check for stop_event_loop flag (set by strands_tools.stop or any tool) + request_state = invocation_state.get("request_state", {}) + if request_state.get("stop_event_loop", False): + logger.info("stop_event_loop= | stopping conversation") connection_id = getattr(self._agent.model, "_connection_id", "unknown") await self._event_queue.put( BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") ) - return # Skip the model send + return # Skip sending result to model - # Send result to model (all tools except stop_conversation) + # Send result to model await self.send(tool_result_event) except Exception as error: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index f575c5606..00d999818 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -42,7 +42,7 @@ async def __call__(self, event: BidiOutputEvent) -> None: elif isinstance(event, BidiConnectionCloseEvent): if event.reason == "user_request": - print("user requested connection close using the stop_conversation tool.") + print("user requested connection close using the stop tool.") logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py index c665dc65a..7ff8dd758 100644 --- a/src/strands/experimental/bidi/tools/__init__.py +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -1,5 +1,10 @@ -"""Built-in tools for bidirectional agents.""" +"""Built-in tools for bidirectional agents. -from .stop_conversation import stop_conversation +Note: To stop a bidirectional conversation, use the standard `stop` tool from strands_tools: -__all__ = ["stop_conversation"] + from strands_tools import stop + agent = BidiAgent(tools=[stop, ...]) + +The stop tool sets `request_state["stop_event_loop"] = True`, which signals the +BidiAgent to gracefully close the connection. +""" diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py deleted file mode 100644 index 9c7e1c6cd..000000000 --- a/src/strands/experimental/bidi/tools/stop_conversation.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Tool to gracefully stop a bidirectional connection.""" - -from ....tools.decorator import tool - - -@tool -def stop_conversation() -> str: - """Stop the bidirectional conversation gracefully. - - Use ONLY when user says "stop conversation" exactly. - Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases. - - Returns: - Success message confirming the conversation will end - """ - return "Ending conversation" diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 0ce8d6658..2e52047cc 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -95,3 +95,116 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): assert tru_messages == exp_messages agent.model.send.assert_called_with(tool_result_event) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_initialized_for_tools(loop, agent, agenerator): + """Test that request_state is initialized before tool execution. + + This ensures tools that access request_state (like strands_tools.stop) + work correctly even when invocation_state is not provided by the user. + """ + + @tool(name="check_request_state") + async def check_request_state(request_state: dict) -> str: + # Verify request_state exists and is writable + request_state["test_key"] = "test_value" + return f"keys: {list(request_state.keys())}" + + agent.tool_registry.register_tool(check_request_state) + + tool_use = {"toolUseId": "t2", "name": "check_request_state", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + # Start without providing invocation_state + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully (request_state was available) + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + assert "test_key" in tool_result_event.tool_result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_stop_event_loop_flag(loop, agent, agenerator): + """Test that tools can set stop_event_loop flag to gracefully close connection.""" + + @tool(name="stop_tool") + async def stop_tool(request_state: dict) -> str: + request_state["stop_event_loop"] = True + return "stopping" + + agent.tool_registry.register_tool(stop_tool) + + tool_use = {"toolUseId": "t3", "name": "stop_tool", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + + # Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close + assert len(tru_events) == 4 + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify connection close event was emitted + from strands.experimental.bidi.types.events import BidiConnectionCloseEvent + + connection_close_event = tru_events[3] + assert isinstance(connection_close_event, BidiConnectionCloseEvent) + assert connection_close_event["reason"] == "user_request" + + # Verify model.send was NOT called (tool result not sent to model) + agent.model.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_preserved_with_invocation_state(agent, agenerator): + """Test that existing invocation_state is preserved when request_state is initialized.""" + + @tool(name="check_invocation_state") + async def check_invocation_state(custom_key: str) -> str: + return f"custom_key: {custom_key}" + + agent.tool_registry.register_tool(check_invocation_state) + + tool_use = {"toolUseId": "t4", "name": "check_invocation_state", "input": {"custom_key": "from_state"}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + loop = agent._loop + # Start with custom invocation_state but no request_state + await loop.start(invocation_state={"custom_data": "preserved"}) + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify request_state was added without removing custom_data + assert "request_state" in loop._invocation_state + assert loop._invocation_state.get("custom_data") == "preserved"