-
Notifications
You must be signed in to change notification settings - Fork 423
feat(multiagent): Add stream_async #961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6c00bbe
b09b539
08141a0
d4f5571
fc0a272
ca59221
60f16b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,14 +19,20 @@ | |
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Tuple | ||
from typing import Any, AsyncIterator, Callable, Tuple, cast | ||
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..agent import Agent, AgentResult | ||
from ..agent import Agent | ||
from ..agent.state import AgentState | ||
from ..telemetry import get_tracer | ||
from ..tools.decorator import tool | ||
from ..types._events import ( | ||
MultiAgentHandoffEvent, | ||
MultiAgentNodeCompleteEvent, | ||
MultiAgentNodeStartEvent, | ||
MultiAgentNodeStreamEvent, | ||
) | ||
from ..types.content import ContentBlock, Messages | ||
from ..types.event_loop import Metrics, Usage | ||
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status | ||
|
@@ -266,12 +272,39 @@ async def invoke_async( | |
) -> SwarmResult: | ||
"""Invoke the swarm asynchronously. | ||
|
||
This method uses stream_async internally and consumes all events until completion, | ||
following the same pattern as the Agent class. | ||
|
||
Args: | ||
task: The task to execute | ||
invocation_state: Additional state/context passed to underlying agents. | ||
Defaults to None to avoid mutable default argument issues. | ||
**kwargs: Keyword arguments allowing backward compatible future changes. | ||
""" | ||
events = self.stream_async(task, invocation_state, **kwargs) | ||
async for event in events: | ||
_ = event | ||
|
||
return cast(SwarmResult, event["result"]) | ||
|
||
async def stream_async( | ||
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Stream events during swarm execution. | ||
|
||
Args: | ||
task: The task to execute | ||
invocation_state: Additional state/context passed to underlying agents. | ||
Defaults to None to avoid mutable default argument issues - a new empty dict | ||
is created if None is provided. | ||
Defaults to None to avoid mutable default argument issues. | ||
**kwargs: Keyword arguments allowing backward compatible future changes. | ||
|
||
Yields: | ||
Dictionary events containing swarm execution information including: | ||
- MultiAgentNodeStartEvent: When an agent begins execution | ||
- MultiAgentNodeStreamEvent: Forwarded agent events with node context | ||
- MultiAgentHandoffEvent: When control is handed off between agents | ||
- MultiAgentNodeCompleteEvent: When an agent completes execution | ||
- Final result event | ||
""" | ||
if invocation_state is None: | ||
invocation_state = {} | ||
|
@@ -282,7 +315,7 @@ async def invoke_async( | |
if self.entry_point: | ||
initial_node = self.nodes[str(self.entry_point.name)] | ||
else: | ||
initial_node = next(iter(self.nodes.values())) # First SwarmNode | ||
initial_node = next(iter(self.nodes.values())) | ||
|
||
self.state = SwarmState( | ||
current_node=initial_node, | ||
|
@@ -303,15 +336,32 @@ async def invoke_async( | |
self.execution_timeout, | ||
) | ||
|
||
await self._execute_swarm(invocation_state) | ||
async for event in self._execute_swarm(invocation_state): | ||
yield event | ||
|
||
# Yield final result (consistent with Agent's AgentResultEvent format) | ||
result = self._build_result() | ||
yield {"result": result} | ||
|
||
except Exception: | ||
logger.exception("swarm execution failed") | ||
self.state.completion_status = Status.FAILED | ||
raise | ||
finally: | ||
self.state.execution_time = round((time.time() - start_time) * 1000) | ||
|
||
return self._build_result() | ||
async def _stream_with_timeout( | ||
self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Wrap an async generator with timeout functionality.""" | ||
while True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit as in graph |
||
try: | ||
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) | ||
yield event | ||
except StopAsyncIteration: | ||
break | ||
except asyncio.TimeoutError: | ||
raise Exception(timeout_message) from None | ||
|
||
def _setup_swarm(self, nodes: list[Agent]) -> None: | ||
"""Initialize swarm configuration.""" | ||
|
@@ -533,14 +583,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str: | |
|
||
return context_text | ||
|
||
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | ||
"""Shared execution logic used by execute_async.""" | ||
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: | ||
"""Execute swarm and yield events.""" | ||
try: | ||
# Main execution loop | ||
while True: | ||
if self.state.completion_status != Status.EXECUTING: | ||
reason = f"Completion status is: {self.state.completion_status}" | ||
logger.debug("reason=<%s> | stopping execution", reason) | ||
logger.debug("reason=<%s> | stopping streaming execution", reason) | ||
break | ||
|
||
should_continue, reason = self.state.should_continue( | ||
|
@@ -568,36 +618,58 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | |
len(self.state.node_history) + 1, | ||
) | ||
|
||
# Store the current node before execution to detect handoffs | ||
previous_node = current_node | ||
|
||
# Execute node with timeout protection | ||
# TODO: Implement cancellation token to stop _execute_node from continuing | ||
try: | ||
await asyncio.wait_for( | ||
self._execute_node(current_node, self.state.task, invocation_state), | ||
timeout=self.node_timeout, | ||
# Execute with timeout wrapper for async generator streaming | ||
node_stream = ( | ||
self._stream_with_timeout( | ||
self._execute_node(current_node, self.state.task, invocation_state), | ||
self.node_timeout, | ||
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", | ||
) | ||
if self.node_timeout is not None | ||
else self._execute_node(current_node, self.state.task, invocation_state) | ||
) | ||
async for event in node_stream: | ||
yield event | ||
|
||
self.state.node_history.append(current_node) | ||
|
||
logger.debug("node=<%s> | node execution completed", current_node.node_id) | ||
|
||
# Check if the current node is still the same after execution | ||
# If it is, then no handoff occurred and we consider the swarm complete | ||
if self.state.current_node == current_node: | ||
# Check if handoff occurred during execution | ||
if self.state.current_node != previous_node: | ||
# Emit handoff event | ||
handoff_event = MultiAgentHandoffEvent( | ||
from_node=previous_node.node_id, | ||
to_node=self.state.current_node.node_id, | ||
message=self.state.handoff_message or "Agent handoff occurred", | ||
) | ||
yield handoff_event.as_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noded for graphs, can we do the as_dict() stuff higher up? I'd like to be strongly typed at the lower-levels and only do dict conversion at the api layer |
||
logger.debug( | ||
"from_node=<%s>, to_node=<%s> | handoff detected", | ||
previous_node.node_id, | ||
self.state.current_node.node_id, | ||
) | ||
else: | ||
# No handoff occurred, mark swarm as complete | ||
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) | ||
self.state.completion_status = Status.COMPLETED | ||
break | ||
|
||
except asyncio.TimeoutError: | ||
logger.exception( | ||
"node=<%s>, timeout=<%s>s | node execution timed out after timeout", | ||
current_node.node_id, | ||
self.node_timeout, | ||
) | ||
self.state.completion_status = Status.FAILED | ||
break | ||
|
||
except Exception: | ||
logger.exception("node=<%s> | node execution failed", current_node.node_id) | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't we use an exception type for this? This seems hacky |
||
# Check if this is a timeout exception | ||
if "timed out after" in str(e): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we create a variable for this so we don't accidentally change the message in exception |
||
logger.exception( | ||
"node=<%s>, timeout=<%s>s | node execution timed out", | ||
current_node.node_id, | ||
self.node_timeout, | ||
) | ||
else: | ||
logger.exception("node=<%s> | node execution failed", current_node.node_id) | ||
self.state.completion_status = Status.FAILED | ||
break | ||
|
||
|
@@ -615,11 +687,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | |
|
||
async def _execute_node( | ||
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] | ||
) -> AgentResult: | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Execute swarm node.""" | ||
start_time = time.time() | ||
node_name = node.node_id | ||
|
||
# Emit node start event | ||
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent") | ||
yield start_event.as_dict() | ||
|
||
try: | ||
# Prepare context for node | ||
context_text = self._build_node_input(node) | ||
|
@@ -632,11 +708,22 @@ async def _execute_node( | |
# Include additional ContentBlocks in node input | ||
node_input = node_input + task | ||
|
||
# Execute node | ||
result = None | ||
# Execute node with streaming | ||
node.reset_executor_state() | ||
# Unpacking since this is the agent class. Other executors should not unpack | ||
result = await node.executor.invoke_async(node_input, **invocation_state) | ||
|
||
# Stream agent events with node context and capture final result | ||
result = None | ||
async for event in node.executor.stream_async(node_input, **invocation_state): | ||
# Forward agent events with node context | ||
wrapped_event = MultiAgentNodeStreamEvent(node_name, event) | ||
yield wrapped_event.as_dict() | ||
# Capture the final result event | ||
if "result" in event: | ||
result = event["result"] | ||
|
||
# Use the captured result from streaming to avoid double execution | ||
if result is None: | ||
raise ValueError(f"Node '{node_name}' did not produce a result event") | ||
|
||
execution_time = round((time.time() - start_time) * 1000) | ||
|
||
|
@@ -664,15 +751,17 @@ async def _execute_node( | |
# Accumulate metrics | ||
self._accumulate_metrics(node_result) | ||
|
||
return result | ||
# Emit node complete event | ||
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) | ||
yield complete_event.as_dict() | ||
|
||
except Exception as e: | ||
execution_time = round((time.time() - start_time) * 1000) | ||
logger.exception("node=<%s> | node execution failed", node_name) | ||
|
||
# Create a NodeResult for the failed node | ||
node_result = NodeResult( | ||
result=e, # Store exception as result | ||
result=e, | ||
execution_time=execution_time, | ||
status=Status.FAILED, | ||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||
|
@@ -683,6 +772,10 @@ async def _execute_node( | |
# Store result in state | ||
self.state.results[node_name] = node_result | ||
|
||
# Still emit complete event for failures | ||
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) | ||
yield complete_event.as_dict() | ||
|
||
raise | ||
|
||
def _accumulate_metrics(self, node_result: NodeResult) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an AgentResult as it is for single-agent streaming? If not, can we rename so that it doesn't conflict with different types