Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/strands/experimental/bidi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from .models.model import BidiModel
from .models.nova_sonic import BidiNovaSonicModel

# Built-in tools
from .tools import stop_conversation


# Event types - For type hints and event handling
from .types.events import (
Expand Down Expand Up @@ -50,8 +49,6 @@
"BidiAudioIO",
# Model providers
"BidiNovaSonicModel",
# Built-in tools
"stop_conversation",
# Input Event types
"BidiTextInputEvent",
"BidiAudioInputEvent",
Expand Down
16 changes: 10 additions & 6 deletions src/strands/experimental/bidi/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None:

tool_results: list[ToolResult] = []

# Ensure request_state exists for tools like strands_tools.stop
if "request_state" not in self._invocation_state:
self._invocation_state["request_state"] = {}

invocation_state: dict[str, Any] = {
**self._invocation_state,
"agent": self._agent,
Expand All @@ -273,7 +277,6 @@ async def _run_tool(self, tool_use: ToolUse) -> None:

await self._event_queue.put(tool_event)

# Normal flow for all tools (including stop_conversation)
tool_result_event = cast(ToolResultEvent, tool_event)

tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
Expand All @@ -282,16 +285,17 @@ async def _run_tool(self, tool_use: ToolUse) -> None:

await self._event_queue.put(ToolResultMessageEvent(tool_result_message))

# Check for stop_conversation before sending to model
if tool_use["name"] == "stop_conversation":
logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"])
# Check for stop_event_loop flag (set by strands_tools.stop or any tool)
request_state = invocation_state.get("request_state", {})
if request_state.get("stop_event_loop", False):
logger.info("stop_event_loop=<True> | stopping conversation")
connection_id = getattr(self._agent.model, "_connection_id", "unknown")
await self._event_queue.put(
BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request")
)
return # Skip the model send
return # Skip sending result to model

# Send result to model (all tools except stop_conversation)
# Send result to model
await self.send(tool_result_event)

except Exception as error:
Expand Down
2 changes: 1 addition & 1 deletion src/strands/experimental/bidi/io/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def __call__(self, event: BidiOutputEvent) -> None:

elif isinstance(event, BidiConnectionCloseEvent):
if event.reason == "user_request":
print("user requested connection close using the stop_conversation tool.")
print("user requested connection close using the stop tool.")
logger.debug("connection_id=<%s> | user requested connection close", event.connection_id)
elif isinstance(event, BidiTranscriptStreamEvent):
text = event["text"]
Expand Down
11 changes: 8 additions & 3 deletions src/strands/experimental/bidi/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Built-in tools for bidirectional agents."""
"""Built-in tools for bidirectional agents.

from .stop_conversation import stop_conversation
Note: To stop a bidirectional conversation, use the standard `stop` tool from strands_tools:

__all__ = ["stop_conversation"]
from strands_tools import stop
agent = BidiAgent(tools=[stop, ...])

The stop tool sets `request_state["stop_event_loop"] = True`, which signals the
BidiAgent to gracefully close the connection.
"""
16 changes: 0 additions & 16 deletions src/strands/experimental/bidi/tools/stop_conversation.py

This file was deleted.

113 changes: 113 additions & 0 deletions tests/strands/experimental/bidi/agent/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,116 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator):
assert tru_messages == exp_messages

agent.model.send.assert_called_with(tool_result_event)


@pytest.mark.asyncio
async def test_bidi_agent_loop_request_state_initialized_for_tools(loop, agent, agenerator):
"""Test that request_state is initialized before tool execution.

This ensures tools that access request_state (like strands_tools.stop)
work correctly even when invocation_state is not provided by the user.
"""

@tool(name="check_request_state")
async def check_request_state(request_state: dict) -> str:
# Verify request_state exists and is writable
request_state["test_key"] = "test_value"
return f"keys: {list(request_state.keys())}"

agent.tool_registry.register_tool(check_request_state)

tool_use = {"toolUseId": "t2", "name": "check_request_state", "input": {}}
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")

agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))

# Start without providing invocation_state
await loop.start()

tru_events = []
async for event in loop.receive():
tru_events.append(event)
if len(tru_events) >= 3:
break

# Verify tool executed successfully (request_state was available)
tool_result_event = tru_events[1]
assert isinstance(tool_result_event, ToolResultEvent)
assert tool_result_event.tool_result["status"] == "success"
assert "test_key" in tool_result_event.tool_result["content"][0]["text"]


@pytest.mark.asyncio
async def test_bidi_agent_loop_stop_event_loop_flag(loop, agent, agenerator):
"""Test that tools can set stop_event_loop flag to gracefully close connection."""

@tool(name="stop_tool")
async def stop_tool(request_state: dict) -> str:
request_state["stop_event_loop"] = True
return "stopping"

agent.tool_registry.register_tool(stop_tool)

tool_use = {"toolUseId": "t3", "name": "stop_tool", "input": {}}
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")

agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))

await loop.start()

tru_events = []
async for event in loop.receive():
tru_events.append(event)

# Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close
assert len(tru_events) == 4

# Verify tool executed successfully
tool_result_event = tru_events[1]
assert isinstance(tool_result_event, ToolResultEvent)
assert tool_result_event.tool_result["status"] == "success"

# Verify connection close event was emitted
from strands.experimental.bidi.types.events import BidiConnectionCloseEvent

connection_close_event = tru_events[3]
assert isinstance(connection_close_event, BidiConnectionCloseEvent)
assert connection_close_event["reason"] == "user_request"

# Verify model.send was NOT called (tool result not sent to model)
agent.model.send.assert_not_called()


@pytest.mark.asyncio
async def test_bidi_agent_loop_request_state_preserved_with_invocation_state(agent, agenerator):
"""Test that existing invocation_state is preserved when request_state is initialized."""

@tool(name="check_invocation_state")
async def check_invocation_state(custom_key: str) -> str:
return f"custom_key: {custom_key}"

agent.tool_registry.register_tool(check_invocation_state)

tool_use = {"toolUseId": "t4", "name": "check_invocation_state", "input": {"custom_key": "from_state"}}
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")

agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))

loop = agent._loop
# Start with custom invocation_state but no request_state
await loop.start(invocation_state={"custom_data": "preserved"})

tru_events = []
async for event in loop.receive():
tru_events.append(event)
if len(tru_events) >= 3:
break

# Verify tool executed successfully
tool_result_event = tru_events[1]
assert isinstance(tool_result_event, ToolResultEvent)
assert tool_result_event.tool_result["status"] == "success"

# Verify request_state was added without removing custom_data
assert "request_state" in loop._invocation_state
assert loop._invocation_state.get("custom_data") == "preserved"
Loading