diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index 87066dc81..fa881bf32 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -5,10 +5,14 @@ is used—hooks read from the orchestrator directly. """ +import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from typing_extensions import override + from ....hooks import BaseHookEvent +from ....types.interrupt import _Interruptible if TYPE_CHECKING: from ....multiagent.base import MultiAgentBase @@ -28,7 +32,7 @@ class MultiAgentInitializedEvent(BaseHookEvent): @dataclass -class BeforeNodeCallEvent(BaseHookEvent): +class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): """Event triggered before individual node execution starts. Attributes: @@ -48,6 +52,20 @@ class BeforeNodeCallEvent(BaseHookEvent): def _can_write(self, name: str) -> bool: return name in ["cancel_node"] + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) + call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) + return f"v1:before_node_call:{node_id}:{call_id}" + @dataclass class AfterNodeCallEvent(BaseHookEvent): diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 9e3b92ea5..f163d05b5 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -12,6 +12,7 @@ from .._async import run_async from ..agent import AgentResult +from ..interrupt import Interrupt from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput from ..types.traces import AttributeValue @@ -20,22 +21,26 @@ class Status(Enum): - """Execution status for both graphs and nodes.""" + """Execution status for both graphs and nodes. + + Attributes: + PENDING: Task has not started execution yet. + EXECUTING: Task is currently running. + COMPLETED: Task finished successfully. + FAILED: Task encountered an error and could not complete. + INTERRUPTED: Task was interrupted by user. + """ PENDING = "pending" EXECUTING = "executing" COMPLETED = "completed" FAILED = "failed" + INTERRUPTED = "interrupted" @dataclass class NodeResult: - """Unified result from node execution - handles both Agent and nested MultiAgentBase results. - - The status field represents the semantic outcome of the node's work: - - COMPLETED: The node's task was successfully accomplished - - FAILED: The node's task failed or produced an error - """ + """Unified result from node execution - handles both Agent and nested MultiAgentBase results.""" # Core result data - single AgentResult, nested MultiAgentResult, or Exception result: Union[AgentResult, "MultiAgentResult", Exception] @@ -48,6 +53,7 @@ class NodeResult: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 + interrupts: list[Interrupt] = field(default_factory=list) def get_agent_results(self) -> list[AgentResult]: """Get all AgentResult objects from this node, flattened if nested.""" @@ -79,6 +85,7 @@ def to_dict(self) -> dict[str, Any]: "accumulated_usage": self.accumulated_usage, "accumulated_metrics": self.accumulated_metrics, "execution_count": self.execution_count, + "interrupts": [interrupt.to_dict() for interrupt in self.interrupts], } @classmethod @@ -101,6 +108,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": usage = _parse_usage(data.get("accumulated_usage", {})) metrics = _parse_metrics(data.get("accumulated_metrics", {})) + interrupts = [] + for interrupt_data in data.get("interrupts", []): + interrupts.append(Interrupt(**interrupt_data)) + return cls( result=result, execution_time=int(data.get("execution_time", 0)), @@ -108,17 +119,13 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": accumulated_usage=usage, accumulated_metrics=metrics, execution_count=int(data.get("execution_count", 0)), + interrupts=interrupts, ) @dataclass class MultiAgentResult: - """Result from multi-agent execution with accumulated metrics. - - The status field represents the outcome of the MultiAgentBase execution: - - COMPLETED: The execution was successfully accomplished - - FAILED: The execution failed or produced an error - """ + """Result from multi-agent execution with accumulated metrics.""" status: Status = Status.PENDING results: dict[str, NodeResult] = field(default_factory=lambda: {}) @@ -126,6 +133,7 @@ class MultiAgentResult: accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 execution_time: int = 0 + interrupts: list[Interrupt] = field(default_factory=list) @classmethod def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": @@ -137,6 +145,10 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": usage = _parse_usage(data.get("accumulated_usage", {})) metrics = _parse_metrics(data.get("accumulated_metrics", {})) + interrupts = [] + for interrupt_data in data.get("interrupts", []): + interrupts.append(Interrupt(**interrupt_data)) + multiagent_result = cls( status=Status(data["status"]), results=results, @@ -144,6 +156,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": accumulated_metrics=metrics, execution_count=int(data.get("execution_count", 0)), execution_time=int(data.get("execution_time", 0)), + interrupts=interrupts, ) return multiagent_result @@ -157,6 +170,7 @@ def to_dict(self) -> dict[str, Any]: "accumulated_metrics": self.accumulated_metrics, "execution_count": self.execution_count, "execution_time": self.execution_time, + "interrupts": [interrupt.to_dict() for interrupt in self.interrupts], } diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index e87b9592d..6156d332c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -979,7 +979,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if isinstance(self.state.task, str): return [ContentBlock(text=self.state.task)] else: - return self.state.task + return cast(list[ContentBlock], self.state.task) # Combine task with dependency outputs node_input = [] @@ -990,7 +990,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: else: # Add task content blocks with a prefix node_input.append(ContentBlock(text="Original Task:")) - node_input.extend(self.state.task) + node_input.extend(cast(list[ContentBlock], self.state.task)) # Add dependency outputs node_input.append(ContentBlock(text="\nInputs from previous nodes:")) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 6970e0426..076c6ab1a 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -33,12 +33,14 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -61,6 +63,7 @@ class SwarmNode: node_id: str executor: Agent + swarm: Optional["Swarm"] = None _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) @@ -89,7 +92,17 @@ def __repr__(self) -> str: return f"SwarmNode(node_id='{self.node_id}')" def reset_executor_state(self) -> None: - """Reset SwarmNode executor state to initial state when swarm was created.""" + """Reset SwarmNode executor state to initial state when swarm was created. + + If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context. + """ + if self.swarm and self.swarm._interrupt_state.activated: + context = self.swarm._interrupt_state.context[self.node_id] + self.executor.messages = context["messages"] + self.executor.state = AgentState(context["state"]) + self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + return + self.executor.messages = copy.deepcopy(self._initial_messages) self.executor.state = AgentState(self._initial_state.get()) @@ -260,11 +273,14 @@ def __init__( self.shared_context = SharedContext() self.nodes: dict[str, SwarmNode] = {} + self.state = SwarmState( current_node=None, # Placeholder, will be set properly task="", completion_status=Status.PENDING, ) + self._interrupt_state = _InterruptState() + self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) @@ -340,6 +356,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final swarm result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -347,7 +365,10 @@ async def stream_async( logger.debug("starting swarm execution") - if not self._resume_from_session: + if self._resume_from_session or self._interrupt_state.activated: + self.state.completion_status = Status.EXECUTING + self.state.start_time = time.time() + else: # Initialize swarm state with configuration initial_node = self._initial_node() @@ -357,12 +378,11 @@ async def stream_async( completion_status=Status.EXECUTING, shared_context=self.shared_context, ) - else: - self.state.completion_status = Status.EXECUTING - self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "swarm", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: current_node = cast(SwarmNode, self.state.current_node) logger.debug("current_node=<%s> | starting swarm execution with node", current_node.node_id) @@ -374,6 +394,9 @@ async def stream_async( ) async for event in self._execute_swarm(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts = event.interrupts + yield event.as_dict() except Exception: @@ -386,7 +409,7 @@ async def stream_async( self._resume_from_session = False # Yield final result after execution_time is set - result = self._build_result() + result = self._build_result(interrupts) yield MultiAgentResultEvent(result=result).as_dict() async def _stream_with_timeout( @@ -450,7 +473,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None: if node_id in self.nodes: raise ValueError(f"Node ID '{node_id}' is not unique. Each agent must have a unique name.") - self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node) + self.nodes[node_id] = SwarmNode(node_id, node, swarm=self) # Validate entry point if specified if self.entry_point is not None: @@ -650,6 +673,31 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text + def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + self.state.completion_status = Status.INTERRUPTED + + self._interrupt_state.context[node.node_id] = { + "activated": node.executor._interrupt_state.activated, + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute swarm and yield TypedEvent objects.""" try: @@ -684,12 +732,16 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, current_node.node_id, invocation_state) ) # TODO: Implement cancellation token to stop _execute_node from continuing try: + if interrupts: + yield self._activate_interrupt(current_node, interrupts) + break + if before_event.cancel_node: cancel_message = ( before_event.cancel_node @@ -709,6 +761,14 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in node_stream: yield event + stop_event = cast(MultiAgentNodeStopEvent, event) + node_result = stop_event["node_result"] + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(current_node, node_result.interrupts) + break + + self._interrupt_state.deactivate() + self.state.node_history.append(current_node) except Exception: @@ -772,16 +832,20 @@ async def _execute_node( yield start_event try: - # Prepare context for node - context_text = self._build_node_input(node) - node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] + if self._interrupt_state.activated and self._interrupt_state.context[node_name]["activated"]: + node_input = self._interrupt_state.context["responses"] - # Clear handoff message after it's been included in context - self.state.handoff_message = None + else: + # Prepare context for node + context_text = self._build_node_input(node) + node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] - if not isinstance(task, str): - # Include additional ContentBlocks in node input - node_input = node_input + task + # Clear handoff message after it's been included in context + self.state.handoff_message = None + + if not isinstance(task, str): + # Include additional ContentBlocks in node input + node_input = node_input + cast(list[ContentBlock], task) # Execute node with streaming node.reset_executor_state() @@ -799,13 +863,8 @@ async def _execute_node( if result is None: raise ValueError(f"Node '{node_name}' did not produce a result event") - if result.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in swarms") - execution_time = round((time.time() - start_time) * 1000) + status = Status.INTERRUPTED if result.stop_reason == "interrupt" else Status.COMPLETED # Create NodeResult with extracted metrics result_metrics = getattr(result, "metrics", None) @@ -815,10 +874,11 @@ async def _execute_node( node_result = NodeResult( result=result, execution_time=execution_time, - status=Status.COMPLETED, + status=status, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=result.interrupts or [], ) # Store result in state @@ -867,7 +927,7 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: self.state.accumulated_usage["totalTokens"] += node_result.accumulated_usage.get("totalTokens", 0) self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0) - def _build_result(self) -> SwarmResult: + def _build_result(self, interrupts: list[Interrupt]) -> SwarmResult: """Build swarm result from current state.""" return SwarmResult( status=self.state.completion_status, @@ -877,15 +937,18 @@ def _build_result(self) -> SwarmResult: execution_count=len(self.state.node_history), execution_time=self.state.execution_time, node_history=self.state.node_history, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: """Serialize the current swarm state to a dictionary.""" status_str = self.state.completion_status.value - if self.state.handoff_node: - next_nodes = [self.state.handoff_node.node_id] - elif self.state.completion_status == Status.EXECUTING and self.state.current_node: + if self.state.completion_status == Status.EXECUTING and self.state.current_node: + next_nodes = [self.state.current_node.node_id] + elif self.state.completion_status == Status.INTERRUPTED and self.state.current_node: next_nodes = [self.state.current_node.node_id] + elif self.state.handoff_node: + next_nodes = [self.state.handoff_node.node_id] else: next_nodes = [] @@ -899,8 +962,12 @@ def serialize_state(self) -> dict[str, Any]: "current_task": self.state.task, "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, + "handoff_node": self.state.handoff_node.node_id if self.state.handoff_node else None, "handoff_message": self.state.handoff_message, }, + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -916,19 +983,23 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ - if not payload.get("next_nodes_to_execute"): - for node in self.nodes.values(): - node.reset_executor_state() - self.state = SwarmState( - current_node=SwarmNode("", Agent()), - task="", - completion_status=Status.PENDING, - ) - self._resume_from_session = False - return - else: + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + + self._resume_from_session = "next_nodes_to_execute" in payload + if self._resume_from_session: self._from_dict(payload) - self._resume_from_session = True + return + + for node in self.nodes.values(): + node.reset_executor_state() + + self.state = SwarmState( + current_node=SwarmNode("", Agent(), swarm=self), + task="", + completion_status=Status.PENDING, + ) def _from_dict(self, payload: dict[str, Any]) -> None: self.state.completion_status = Status(payload["status"]) @@ -936,6 +1007,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: context = payload["context"] or {} self.shared_context.context = context.get("shared_context") or {} self.state.handoff_message = context.get("handoff_message") + self.state.handoff_node = self.nodes[context["handoff_node"]] if context.get("handoff_node") else None self.state.node_history = [self.nodes[nid] for nid in (payload.get("node_history") or []) if nid in self.nodes] diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 558d3e298..36193dac8 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -543,3 +543,27 @@ def __init__(self, node_id: str, message: str) -> None: "message": message, } ) + + +class MultiAgentNodeInterruptEvent(TypedEvent): + """Event emitted when a node is interrupted.""" + + def __init__(self, node_id: str, interrupts: list[Interrupt]) -> None: + """Set interrupt in the event payload. + + Args: + node_id: Unique identifier for the node generating the event. + interrupts: Interrupts raised by user. + """ + super().__init__( + { + "type": "multiagent_node_interrupt", + "node_id": node_id, + "interrupts": interrupts, + } + ) + + @property + def interrupts(self) -> list[Interrupt]: + """The interrupt instances.""" + return cast(list[Interrupt], self["interrupts"]) diff --git a/src/strands/types/multiagent.py b/src/strands/types/multiagent.py index d9487dbd2..a8fcd4844 100644 --- a/src/strands/types/multiagent.py +++ b/src/strands/types/multiagent.py @@ -3,5 +3,6 @@ from typing import TypeAlias from .content import ContentBlock +from .interrupt import InterruptResponseContent -MultiAgentInput: TypeAlias = str | list[ContentBlock] +MultiAgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py new file mode 100644 index 000000000..85e0ef7fc --- /dev/null +++ b/tests/strands/multiagent/conftest.py @@ -0,0 +1,16 @@ +import pytest + +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + return event.interrupt("test_name", reason="test_reason") + + return Hook() diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 66850fa6f..f2abed9f7 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -8,6 +8,7 @@ from strands.agent.state import AgentState from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks.registry import HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState from strands.session.file_session_manager import FileSessionManager @@ -23,6 +24,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent.messages = [] agent.state = AgentState() # Add state attribute + agent._interrupt_state = _InterruptState() # Add interrupt state agent.tool_registry = Mock() agent.tool_registry.registry = {} agent.tool_registry.process_tools = Mock() @@ -1117,6 +1119,9 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): state = swarm.serialize_state() assert state["type"] == "swarm" assert state["id"] == "default_swarm" + assert state["_internal_state"] == { + "interrupt_state": {"activated": False, "context": {}, "interrupts": {}}, + } assert "status" in state assert "node_history" in state assert "node_results" in state @@ -1130,12 +1135,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): "current_task": "persisted task", "next_nodes_to_execute": ["test_agent"], "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, + "_internal_state": { + "interrupt_state": { + "activated": False, + "context": {"a": 1}, + "interrupts": { + "i1": { + "id": "i1", + "name": "test_name", + "reason": "test_reason", + }, + }, + }, + }, } - swarm._from_dict(persisted_state) + swarm.deserialize_state(persisted_state) assert swarm.state.task == "persisted task" assert swarm.state.handoff_message == "test handoff" assert swarm.shared_context.context["test_agent"]["key"] == "value" + assert swarm._interrupt_state == _InterruptState( + activated=False, + context={"a": 1}, + interrupts={"i1": Interrupt(id="i1", name="test_name", reason="test_reason")}, + ) # Execute swarm to test persistence integration result = await swarm.invoke_async("Test persistence") @@ -1212,3 +1235,115 @@ def cancel_callback(event): tru_status = swarm.state.completion_status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + swarm = Swarm([agent], hooks=[interrupt_hook]) + + multiagent_result = swarm("Test task") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + +def test_swarm_interrupt_on_agent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ), + ] + + agent = create_mock_agent("test_agent", "Task completed") + + swarm = Swarm([agent]) + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + multiagent_result = swarm("Test task") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + swarm._interrupt_state.context["test_agent"]["activated"] = True + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + agent.stream_async.assert_called_once_with(responses, invocation_state={}) diff --git a/tests_integ/interrupts/multiagent/__init__.py b/tests_integ/interrupts/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py new file mode 100644 index 000000000..36fcfef27 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_agent.py @@ -0,0 +1,67 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status +from strands.types.tools import ToolContext + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("test_interrupt", reason="need weather") + return response + + return func + + +@pytest.fixture +def swarm(weather_tool): + weather_agent = Agent(name="weather", tools=[weather_tool]) + + return Swarm([weather_agent]) + + +def test_swarm_interrupt_agent(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py new file mode 100644 index 000000000..be7682082 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -0,0 +1,133 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "info": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_node = "node rejected" + + return Hook() + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def swarm(interrupt_hook, weather_tool): + info_agent = Agent(name="info") + weather_agent = Agent(name="weather", tools=[weather_tool]) + + return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) + + +def test_swarm_interrupt(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 2 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_swarm_interrupt_reject(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + tru_cancel_id = None + async for event in swarm.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + multiagent_result = event["result"] + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_status = multiagent_result.status + exp_status = Status.FAILED + assert tru_status == exp_status + + assert len(multiagent_result.node_history) == 1 + tru_node_id = multiagent_result.node_history[0].node_id + exp_node_id = "info" + assert tru_node_id == exp_node_id diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py new file mode 100644 index 000000000..d6e8cdbf8 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -0,0 +1,77 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status +from strands.session import FileSessionManager +from strands.types.tools import ToolContext + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("test_interrupt", reason="need weather") + return response + + return func + + +@pytest.fixture +def swarm(weather_tool): + weather_agent = Agent(name="weather", tools=[weather_tool]) + return Swarm([weather_agent]) + + +def test_swarm_interrupt_session(weather_tool, tmpdir): + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + swarm = Swarm([weather_agent, summarizer_agent], session_manager=session_manager) + + multiagent_result = swarm("Can you check the weather and then summarize the results?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + swarm = Swarm([weather_agent, summarizer_agent], session_manager=session_manager) + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 2 + summarizer_result = multiagent_result.results["summarizer"] + + summarizer_message = json.dumps(summarizer_result.result.message).lower() + assert "sunny" in summarizer_message