-
Notifications
You must be signed in to change notification settings - Fork 505
feat(a2a): add A2AAgent class as an implementation of the agent interface for remote A2A protocol based agents #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a7b4036
5034f18
64520b7
90124cf
182f9a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,253 @@ | ||
| """A2A Agent client for Strands Agents. | ||
|
|
||
| This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, | ||
| allowing them to be used standalone or as part of multi-agent patterns. | ||
|
|
||
| A2AAgent can be used to get the Agent Card and interact with the agent. | ||
| """ | ||
|
|
||
| import logging | ||
| from typing import Any, AsyncIterator | ||
|
|
||
| import httpx | ||
| from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory | ||
| from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent | ||
|
|
||
| from .._async import run_async | ||
| from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result | ||
| from ..types._events import AgentResultEvent | ||
| from ..types.a2a import A2AResponse, A2AStreamEvent | ||
| from ..types.agent import AgentInput | ||
| from .agent_result import AgentResult | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| DEFAULT_TIMEOUT = 300 | ||
|
|
||
|
|
||
| class A2AAgent: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work when the client is configured for streaming and when it is not (both cases)? |
||
| """Client wrapper for remote A2A agents.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| endpoint: str, | ||
| *, | ||
| name: str | None = None, | ||
| description: str = "", | ||
| timeout: int = DEFAULT_TIMEOUT, | ||
| a2a_client_factory: ClientFactory | None = None, | ||
| ): | ||
| """Initialize A2A agent. | ||
|
|
||
| Args: | ||
| endpoint: The base URL of the remote A2A agent. | ||
| name: Agent name. If not provided, will be populated from agent card. | ||
| description: Agent description. If empty, will be populated from agent card. | ||
| timeout: Timeout for HTTP operations in seconds (defaults to 300). | ||
| a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, | ||
| it will be used to create the A2A client after discovering the agent card. | ||
| """ | ||
| self.endpoint = endpoint | ||
| self.name = name | ||
| self.description = description | ||
| self.timeout = timeout | ||
| self._httpx_client: httpx.AsyncClient | None = None | ||
| self._owns_client = a2a_client_factory is None | ||
| self._agent_card: AgentCard | None = None | ||
| self._a2a_client: Client | None = None | ||
| self._a2a_client_factory: ClientFactory | None = a2a_client_factory | ||
|
|
||
| def _get_httpx_client(self) -> httpx.AsyncClient: | ||
| """Get or create the httpx client for this agent. | ||
|
|
||
| Returns: | ||
| Configured httpx.AsyncClient instance. | ||
| """ | ||
| if self._httpx_client is None: | ||
| self._httpx_client = httpx.AsyncClient(timeout=self.timeout) | ||
| return self._httpx_client | ||
|
|
||
| async def _get_agent_card(self) -> AgentCard: | ||
| """Discover and cache the agent card from the remote endpoint. | ||
|
|
||
| Returns: | ||
| The discovered AgentCard. | ||
| """ | ||
| if self._agent_card is not None: | ||
| return self._agent_card | ||
|
|
||
| httpx_client = self._get_httpx_client() | ||
| resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint) | ||
| self._agent_card = await resolver.get_agent_card() | ||
|
|
||
| # Populate name from card if not set | ||
| if self.name is None and self._agent_card.name: | ||
| self.name = self._agent_card.name | ||
|
|
||
| # Populate description from card if not set | ||
| if not self.description and self._agent_card.description: | ||
| self.description = self._agent_card.description | ||
|
|
||
| logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be debug I think? I'm not sure we really use info anywhere throughout |
||
| return self._agent_card | ||
|
|
||
| async def _get_a2a_client(self) -> Client: | ||
| """Get or create the A2A client for this agent. | ||
|
|
||
| Returns: | ||
| Configured A2A client instance. | ||
| """ | ||
| if self._a2a_client is None: | ||
| agent_card = await self._get_agent_card() | ||
|
|
||
| if self._a2a_client_factory is not None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts on simplifying to: |
||
| # Use provided factory | ||
| factory = self._a2a_client_factory | ||
| else: | ||
| # Create default factory | ||
| httpx_client = self._get_httpx_client() | ||
| config = ClientConfig(httpx_client=httpx_client, streaming=False) | ||
| factory = ClientFactory(config) | ||
|
|
||
| self._a2a_client = factory.create(agent_card) | ||
| return self._a2a_client | ||
|
|
||
| async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: | ||
| """Send message to A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to send to the agent. | ||
|
|
||
| Returns: | ||
| Async iterator of A2A events. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| """ | ||
| if prompt is None: | ||
| raise ValueError("prompt is required for A2AAgent") | ||
|
|
||
| client = await self._get_a2a_client() | ||
| message = convert_input_to_message(prompt) | ||
|
|
||
| logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) | ||
| return client.send_message(message) | ||
|
|
||
| def _is_complete_event(self, event: A2AResponse) -> bool: | ||
| """Check if an A2A event represents a complete response. | ||
|
|
||
| Args: | ||
| event: A2A event. | ||
|
|
||
| Returns: | ||
| True if the event represents a complete response. | ||
| """ | ||
| # Direct Message is always complete | ||
| if isinstance(event, Message): | ||
| return True | ||
|
|
||
| # Handle tuple responses (Task, UpdateEvent | None) | ||
| if isinstance(event, tuple) and len(event) == 2: | ||
| task, update_event = event | ||
|
|
||
| # Initial task response (no update event) | ||
| if update_event is None: | ||
| return True | ||
|
|
||
| # Artifact update with last_chunk flag | ||
| if isinstance(update_event, TaskArtifactUpdateEvent): | ||
| if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: | ||
| return update_event.last_chunk | ||
| return False | ||
|
|
||
| # Status update with completed state | ||
| if isinstance(update_event, TaskStatusUpdateEvent): | ||
| if update_event.status and hasattr(update_event.status, "state"): | ||
| return update_event.status.state == TaskState.completed | ||
|
|
||
| return False | ||
|
|
||
| async def invoke_async( | ||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AgentResult: | ||
| """Asynchronously invoke the remote A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Returns: | ||
| AgentResult containing the agent's response. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| RuntimeError: If no response received from agent. | ||
| """ | ||
| async for event in await self._send_message(prompt): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we delegate this implementation to |
||
| return convert_response_to_agent_result(event) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that first event is the complete response though, right? What if it's not? |
||
|
|
||
| raise RuntimeError("No response received from A2A agent") | ||
|
|
||
| def __call__( | ||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AgentResult: | ||
| """Synchronously invoke the remote A2A agent. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Returns: | ||
| AgentResult containing the agent's response. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| RuntimeError: If no response received from agent. | ||
| """ | ||
| return run_async(lambda: self.invoke_async(prompt, **kwargs)) | ||
|
|
||
| async def stream_async( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stream async needs to return AgentResult for it to work with graph/swarm https://github.com/strands-agents/sdk-python/blob/main/src/strands/multiagent/graph.py#L810
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have this documented anywhere? Or is it only applicable internally? |
||
| self, | ||
| prompt: AgentInput = None, | ||
| **kwargs: Any, | ||
| ) -> AsyncIterator[Any]: | ||
| """Stream agent execution asynchronously. | ||
|
|
||
| Args: | ||
| prompt: Input to the agent (string, message list, or content blocks). | ||
| **kwargs: Additional arguments (ignored). | ||
|
|
||
| Yields: | ||
| A2A events and a final AgentResult event. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt is None. | ||
| """ | ||
| last_event = None | ||
| last_complete_event = None | ||
|
|
||
| async for event in await self._send_message(prompt): | ||
| last_event = event | ||
| if self._is_complete_event(event): | ||
| last_complete_event = event | ||
| yield A2AStreamEvent(event) | ||
|
|
||
| # Use the last complete event if available, otherwise fall back to last event | ||
| final_event = last_complete_event if last_complete_event is not None else last_event | ||
|
|
||
| if final_event is not None: | ||
| result = convert_response_to_agent_result(final_event) | ||
| yield AgentResultEvent(result) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to yield one final event that becomes the "result" of the tool: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/tools/python-tools/#tool-streaming I think as is, this results in a stringified version of AgentResultEvent - is that what we want? |
||
|
|
||
| def __del__(self) -> None: | ||
| """Clean up resources when agent is garbage collected.""" | ||
| if self._owns_client and self._httpx_client is not None: | ||
| try: | ||
| client = self._httpx_client | ||
| run_async(lambda: client.aclose()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you check to ensure that this works correctly? IIRC, @dbschmigelski indicated there were problems trying to do async stuff from del |
||
| except Exception: | ||
| pass # Best effort cleanup, ignore errors in __del__ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """Conversion functions between Strands and A2A types.""" | ||
|
|
||
| from typing import cast | ||
| from uuid import uuid4 | ||
|
|
||
| from a2a.types import Message as A2AMessage | ||
| from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart | ||
|
|
||
| from ...agent.agent_result import AgentResult | ||
| from ...telemetry.metrics import EventLoopMetrics | ||
| from ...types.a2a import A2AResponse | ||
| from ...types.agent import AgentInput | ||
| from ...types.content import ContentBlock, Message | ||
|
|
||
|
|
||
| def convert_input_to_message(prompt: AgentInput) -> A2AMessage: | ||
| """Convert AgentInput to A2A Message. | ||
|
|
||
| Args: | ||
| prompt: Input in various formats (string, message list, or content blocks). | ||
|
|
||
| Returns: | ||
| A2AMessage ready to send to the remote agent. | ||
|
|
||
| Raises: | ||
| ValueError: If prompt format is unsupported. | ||
| """ | ||
| message_id = uuid4().hex | ||
|
|
||
| if isinstance(prompt, str): | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=[Part(TextPart(kind="text", text=prompt))], | ||
| message_id=message_id, | ||
| ) | ||
|
|
||
| if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): | ||
| if "role" in prompt[0]: | ||
| for msg in reversed(prompt): | ||
| if msg.get("role") == "user": | ||
| content = cast(list[ContentBlock], msg.get("content", [])) | ||
| parts = convert_content_blocks_to_parts(content) | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=parts, | ||
| message_id=message_id, | ||
| ) | ||
| else: | ||
| parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) | ||
| return A2AMessage( | ||
| kind="message", | ||
| role=Role.user, | ||
| parts=parts, | ||
| message_id=message_id, | ||
| ) | ||
|
|
||
| raise ValueError(f"Unsupported input type: {type(prompt)}") | ||
|
|
||
|
|
||
| def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: | ||
| """Convert Strands ContentBlocks to A2A Parts. | ||
|
|
||
| Args: | ||
| content_blocks: List of Strands content blocks. | ||
|
|
||
| Returns: | ||
| List of A2A Part objects. | ||
| """ | ||
| parts = [] | ||
| for block in content_blocks: | ||
| if "text" in block: | ||
| parts.append(Part(TextPart(kind="text", text=block["text"]))) | ||
| return parts | ||
|
|
||
|
|
||
| def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: | ||
| """Convert A2A response to AgentResult. | ||
|
|
||
| Args: | ||
| response: A2A response (either A2AMessage or tuple of task and update event). | ||
|
|
||
| Returns: | ||
| AgentResult with extracted content and metadata. | ||
| """ | ||
| content: list[ContentBlock] = [] | ||
|
|
||
| if isinstance(response, tuple) and len(response) == 2: | ||
| task, update_event = response | ||
|
|
||
| # Handle artifact updates | ||
| if isinstance(update_event, TaskArtifactUpdateEvent): | ||
| if update_event.artifact and hasattr(update_event.artifact, "parts"): | ||
| for part in update_event.artifact.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| # Handle status updates with messages | ||
| elif isinstance(update_event, TaskStatusUpdateEvent): | ||
| if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: | ||
| for part in update_event.status.message.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| # Handle initial task or task without update event | ||
| elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: | ||
| for artifact in task.artifacts: | ||
| if hasattr(artifact, "parts"): | ||
| for part in artifact.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
| elif isinstance(response, A2AMessage): | ||
| for part in response.parts: | ||
| if hasattr(part, "root") and hasattr(part.root, "text"): | ||
| content.append({"text": part.root.text}) | ||
|
|
||
| message: Message = { | ||
| "role": "assistant", | ||
| "content": content, | ||
| } | ||
|
|
||
| return AgentResult( | ||
| stop_reason="end_turn", | ||
| message=message, | ||
| metrics=EventLoopMetrics(), | ||
| state={}, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Callout that I appreciate this PR description & code sample.