Skip to content

Commit a7b4036

Browse files
committed
feat(a2a): add A2AAgent class as an implementation of the agent interface for remote A2A protocol based agents
1 parent 95ac650 commit a7b4036

File tree

3 files changed

+571
-7
lines changed

3 files changed

+571
-7
lines changed

src/strands/agent/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
1-
"""This package provides the core Agent interface and supporting components for building AI agents with the SDK.
2-
3-
It includes:
4-
5-
- Agent: The main interface for interacting with AI models and tools
6-
- ConversationManager: Classes for managing conversation history and context windows
7-
"""
1+
"""This package provides the core Agent interface and supporting components for building AI agents with the SDK."""
82

93
from .agent import Agent
104
from .agent_result import AgentResult

src/strands/agent/a2a_agent.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""A2A Agent client for Strands Agents.
2+
3+
This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents,
4+
allowing them to be used in graphs, swarms, and other multi-agent patterns.
5+
"""
6+
7+
import logging
8+
from typing import Any, AsyncIterator, cast
9+
from uuid import uuid4
10+
11+
import httpx
12+
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
13+
from a2a.types import AgentCard, Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
14+
from a2a.types import Message as A2AMessage
15+
16+
from .._async import run_async
17+
from ..telemetry.metrics import EventLoopMetrics
18+
from ..types.agent import AgentInput
19+
from ..types.content import ContentBlock, Message
20+
from .agent_result import AgentResult
21+
22+
logger = logging.getLogger(__name__)
23+
24+
DEFAULT_TIMEOUT = 300
25+
26+
27+
class A2AAgent:
28+
"""Client wrapper for remote A2A agents.
29+
30+
Implements the AgentBase protocol to enable remote A2A agents to be used
31+
in graphs, swarms, and other multi-agent patterns.
32+
"""
33+
34+
def __init__(
35+
self,
36+
endpoint: str,
37+
timeout: int = DEFAULT_TIMEOUT,
38+
httpx_client_args: dict[str, Any] | None = None,
39+
):
40+
"""Initialize A2A agent client.
41+
42+
Args:
43+
endpoint: The base URL of the remote A2A agent
44+
timeout: Timeout for HTTP operations in seconds (defaults to 300)
45+
httpx_client_args: Optional dictionary of arguments to pass to httpx.AsyncClient
46+
constructor. Allows custom auth, headers, proxies, etc.
47+
Example: {"headers": {"Authorization": "Bearer token"}}
48+
"""
49+
self.endpoint = endpoint
50+
self.timeout = timeout
51+
self._httpx_client_args: dict[str, Any] = httpx_client_args or {}
52+
53+
if "timeout" not in self._httpx_client_args:
54+
self._httpx_client_args["timeout"] = self.timeout
55+
56+
self._agent_card: AgentCard | None = None
57+
58+
def _get_httpx_client(self) -> httpx.AsyncClient:
59+
"""Get a fresh httpx client for the current operation.
60+
61+
Returns:
62+
Configured httpx.AsyncClient instance.
63+
"""
64+
return httpx.AsyncClient(**self._httpx_client_args)
65+
66+
def _get_client_factory(self, streaming: bool = False) -> ClientFactory:
67+
"""Get a ClientFactory for the current operation.
68+
69+
Args:
70+
streaming: Whether to enable streaming mode.
71+
72+
Returns:
73+
Configured ClientFactory instance.
74+
"""
75+
httpx_client = self._get_httpx_client()
76+
config = ClientConfig(
77+
httpx_client=httpx_client,
78+
streaming=streaming,
79+
)
80+
return ClientFactory(config)
81+
82+
async def _discover_agent_card(self) -> AgentCard:
83+
"""Discover and cache the agent card from the remote endpoint.
84+
85+
Returns:
86+
The discovered AgentCard.
87+
"""
88+
if self._agent_card is not None:
89+
return self._agent_card
90+
91+
httpx_client = self._get_httpx_client()
92+
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint)
93+
self._agent_card = await resolver.get_agent_card()
94+
logger.info("endpoint=<%s> | discovered agent card", self.endpoint)
95+
return self._agent_card
96+
97+
def _convert_input_to_message(self, prompt: AgentInput) -> A2AMessage:
98+
"""Convert AgentInput to A2A Message.
99+
100+
Args:
101+
prompt: Input in various formats (string, message list, or content blocks).
102+
103+
Returns:
104+
A2AMessage ready to send to the remote agent.
105+
106+
Raises:
107+
ValueError: If prompt format is unsupported.
108+
"""
109+
message_id = uuid4().hex
110+
111+
if isinstance(prompt, str):
112+
return A2AMessage(
113+
kind="message",
114+
role=Role.user,
115+
parts=[Part(TextPart(kind="text", text=prompt))],
116+
message_id=message_id,
117+
)
118+
119+
if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)):
120+
if "role" in prompt[0]:
121+
# Message list - extract last user message
122+
for msg in reversed(prompt):
123+
if msg.get("role") == "user":
124+
content = cast(list[ContentBlock], msg.get("content", []))
125+
parts = self._convert_content_blocks_to_parts(content)
126+
return A2AMessage(
127+
kind="message",
128+
role=Role.user,
129+
parts=parts,
130+
message_id=message_id,
131+
)
132+
else:
133+
# ContentBlock list
134+
parts = self._convert_content_blocks_to_parts(cast(list[ContentBlock], prompt))
135+
return A2AMessage(
136+
kind="message",
137+
role=Role.user,
138+
parts=parts,
139+
message_id=message_id,
140+
)
141+
142+
raise ValueError(f"Unsupported input type: {type(prompt)}")
143+
144+
def _convert_content_blocks_to_parts(self, content_blocks: list[ContentBlock]) -> list[Part]:
145+
"""Convert Strands ContentBlocks to A2A Parts.
146+
147+
Args:
148+
content_blocks: List of Strands content blocks.
149+
150+
Returns:
151+
List of A2A Part objects.
152+
"""
153+
parts = []
154+
for block in content_blocks:
155+
if "text" in block:
156+
parts.append(Part(TextPart(kind="text", text=block["text"])))
157+
return parts
158+
159+
def _convert_response_to_agent_result(self, response: Any) -> AgentResult:
160+
"""Convert A2A response to AgentResult.
161+
162+
Args:
163+
response: A2A response (either A2AMessage or tuple of task and update event).
164+
165+
Returns:
166+
AgentResult with extracted content and metadata.
167+
"""
168+
content: list[ContentBlock] = []
169+
170+
if isinstance(response, tuple) and len(response) == 2:
171+
task, update_event = response
172+
if update_event is None and task and hasattr(task, "artifacts"):
173+
# Non-streaming response: extract from task artifacts
174+
for artifact in task.artifacts:
175+
if hasattr(artifact, "parts"):
176+
for part in artifact.parts:
177+
if hasattr(part, "root") and hasattr(part.root, "text"):
178+
content.append({"text": part.root.text})
179+
elif isinstance(response, A2AMessage):
180+
# Direct message response
181+
for part in response.parts:
182+
if hasattr(part, "root") and hasattr(part.root, "text"):
183+
content.append({"text": part.root.text})
184+
185+
message: Message = {
186+
"role": "assistant",
187+
"content": content,
188+
}
189+
190+
return AgentResult(
191+
stop_reason="end_turn",
192+
message=message,
193+
metrics=EventLoopMetrics(),
194+
state={},
195+
)
196+
197+
async def _send_message(
198+
self, prompt: AgentInput, streaming: bool
199+
) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage]:
200+
"""Send message to A2A agent.
201+
202+
Args:
203+
prompt: Input to send to the agent.
204+
streaming: Whether to use streaming mode.
205+
206+
Returns:
207+
Async iterator of A2A events.
208+
209+
Raises:
210+
ValueError: If prompt is None.
211+
"""
212+
if prompt is None:
213+
raise ValueError("prompt is required for A2AAgent")
214+
215+
agent_card = await self._discover_agent_card()
216+
client = self._get_client_factory(streaming=streaming).create(agent_card)
217+
message = self._convert_input_to_message(prompt)
218+
219+
logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending")
220+
return client.send_message(message)
221+
222+
async def invoke_async(
223+
self,
224+
prompt: AgentInput = None,
225+
**kwargs: Any,
226+
) -> AgentResult:
227+
"""Asynchronously invoke the remote A2A agent.
228+
229+
Args:
230+
prompt: Input to the agent (string, message list, or content blocks).
231+
**kwargs: Additional arguments (ignored).
232+
233+
Returns:
234+
AgentResult containing the agent's response.
235+
236+
Raises:
237+
ValueError: If prompt is None.
238+
RuntimeError: If no response received from agent.
239+
"""
240+
async for event in await self._send_message(prompt, streaming=False):
241+
return self._convert_response_to_agent_result(event)
242+
243+
raise RuntimeError("No response received from A2A agent")
244+
245+
def __call__(
246+
self,
247+
prompt: AgentInput = None,
248+
**kwargs: Any,
249+
) -> AgentResult:
250+
"""Synchronously invoke the remote A2A agent.
251+
252+
Args:
253+
prompt: Input to the agent (string, message list, or content blocks).
254+
**kwargs: Additional arguments (ignored).
255+
256+
Returns:
257+
AgentResult containing the agent's response.
258+
259+
Raises:
260+
ValueError: If prompt is None.
261+
RuntimeError: If no response received from agent.
262+
"""
263+
return run_async(lambda: self.invoke_async(prompt, **kwargs))
264+
265+
async def stream_async(
266+
self,
267+
prompt: AgentInput = None,
268+
**kwargs: Any,
269+
) -> AsyncIterator[Any]:
270+
"""Stream agent execution asynchronously.
271+
272+
Args:
273+
prompt: Input to the agent (string, message list, or content blocks).
274+
**kwargs: Additional arguments (ignored).
275+
276+
Yields:
277+
A2A events wrapped in dictionaries with an 'a2a_event' key.
278+
279+
Raises:
280+
ValueError: If prompt is None.
281+
"""
282+
async for event in await self._send_message(prompt, streaming=True):
283+
yield {"a2a_event": event}

0 commit comments

Comments
 (0)