Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from typing import Any, AsyncIterator, Union

from ..agent import AgentResult
from ..types.content import ContentBlock
Expand Down Expand Up @@ -97,6 +97,32 @@ async def invoke_async(
"""
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

This default implementation provides backward compatibility by executing
invoke_async and yielding a single result event. Subclasses can override
this method to provide true streaming capabilities.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.

Yields:
Dictionary events containing multi-agent execution information including:
- Multi-agent coordination events (node start/complete, handoffs)
- Forwarded single-agent events with node context
- Final result event
"""
# Default implementation for backward compatibility
# Execute invoke_async and yield the result as a single event
result = await self.invoke_async(task, invocation_state, **kwargs)
yield {"result": result}

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
Expand Down
237 changes: 198 additions & 39 deletions src/strands/multiagent/graph.py

Large diffs are not rendered by default.

163 changes: 128 additions & 35 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple
from typing import Any, AsyncIterator, Callable, Tuple, cast

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent import Agent
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCompleteEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStreamEvent,
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
Expand Down Expand Up @@ -266,12 +272,39 @@ async def invoke_async(
) -> SwarmResult:
"""Invoke the swarm asynchronously.

This method uses stream_async internally and consumes all events until completion,
following the same pattern as the Agent class.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
async for event in events:
_ = event

return cast(SwarmResult, event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during swarm execution.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.

Yields:
Dictionary events containing swarm execution information including:
- MultiAgentNodeStartEvent: When an agent begins execution
- MultiAgentNodeStreamEvent: Forwarded agent events with node context
- MultiAgentHandoffEvent: When control is handed off between agents
- MultiAgentNodeCompleteEvent: When an agent completes execution
- Final result event
"""
if invocation_state is None:
invocation_state = {}
Expand All @@ -282,7 +315,7 @@ async def invoke_async(
if self.entry_point:
initial_node = self.nodes[str(self.entry_point.name)]
else:
initial_node = next(iter(self.nodes.values())) # First SwarmNode
initial_node = next(iter(self.nodes.values()))

self.state = SwarmState(
current_node=initial_node,
Expand All @@ -303,15 +336,32 @@ async def invoke_async(
self.execution_timeout,
)

await self._execute_swarm(invocation_state)
async for event in self._execute_swarm(invocation_state):
yield event

# Yield final result (consistent with Agent's AgentResultEvent format)
result = self._build_result()
yield {"result": result}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yield final result (consistent with Agent's AgentResultEvent format)

Is this an AgentResult as it is for single-agent streaming? If not, can we rename so that it doesn't conflict with different types


except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)

return self._build_result()
async def _stream_with_timeout(
self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str
) -> AsyncIterator[dict[str, Any]]:
"""Wrap an async generator with timeout functionality."""
while True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit as in graph

try:
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout)
yield event
except StopAsyncIteration:
break
except asyncio.TimeoutError:
raise Exception(timeout_message) from None

def _setup_swarm(self, nodes: list[Agent]) -> None:
"""Initialize swarm configuration."""
Expand Down Expand Up @@ -533,14 +583,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
"""Execute swarm and yield events."""
try:
# Main execution loop
while True:
if self.state.completion_status != Status.EXECUTING:
reason = f"Completion status is: {self.state.completion_status}"
logger.debug("reason=<%s> | stopping execution", reason)
logger.debug("reason=<%s> | stopping streaming execution", reason)
break

should_continue, reason = self.state.should_continue(
Expand Down Expand Up @@ -568,36 +618,58 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
len(self.state.node_history) + 1,
)

# Store the current node before execution to detect handoffs
previous_node = current_node

# Execute node with timeout protection
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
# Execute with timeout wrapper for async generator streaming
node_stream = (
self._stream_with_timeout(
self._execute_node(current_node, self.state.task, invocation_state),
self.node_timeout,
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s",
)
if self.node_timeout is not None
else self._execute_node(current_node, self.state.task, invocation_state)
)
async for event in node_stream:
yield event

self.state.node_history.append(current_node)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if the current node is still the same after execution
# If it is, then no handoff occurred and we consider the swarm complete
if self.state.current_node == current_node:
# Check if handoff occurred during execution
if self.state.current_node != previous_node:
# Emit handoff event
handoff_event = MultiAgentHandoffEvent(
from_node=previous_node.node_id,
to_node=self.state.current_node.node_id,
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event.as_dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noded for graphs, can we do the as_dict() stuff higher up? I'd like to be strongly typed at the lower-levels and only do dict conversion at the api layer

logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
self.state.current_node.node_id,
)
else:
# No handoff occurred, mark swarm as complete
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break

except asyncio.TimeoutError:
logger.exception(
"node=<%s>, timeout=<%s>s | node execution timed out after timeout",
current_node.node_id,
self.node_timeout,
)
self.state.completion_status = Status.FAILED
break

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use an exception type for this? This seems hacky

# Check if this is a timeout exception
if "timed out after" in str(e):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we create a variable for this so we don't accidentally change the message in exception

logger.exception(
"node=<%s>, timeout=<%s>s | node execution timed out",
current_node.node_id,
self.node_timeout,
)
else:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
break

Expand All @@ -615,11 +687,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:

async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
) -> AsyncIterator[dict[str, Any]]:
"""Execute swarm node."""
start_time = time.time()
node_name = node.node_id

# Emit node start event
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent")
yield start_event.as_dict()

try:
# Prepare context for node
context_text = self._build_node_input(node)
Expand All @@ -632,11 +708,22 @@ async def _execute_node(
# Include additional ContentBlocks in node input
node_input = node_input + task

# Execute node
result = None
# Execute node with streaming
node.reset_executor_state()
# Unpacking since this is the agent class. Other executors should not unpack
result = await node.executor.invoke_async(node_input, **invocation_state)

# Stream agent events with node context and capture final result
result = None
async for event in node.executor.stream_async(node_input, **invocation_state):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node_name, event)
yield wrapped_event.as_dict()
# Capture the final result event
if "result" in event:
result = event["result"]

# Use the captured result from streaming to avoid double execution
if result is None:
raise ValueError(f"Node '{node_name}' did not produce a result event")

execution_time = round((time.time() - start_time) * 1000)

Expand Down Expand Up @@ -664,15 +751,17 @@ async def _execute_node(
# Accumulate metrics
self._accumulate_metrics(node_result)

return result
# Emit node complete event
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time)
yield complete_event.as_dict()

except Exception as e:
execution_time = round((time.time() - start_time) * 1000)
logger.exception("node=<%s> | node execution failed", node_name)

# Create a NodeResult for the failed node
node_result = NodeResult(
result=e, # Store exception as result
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
Expand All @@ -683,6 +772,10 @@ async def _execute_node(
# Store result in state
self.state.results[node_name] = node_result

# Still emit complete event for failures
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time)
yield complete_event.as_dict()

raise

def _accumulate_metrics(self, node_result: NodeResult) -> None:
Expand Down
59 changes: 59 additions & 0 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,62 @@ def __init__(self, reason: str | Exception) -> None:
class AgentResultEvent(TypedEvent):
def __init__(self, result: "AgentResult"):
super().__init__({"result": result})


class MultiAgentNodeStartEvent(TypedEvent):
"""Event emitted when a node begins execution in multi-agent context."""

def __init__(self, node_id: str, node_type: str) -> None:
"""Initialize with node information.

Args:
node_id: Unique identifier for the node
node_type: Type of node ("agent", "swarm", "graph")
"""
super().__init__({"multi_agent_node_start": True, "node_id": node_id, "node_type": node_type})


class MultiAgentNodeCompleteEvent(TypedEvent):
"""Event emitted when a node completes execution."""

def __init__(self, node_id: str, execution_time: int) -> None:
"""Initialize with completion information.

Args:
node_id: Unique identifier for the node
execution_time: Execution time in milliseconds
"""
super().__init__({"multi_agent_node_complete": True, "node_id": node_id, "execution_time": execution_time})


class MultiAgentHandoffEvent(TypedEvent):
"""Event emitted during agent handoffs in Swarm."""

def __init__(self, from_node: str, to_node: str, message: str) -> None:
"""Initialize with handoff information.

Args:
from_node: Node ID handing off control
to_node: Node ID receiving control
message: Handoff message explaining the transfer
"""
super().__init__({"multi_agent_handoff": True, "from_node": from_node, "to_node": to_node, "message": message})


class MultiAgentNodeStreamEvent(TypedEvent):
"""Event emitted during node execution - forwards agent events with node context."""

def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
"""Initialize with node context and agent event.

Args:
node_id: Unique identifier for the node generating the event
agent_event: The original agent event data
"""
super().__init__(
{
"multi_agent_node_stream": True,
"node_id": node_id,
**agent_event, # Forward all original agent event data
}
)
Loading
Loading