From fbb30a3d7e9d45cc00de511bf0e499fd900d6a53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Thu, 26 Mar 2026 16:02:16 +0100 Subject: [PATCH] feat: replace custom ContextStore with A2A-native task history MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use A2A Tasks as the single source of truth for conversation history, eliminating the duplicate ContextStore abstraction. adk-server: - Implement on_list_tasks in A2A proxy (forwards to agent transport) - Remove context history endpoints (POST/GET/DELETE /contexts/{id}/history) - Remove title generation job and worker - Remove history-related service, repository, and schema code adk-py: - Delete ContextStore, InMemoryContextStore, PlatformContextStore - Rewire RunContext to use TaskStore (load_history reads from task store) - Remove store(), store_sync(), delete_history_from_id() from RunContext - Remove context_store parameter from create_app() and Server.serve() - Remove ContextHistoryItem and history methods from platform SDK adk-ui: - Add listTasks to A2A JSON-RPC client - Create fetchTasksForContext for server-side task fetching - Add convertTasksToUIMessages for task-to-UI conversion - Rewire AgentRun/PlatformContextProvider/MessagesProvider to use tasks - Remove useListContextHistory, context history API functions and types Closes #116 Assisted-By: Claude (Anthropic AI) Signed-off-by: Radek Ježek --- agents/canvas/src/canvas/agent.py | 7 - agents/chat/src/chat/agent.py | 22 +-- .../src/content_builder/agent.py | 4 - agents/rag/src/rag/agent.py | 5 - apps/adk-py/examples/canvas_ui_code_agent.py | 6 - apps/adk-py/examples/canvas_ui_test_agent.py | 6 - apps/adk-py/examples/citation_agent.py | 7 +- .../examples/citation_agent_artifact.py | 8 +- apps/adk-py/examples/history.py | 6 - apps/adk-py/examples/history_framework.py | 5 - apps/adk-py/examples/oauth.py | 5 - apps/adk-py/examples/trajectory_agent.py | 20 +-- .../src/kagenti_adk/platform/context.py | 114 ++------------- apps/adk-py/src/kagenti_adk/server/agent.py | 14 +- apps/adk-py/src/kagenti_adk/server/app.py | 7 +- apps/adk-py/src/kagenti_adk/server/context.py | 60 ++------ apps/adk-py/src/kagenti_adk/server/server.py | 4 - .../kagenti_adk/server/store/context_store.py | 36 ----- .../server/store/memory_context_store.py | 61 -------- .../server/store/platform_context_store.py | 53 ------- apps/adk-py/tests/e2e/conftest.py | 11 +- apps/adk-py/tests/e2e/test_history.py | 108 ++++---------- .../src/adk_server/api/routes/contexts.py | 36 +---- .../src/adk_server/api/schema/contexts.py | 5 +- .../adk_server/domain/repositories/context.py | 16 +- .../persistence/repositories/context.py | 108 +------------- .../src/adk_server/jobs/procrastinate.py | 2 - apps/adk-server/src/adk_server/jobs/queues.py | 1 - .../src/adk_server/jobs/tasks/context.py | 18 --- apps/adk-server/src/adk_server/run_workers.py | 5 - .../adk_server/service_layer/services/a2a.py | 3 +- .../service_layer/services/contexts.py | 135 +---------------- apps/adk-server/tests/e2e/agents/conftest.py | 5 - .../tests/e2e/agents/test_context_store.py | 123 ---------------- .../tests/e2e/routes/test_contexts.py | 138 +----------------- apps/adk-ui/src/api/a2a/jsonrpc-client.ts | 19 +++ apps/adk-ui/src/api/a2a/list-tasks.ts | 79 ++++++++++ apps/adk-ui/src/modules/history/utils.ts | 50 ++++++- .../contexts/Messages/MessagesProvider.tsx | 64 ++------ .../contexts/Messages/messages-context.ts | 9 +- .../modules/platform-context/api/constants.ts | 4 +- .../src/modules/platform-context/api/index.ts | 13 -- .../src/modules/platform-context/api/keys.ts | 5 +- .../api/queries/useListContextHistory.ts | 60 -------- .../src/modules/platform-context/api/types.ts | 5 - .../src/modules/platform-context/api/utils.ts | 10 -- .../contexts/PlatformContextProvider.tsx | 8 +- .../contexts/platform-context.ts | 9 +- .../src/modules/runs/components/AgentRun.tsx | 16 +- .../contexts/agent-run/AgentRunProvider.tsx | 1 - docs/development/agent-integration/canvas.mdx | 3 - .../agent-integration/multi-turn.mdx | 121 ++------------- docs/development/agent-integration/rag.mdx | 44 +++--- .../src/canvas_with_llm/agent.py | 3 - .../src/advanced_history/agent.py | 5 - .../basic-history/src/basic_history/agent.py | 6 - .../src/streaming_agent_history/agent.py | 18 +-- .../src/conversation_rag_agent/agent.py | 28 ++-- skills/kagenti-adk-wrapper/SKILL.md | 9 +- .../references/wrapper-entrypoint.md | 5 +- 60 files changed, 310 insertions(+), 1448 deletions(-) delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/context_store.py delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py delete mode 100644 apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py delete mode 100644 apps/adk-server/tests/e2e/agents/test_context_store.py create mode 100644 apps/adk-ui/src/api/a2a/list-tasks.ts delete mode 100644 apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts delete mode 100644 apps/adk-ui/src/modules/platform-context/api/utils.ts diff --git a/agents/canvas/src/canvas/agent.py b/agents/canvas/src/canvas/agent.py index 7c9ba8b1..a4e9308f 100644 --- a/agents/canvas/src/canvas/agent.py +++ b/agents/canvas/src/canvas/agent.py @@ -50,7 +50,6 @@ async def canvas_agent( yield "Can't run without a LLM." return - await context.store(message) edit_request = await canvas.parse_canvas_edit_request(message=message) user_text_content = _get_text(message) @@ -132,12 +131,6 @@ async def canvas_agent( parts=[TextPart(text=content_delta)], ) - final_artifact = AgentArtifact( - artifact_id=artifact.artifact_id, - name=artifact.name, - parts=[TextPart(text=buffer)], - ) - await context.store(final_artifact) def serve(): diff --git a/agents/chat/src/chat/agent.py b/agents/chat/src/chat/agent.py index 66a868cb..b7321f84 100644 --- a/agents/chat/src/chat/agent.py +++ b/agents/chat/src/chat/agent.py @@ -24,11 +24,10 @@ PlatformApiExtensionServer, PlatformApiExtensionSpec, ) -from kagenti_adk.a2a.types import AgentArtifact, AgentMessage +from kagenti_adk.a2a.types import AgentArtifact from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.middleware.platform_auth_backend import PlatformAuthBackend -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.events import ( @@ -36,7 +35,7 @@ RequirementAgentSuccessEvent, ) from beeai_framework.agents.requirement.utils._tool import FinalAnswerTool -from beeai_framework.backend import AssistantMessage, ChatModelParameters +from beeai_framework.backend import ChatModelParameters from beeai_framework.errors import FrameworkError from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware from beeai_framework.tools import AnyTool, Tool @@ -45,7 +44,6 @@ from beeai_framework.tools.weather import OpenMeteoTool from openinference.instrumentation.beeai import BeeAIInstrumentor -from chat.helpers.citations import extract_citations from chat.helpers.trajectory import TrajectoryContent from chat.tools.files.file_creator import FileCreatorTool, FileCreatorToolOutput from chat.tools.files.file_reader import FileReaderTool @@ -149,8 +147,6 @@ async def chat( _p: Annotated[PlatformApiExtensionServer, PlatformApiExtensionSpec()], ): """Agent with memory and access to web search, Wikipedia, and weather.""" - await context.store(input) - # Send initial trajectory yield trajectory.trajectory_metadata(title="Starting", content="Received your request") @@ -220,7 +216,6 @@ async def chat( middlewares=[GlobalTrajectoryMiddleware(included=[Tool])], ) - final_answer: AssistantMessage | None = None new_messages = [to_framework_message(item, extracted_files) for item in history] try: @@ -244,8 +239,6 @@ async def chat( case RequirementAgentFinalAnswerEvent(delta=delta): yield delta case RequirementAgentSuccessEvent(state=state): - final_answer = state.answer - last_step = state.steps[-1] if last_step.tool and last_step.tool.name == FinalAnswerTool.name: # internal tool continue @@ -259,7 +252,6 @@ async def chat( group_id=last_step.id, ) yield metadata - await context.store(AgentMessage(metadata=metadata)) if isinstance(last_step.output, FileCreatorToolOutput): for file_info in last_step.output.result.files: @@ -267,16 +259,7 @@ async def chat( part.filename = file_info.display_filename artifact = AgentArtifact(name=file_info.display_filename, parts=[part]) yield artifact - await context.store(artifact) - - if final_answer: - citations, clean_text = extract_citations(final_answer.text) - message = AgentMessage( - text=clean_text, - metadata=(citation.citation_metadata(citations=citations) if citations else None), - ) - await context.store(message) except FrameworkError as err: raise RuntimeError(err.explain()) @@ -287,7 +270,6 @@ def serve(): host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), configure_telemetry=True, - context_store=PlatformContextStore(), auth_backend=PlatformAuthBackend(), ) except KeyboardInterrupt: diff --git a/agents/deepagents_content_builder/src/content_builder/agent.py b/agents/deepagents_content_builder/src/content_builder/agent.py index 82fac179..e713d395 100644 --- a/agents/deepagents_content_builder/src/content_builder/agent.py +++ b/agents/deepagents_content_builder/src/content_builder/agent.py @@ -91,7 +91,6 @@ async def content_builder_agent( return started_at = datetime.now(timezone.utc) - await context.store(data=message) subagents: list[SubAgent] = [] for sub_agent in AVAILABLE_SUBAGENTS: @@ -140,7 +139,6 @@ async def content_builder_agent( title=data["name"], content=json.dumps(obj=data["args"]) ) yield tool_call_metadata - await context.store(data=AgentMessage(metadata=tool_call_metadata)) tool_calls.clear() elif last_msg.tool_call_chunks: @@ -151,12 +149,10 @@ async def content_builder_agent( tool_calls[tc_id]["args"] += tc.get("args") or "" elif last_msg.text: yield AgentMessage(text=last_msg.text) - await context.store(AgentMessage(text=last_msg.text)) elif isinstance(last_msg, ToolMessage) and last_msg.name and last_msg.text: tool_message_metadata = trajectory.trajectory_metadata(title=last_msg.name, content=last_msg.text) yield tool_message_metadata - await context.store(data=AgentMessage(metadata=tool_message_metadata)) updated_files = await agent_stack_backend.alist(order_by="created_at", order="asc", created_after=started_at) for updated_file in updated_files: diff --git a/agents/rag/src/rag/agent.py b/agents/rag/src/rag/agent.py index 0199a396..36e35828 100644 --- a/agents/rag/src/rag/agent.py +++ b/agents/rag/src/rag/agent.py @@ -28,7 +28,6 @@ from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext from kagenti_adk.server.middleware.platform_auth_backend import PlatformAuthBackend -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.utils._tool import FinalAnswerTool @@ -116,7 +115,6 @@ async def rag( _: Annotated[PlatformApiExtensionServer, PlatformApiExtensionSpec()], ): """RAG agent that retrieves and generates text based on user queries.""" - await context.store(input) llm, embedding = _get_clients(llm_ext, embedding_ext) history = [m async for m in context.load_history()] @@ -181,7 +179,6 @@ async def rag( phase="end", ).metadata(trajectory) yield vector_store_create_metadata - await context.store(AgentMessage(metadata=vector_store_create_metadata)) tools.append(cast(Tool, VectorSearchTool(vector_store_id=vector_store_id, embedding_function=embedding))) async for item in embed_all_files( @@ -300,7 +297,6 @@ async def handle_tool_success(event, meta): metadata=(citation.citation_metadata(citations=citations) if citations else None), ) yield message - await context.store(message) def _get_clients( @@ -331,7 +327,6 @@ def serve(): host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), configure_telemetry=True, - context_store=PlatformContextStore(), auth_backend=PlatformAuthBackend(), ) except KeyboardInterrupt: diff --git a/apps/adk-py/examples/canvas_ui_code_agent.py b/apps/adk-py/examples/canvas_ui_code_agent.py index 22c25e45..2ea7afcb 100644 --- a/apps/adk-py/examples/canvas_ui_code_agent.py +++ b/apps/adk-py/examples/canvas_ui_code_agent.py @@ -84,8 +84,6 @@ async def artifacts_agent( ): """Works with artifacts""" - await context.store(input) - canvas_edit_request = await canvas.parse_canvas_edit_request(message=input) if canvas_edit_request: @@ -126,7 +124,6 @@ async def artifacts_agent( if pre_text := response[: match.start()]: message = AgentMessage(text=pre_text) yield message - await context.store(message) await asyncio.sleep(1) @@ -153,7 +150,6 @@ async def artifacts_agent( name=artifact_name, parts=[TextPart(text=code_content)], ) - await context.store(artifact) # Send first chunk with artifact_id to establish the artifact first_artifact = AgentArtifact( @@ -171,13 +167,11 @@ async def artifacts_agent( parts=[TextPart(text=chunk)], ) yield chunk_artifact - await context.store(chunk_artifact) await asyncio.sleep(0.3) if post_text := response[match.end() :]: message = AgentMessage(text=post_text) yield message - await context.store(message) if __name__ == "__main__": diff --git a/apps/adk-py/examples/canvas_ui_test_agent.py b/apps/adk-py/examples/canvas_ui_test_agent.py index 87383535..7f407694 100644 --- a/apps/adk-py/examples/canvas_ui_test_agent.py +++ b/apps/adk-py/examples/canvas_ui_test_agent.py @@ -67,8 +67,6 @@ async def artifacts_agent( ): """Works with artifacts""" - await context.store(input) - canvas_edit_request = await canvas.parse_canvas_edit_request(message=input) if canvas_edit_request: @@ -106,7 +104,6 @@ async def artifacts_agent( if pre_text := response[: match.start()].strip(): message = AgentMessage(text=pre_text) yield message - await context.store(message) await asyncio.sleep(1) @@ -137,7 +134,6 @@ async def artifacts_agent( name=artifact_name, parts=[TextPart(text=recipe_content)], ) - await context.store(artifact) # Send first chunk with artifact_id to establish the artifact first_artifact = AgentArtifact( @@ -155,13 +151,11 @@ async def artifacts_agent( parts=[TextPart(text=chunk)], ) yield chunk_artifact - await context.store(chunk_artifact) await asyncio.sleep(0.3) if post_text := response[match.end() :]: message = AgentMessage(text=post_text) yield message - await context.store(message) if __name__ == "__main__": diff --git a/apps/adk-py/examples/citation_agent.py b/apps/adk-py/examples/citation_agent.py index 8929c053..7857f555 100644 --- a/apps/adk-py/examples/citation_agent.py +++ b/apps/adk-py/examples/citation_agent.py @@ -12,7 +12,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -27,9 +26,6 @@ async def example_agent( ): """Agent that demonstrates citation extension usage""" - # Store the current message in the context store - await context.store(input) - # Simulate researching multiple sources research_text = """Based on recent research, artificial intelligence has made significant progress in natural language processing. Studies show that transformer models have revolutionized the field, and @@ -60,12 +56,11 @@ async def example_agent( metadata=citation.citation_metadata(citations=citations), ) yield message - await context.store(message) def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)) ) diff --git a/apps/adk-py/examples/citation_agent_artifact.py b/apps/adk-py/examples/citation_agent_artifact.py index af089064..a92524cb 100644 --- a/apps/adk-py/examples/citation_agent_artifact.py +++ b/apps/adk-py/examples/citation_agent_artifact.py @@ -12,7 +12,6 @@ from kagenti_adk.a2a.types import AgentArtifact from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -27,9 +26,6 @@ async def example_agent( ): """Agent that demonstrates citation extension usage""" - # Store the current message in the context store - await context.store(input) - # Simulate researching multiple sources research_text = """Based on recent research, artificial intelligence has made significant progress in natural language processing. Studies show that transformer models have revolutionized the field, and @@ -62,12 +58,10 @@ async def example_agent( ) yield artifact - await context.store(artifact) - def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8002)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8002)) ) diff --git a/apps/adk-py/examples/history.py b/apps/adk-py/examples/history.py index 27c0c3ca..76562e3b 100644 --- a/apps/adk-py/examples/history.py +++ b/apps/adk-py/examples/history.py @@ -19,9 +19,6 @@ async def example_agent(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -36,9 +33,6 @@ async def example_agent(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/apps/adk-py/examples/history_framework.py b/apps/adk-py/examples/history_framework.py index 711fa0ab..46607fe6 100644 --- a/apps/adk-py/examples/history_framework.py +++ b/apps/adk-py/examples/history_framework.py @@ -18,7 +18,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -44,8 +43,6 @@ async def multi_turn_chat_agent( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -81,14 +78,12 @@ async def multi_turn_chat_agent( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/apps/adk-py/examples/oauth.py b/apps/adk-py/examples/oauth.py index 52edeab3..0e6c89d5 100644 --- a/apps/adk-py/examples/oauth.py +++ b/apps/adk-py/examples/oauth.py @@ -30,7 +30,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -57,8 +56,6 @@ async def oauth_agent( oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # pyrefly: ignore [deprecated] -- TODO: upgrade mcp_client = streamablehttp_client( url="https://mcp.stripe.com", @@ -104,14 +101,12 @@ async def oauth_agent( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/apps/adk-py/examples/trajectory_agent.py b/apps/adk-py/examples/trajectory_agent.py index 00c67226..291a59d0 100644 --- a/apps/adk-py/examples/trajectory_agent.py +++ b/apps/adk-py/examples/trajectory_agent.py @@ -13,7 +13,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -28,15 +27,11 @@ async def example_agent( ): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - metadata = trajectory.trajectory_metadata( title="Initializing...", content="Initializing...", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2.5) @@ -46,7 +41,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(0.3) for i in range(4, 7): @@ -55,7 +49,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(0.8) await asyncio.sleep(1) @@ -65,7 +58,6 @@ async def example_agent( content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) metadata = trajectory.trajectory_metadata( @@ -134,7 +126,6 @@ def extract_entities(text): """, ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) @@ -169,7 +160,6 @@ def extract_entities(text): }""", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(1) @@ -177,25 +167,21 @@ def extract_entities(text): title="Web search", content="Querying search engines...", group_id="websearch" ) yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(4) metadata = trajectory.trajectory_metadata(content="Found 8 results.", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(1) metadata = trajectory.trajectory_metadata(content="Found 8 results\nAnalyzed 3/8 results", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(2) metadata = trajectory.trajectory_metadata(content="Found 8 results\nAnalyzed 8/8 results", group_id="websearch") yield metadata - await context.store(AgentMessage(metadata=metadata)) await asyncio.sleep(4) @@ -205,7 +191,6 @@ def extract_entities(text): group_id="websearch", ) yield metadata - await context.store(AgentMessage(metadata=metadata)) # Your agent logic here - you can now reference all messages in the conversation message = AgentMessage( @@ -213,13 +198,10 @@ def extract_entities(text): ) yield message - # Store the message in the context store - await context.store(message) - def run(): server.run( - host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)), context_store=PlatformContextStore() + host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000)) ) diff --git a/apps/adk-py/src/kagenti_adk/platform/context.py b/apps/adk-py/src/kagenti_adk/platform/context.py index cbf5ea56..379af275 100644 --- a/apps/adk-py/src/kagenti_adk/platform/context.py +++ b/apps/adk-py/src/kagenti_adk/platform/context.py @@ -5,37 +5,15 @@ from __future__ import annotations import builtins -from collections.abc import AsyncIterator -from typing import Any, Literal, Self -from uuid import UUID, uuid4 - +from typing import Literal import pydantic -from a2a.types import Artifact, Message -from google.protobuf.json_format import MessageToDict, ParseDict -from pydantic import AwareDatetime, BaseModel, Field, SerializeAsAny, computed_field +from pydantic import SerializeAsAny from kagenti_adk.platform.client import PlatformClient, get_platform_client from kagenti_adk.platform.common import PaginatedResult from kagenti_adk.platform.provider import Provider from kagenti_adk.platform.types import Metadata, MetadataPatch -from kagenti_adk.util.utils import filter_dict, utc_now - - -class ContextHistoryItem(BaseModel, arbitrary_types_allowed=True): - id: UUID = Field(default_factory=uuid4) - data: Artifact | Message - created_at: AwareDatetime = Field(default_factory=utc_now) - context_id: str - - @computed_field - @property - def kind(self) -> Literal["message", "artifact"]: - return getattr(self.data, "kind", "artifact") - - @pydantic.field_validator("data", mode="before") - @classmethod - def parse_data(cls: Self, value: dict[str, Any]) -> Artifact | Message: - return ParseDict(value, Artifact() if "artifact_id" in value or "artifactId" in value else Message()) +from kagenti_adk.util.utils import filter_dict class ContextToken(pydantic.BaseModel): @@ -80,7 +58,7 @@ async def create( metadata: Metadata | None = None, provider_id: str | None = None, client: PlatformClient | None = None, - ) -> Context: + ) -> "Context": async with client or get_platform_client() as client: return pydantic.TypeAdapter(Context).validate_python( ( @@ -103,7 +81,7 @@ async def list( order_by: Literal["created_at"] | Literal["updated_at"] | None = None, include_empty: bool = True, provider_id: str | None = None, - ) -> PaginatedResult[Context]: + ) -> PaginatedResult["Context"]: # `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance async with client or get_platform_client() as client: return pydantic.TypeAdapter(PaginatedResult[Context]).validate_python( @@ -127,10 +105,10 @@ async def list( ) async def get( - self: Context | str, + self: "Context" | str, *, client: PlatformClient | None = None, - ) -> Context: + ) -> "Context": # `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance context_id = self if isinstance(self, str) else self.id async with client or get_platform_client() as client: @@ -139,11 +117,11 @@ async def get( ) async def update( - self: Context | str, + self: "Context" | str, *, metadata: Metadata | None, client: PlatformClient | None = None, - ) -> Context: + ) -> "Context": # `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance context_id = self if isinstance(self, str) else self.id async with client or get_platform_client() as client: @@ -158,11 +136,11 @@ async def update( return result async def patch_metadata( - self: Context | str, + self: "Context" | str, *, metadata: MetadataPatch | None, client: PlatformClient | None = None, - ) -> Context: + ) -> "Context": # `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance context_id = self if isinstance(self, str) else self.id async with client or get_platform_client() as client: @@ -177,7 +155,7 @@ async def patch_metadata( return result async def delete( - self: Context | str, + self: "Context" | str, *, client: PlatformClient | None = None, ) -> None: @@ -187,7 +165,7 @@ async def delete( _ = (await client.delete(url=f"/api/v1/contexts/{context_id}")).raise_for_status() async def generate_token( - self: Context | str, + self: "Context" | str, *, providers: builtins.list[str] | builtins.list[Provider] | None = None, client: PlatformClient | None = None, @@ -231,69 +209,3 @@ async def generate_token( .json() ) return pydantic.TypeAdapter(ContextToken).validate_python({**token_response, "context_id": context_id}) - - async def add_history_item( - self: Context | str, - *, - data: Message | Artifact, - client: PlatformClient | None = None, - ) -> None: - """Add a Message or Artifact to the context history (append-only)""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - _ = ( - await platform_client.post( - url=f"/api/v1/contexts/{target_context_id}/history", json=MessageToDict(data) - ) - ).raise_for_status() - - async def delete_history_from_id( - self: Context | str, - *, - from_id: UUID | str, - client: PlatformClient | None = None, - ) -> None: - """Delete all history items from a specific item onwards (inclusive)""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - _ = ( - await platform_client.delete( - url=f"/api/v1/contexts/{target_context_id}/history", params={"from_id": str(from_id)} - ) - ).raise_for_status() - - async def list_history( - self: Context | str, - *, - page_token: str | None = None, - limit: int | None = None, - order: Literal["asc"] | Literal["desc"] | None = "asc", - order_by: Literal["created_at"] | Literal["updated_at"] | None = None, - client: PlatformClient | None = None, - ) -> PaginatedResult[ContextHistoryItem]: - """List all history items for this context in chronological order""" - target_context_id = self if isinstance(self, str) else self.id - async with client or get_platform_client() as platform_client: - return pydantic.TypeAdapter(PaginatedResult[ContextHistoryItem]).validate_python( - ( - await platform_client.get( - url=f"/api/v1/contexts/{target_context_id}/history", - params=filter_dict( - {"page_token": page_token, "limit": limit, "order": order, "order_by": order_by} - ), - ) - ) - .raise_for_status() - .json() - ) - - async def list_all_history( - self: Context | str, client: PlatformClient | None = None - ) -> AsyncIterator[ContextHistoryItem]: - result = await Context.list_history(self, client=client) - for item in result.items: - yield item - while result.has_more: - result = await Context.list_history(self, page_token=result.next_page_token, client=client) - for item in result.items: - yield item diff --git a/apps/adk-py/src/kagenti_adk/server/agent.py b/apps/adk-py/src/kagenti_adk/server/agent.py index 4df0393a..5e3944b4 100644 --- a/apps/adk-py/src/kagenti_adk/server/agent.py +++ b/apps/adk-py/src/kagenti_adk/server/agent.py @@ -51,13 +51,12 @@ from kagenti_adk.server.context import RunContext from kagenti_adk.server.dependencies import Dependency, Depends, extract_dependencies from kagenti_adk.server.exceptions import InvalidYieldError -from kagenti_adk.server.store.context_store import ContextStore from kagenti_adk.server.utils import cancel_task, merge_messages from kagenti_adk.types import A2ASecurity, JsonPatch from kagenti_adk.util.logging import logger AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]] -AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]] +AgentFunctionFactory: TypeAlias = Callable[[RequestContext, TaskStore], AbstractAsyncContextManager[AgentFunction]] OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) @@ -356,7 +355,7 @@ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None: class AgentRun: - def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None: + def __init__(self, agent: Agent, task_store: TaskStore, on_finish: Callable[[], None] | None = None) -> None: self._agent: Agent = agent self._task: asyncio.Task[None] | None = None self.last_invocation: datetime = datetime.now() @@ -364,7 +363,7 @@ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callabl self._run_context: RunContext | None = None self._request_context: RequestContext | None = None self._task_updater: TaskUpdater | None = None - self._context_store: ContextStore = context_store + self._task_store: TaskStore = task_store self._lock: asyncio.Lock = asyncio.Lock() self._on_finish: Callable[[], None] | None = on_finish self._working: bool = False @@ -403,14 +402,13 @@ async def start(self, request_context: RequestContext, event_queue: EventQueue): raise RuntimeError("Attempting to start a run that is already executing or done") task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message assert task_id and context_id and message - context_store = await self._context_store.create(context_id) self._run_context = RunContext( configuration=request_context.configuration, context_id=context_id, task_id=task_id, current_task=request_context.current_task, related_tasks=request_context.related_tasks, - _store=context_store, + _task_store=self._task_store, ) self._request_context = request_context self._task_updater = TaskUpdater(event_queue, task_id, context_id) @@ -599,14 +597,12 @@ def __init__( self, agent: Agent, queue_manager: QueueManager, - context_store: ContextStore, task_timeout: timedelta, task_store: TaskStore, ) -> None: self._agent: Agent = agent self._running_tasks: dict[str, AgentRun] = {} self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {} - self._context_store: ContextStore = context_store self._task_timeout: timedelta = task_timeout self._task_store: TaskStore = task_store @@ -618,7 +614,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non agent_run: AgentRun | None = None try: if not context.current_task: - agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id)) + agent_run = AgentRun(self._agent, self._task_store, lambda: self._handle_finish(task_id)) self._running_tasks[task_id] = agent_run await self._schedule_run_cleanup(request_context=context) await agent_run.start(request_context=context, event_queue=event_queue) diff --git a/apps/adk-py/src/kagenti_adk/server/app.py b/apps/adk-py/src/kagenti_adk/server/app.py index 3c5e0bc4..a3e08ec6 100644 --- a/apps/adk-py/src/kagenti_adk/server/app.py +++ b/apps/adk-py/src/kagenti_adk/server/app.py @@ -25,8 +25,6 @@ from kagenti_adk.a2a.extensions import BaseExtensionServer from kagenti_adk.server.agent import Agent, Executor from kagenti_adk.server.constants import DEFAULT_IMPLICIT_EXTENSIONS -from kagenti_adk.server.store.context_store import ContextStore -from kagenti_adk.server.store.memory_context_store import InMemoryContextStore from kagenti_adk.types import SdkAuthenticationBackend @@ -34,7 +32,6 @@ def create_app( agent: Agent, url: str, task_store: TaskStore | None = None, - context_store: ContextStore | None = None, implicit_extensions: dict[str, BaseExtensionServer] = DEFAULT_IMPLICIT_EXTENSIONS, required_extensions: set[str] | None = None, auth_backend: SdkAuthenticationBackend | None = None, @@ -49,12 +46,10 @@ def create_app( ) -> FastAPI: queue_manager = queue_manager or InMemoryQueueManager() task_store = task_store or InMemoryTaskStore() - context_store = context_store or InMemoryContextStore() http_handler = DefaultRequestHandler( agent_executor=Executor( agent, queue_manager, - context_store=context_store, task_timeout=task_timeout, task_store=task_store, ), @@ -73,7 +68,7 @@ def create_app( AgentInterface(url=url + "/jsonrpc/", protocol_binding="JSONRPC", protocol_version=protocol_version), ], implicit_extensions=implicit_extensions, - required_extensions=(required_extensions or set()) | context_store.required_extensions, + required_extensions=required_extensions or set(), ) jsonrpc_app = A2AFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build( diff --git a/apps/adk-py/src/kagenti_adk/server/context.py b/apps/adk-py/src/kagenti_adk/server/context.py index 01e85eb2..bc0f7567 100644 --- a/apps/adk-py/src/kagenti_adk/server/context.py +++ b/apps/adk-py/src/kagenti_adk/server/context.py @@ -4,21 +4,17 @@ from __future__ import annotations from collections.abc import AsyncGenerator -from typing import Literal, overload -from uuid import UUID import janus +from a2a.server.tasks import TaskStore from a2a.types import ( Artifact, Message, Task, ) -from asgiref.sync import async_to_sync from pydantic import BaseModel, PrivateAttr from kagenti_adk.a2a.types import RunYield, RunYieldResume -from kagenti_adk.platform.context import ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStoreInstance class RunContextSettings(BaseModel): @@ -32,49 +28,25 @@ class RunContext(BaseModel, arbitrary_types_allowed=True): related_tasks: list[Task] | None = None strict: bool = False # TODO: explain strict mode - what yields will stop message etc. Use in match/case - _store: ContextStoreInstance + _task_store: TaskStore _yield_queue: janus.Queue[RunYield] = PrivateAttr(default_factory=janus.Queue) _yield_resume_queue: janus.Queue[RunYieldResume | Exception] = PrivateAttr(default_factory=janus.Queue) - def __init__(self, _store: ContextStoreInstance, **data): + def __init__(self, _task_store: TaskStore, **data): super().__init__(**data) - self._store = _store - - def _prepare_store_data(self, data: Message | Artifact) -> Message | Artifact: - if not self._store: - raise RuntimeError("Context store is not initialized") - if isinstance(data, Message): - msg = Message() - msg.CopyFrom(data) - msg.context_id = self.context_id - msg.task_id = self.task_id - return msg - return data - - async def store(self, data: Message | Artifact): - await self._store.store(self._prepare_store_data(data)) - - def store_sync(self, data: Message | Artifact): - async_to_sync(self._store.store)(self._prepare_store_data(data)) - - @overload - def load_history(self, load_history_items: Literal[False] = False) -> AsyncGenerator[Message | Artifact, None]: ... - - @overload - def load_history(self, load_history_items: Literal[True]) -> AsyncGenerator[ContextHistoryItem, None]: ... - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncGenerator[ContextHistoryItem | Message | Artifact]: - if not self._store: - raise RuntimeError("Context store is not initialized") - async for item in self._store.load_history(load_history_items=load_history_items): - yield item - - async def delete_history_from_id(self, from_id: UUID) -> None: - if not self._store: - raise RuntimeError("Context store is not initialized") - await self._store.delete_history_from_id(from_id) + self._task_store = _task_store + + async def load_history(self) -> AsyncGenerator[Message | Artifact, None]: + """Load conversation history from the A2A TaskStore. + + Yields messages and artifacts from the current task's history. + """ + task = await self._task_store.get(self.task_id) + if task: + for msg in task.history: + yield msg + for artifact in task.artifacts: + yield artifact def yield_sync(self, value: RunYield) -> RunYieldResume: self._yield_queue.sync_q.put(value) diff --git a/apps/adk-py/src/kagenti_adk/server/server.py b/apps/adk-py/src/kagenti_adk/server/server.py index 1a301611..3ff3d1e8 100644 --- a/apps/adk-py/src/kagenti_adk/server/server.py +++ b/apps/adk-py/src/kagenti_adk/server/server.py @@ -36,7 +36,6 @@ from kagenti_adk.server.agent import Agent from kagenti_adk.server.agent import agent as agent_decorator from kagenti_adk.server.constants import DEFAULT_IMPLICIT_EXTENSIONS -from kagenti_adk.server.store.context_store import ContextStore from kagenti_adk.server.telemetry import configure_telemetry as configure_telemetry_func from kagenti_adk.server.utils import cancel_task from kagenti_adk.types import SdkAuthenticationBackend @@ -48,7 +47,6 @@ class Server: def __init__(self) -> None: self._agent: Agent | None = None self.server: uvicorn.Server | None = None - self._context_store: ContextStore | None = None self._self_registration_client: PlatformClient | None = None self._self_registration_id: str | None = None self._provider_id: str | None = None @@ -72,7 +70,6 @@ async def serve( self_registration: bool = True, self_registration_id: str | None = None, task_store: TaskStore | None = None, - context_store: ContextStore | None = None, queue_manager: QueueManager | None = None, task_timeout: timedelta = timedelta(minutes=10), push_config_store: PushNotificationConfigStore | None = None, @@ -194,7 +191,6 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]: lifespan=_lifespan_fn, implicit_extensions=implicit_extensions, task_store=task_store, - context_store=context_store, queue_manager=queue_manager, push_config_store=push_config_store, push_sender=push_sender, diff --git a/apps/adk-py/src/kagenti_adk/server/store/context_store.py b/apps/adk-py/src/kagenti_adk/server/store/context_store.py deleted file mode 100644 index 88e76c58..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/context_store.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - - -from __future__ import annotations - -import abc -from collections.abc import AsyncIterator -from typing import Protocol -from uuid import UUID - -from a2a.types import Artifact, Message - -from kagenti_adk.platform.context import ContextHistoryItem - -__all__ = [ - "ContextStore", - "ContextStoreInstance", -] - - -class ContextStoreInstance(Protocol): - def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: ... - async def store(self, data: Message | Artifact) -> None: ... - async def delete_history_from_id(self, from_id: UUID) -> None: ... - - -class ContextStore(abc.ABC): - @property - def required_extensions(self) -> set[str]: - return set() - - @abc.abstractmethod - async def create(self, context_id: str) -> ContextStoreInstance: ... diff --git a/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py b/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py deleted file mode 100644 index 29342639..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/memory_context_store.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncIterator -from datetime import timedelta -from uuid import UUID - -from a2a.types import Artifact, Message -from cachetools import TTLCache -from google.protobuf.json_format import MessageToDict - -from kagenti_adk.platform.context import ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStore, ContextStoreInstance - - -class MemoryContextStoreInstance(ContextStoreInstance): - def __init__(self, context_id: str): - self.context_id = context_id - self._history: list[ContextHistoryItem] = [] - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: - for item in self._history.copy(): - if load_history_items: - yield item - else: - yield item.data - - async def store(self, data: Message | Artifact) -> None: - self._history.append(ContextHistoryItem(data=MessageToDict(data), context_id=self.context_id)) - - async def delete_history_from_id(self, from_id: UUID) -> None: - # Does not allow to delete from an artifact onwards - index = next( - (i for i, item in enumerate(self._history) if item.id == from_id), - None, - ) - if index is not None: - self._history = self._history[:index] - - -class InMemoryContextStore(ContextStore): - def __init__(self, max_contexts: int = 1000, context_ttl: timedelta = timedelta(hours=1)): - """ - Initialize in-memory context store with TTL cache. - - Args: - max_contexts: Maximum number of contexts to keep in memory - ttl_seconds: Time-to-live for context instances in seconds (default: 1 hour) - """ - self._instances: TTLCache[str, MemoryContextStoreInstance] = TTLCache( - maxsize=max_contexts, ttl=context_ttl.total_seconds() - ) - - async def create(self, context_id: str) -> ContextStoreInstance: - if context_id not in self._instances: - self._instances[context_id] = MemoryContextStoreInstance(context_id) - return self._instances[context_id] diff --git a/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py b/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py deleted file mode 100644 index d1b20d85..00000000 --- a/apps/adk-py/src/kagenti_adk/server/store/platform_context_store.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from uuid import UUID - -from a2a.types import Artifact, Message - -from kagenti_adk.a2a.extensions.services.platform import PlatformApiExtensionServer, PlatformApiExtensionSpec -from kagenti_adk.platform.context import Context, ContextHistoryItem -from kagenti_adk.server.store.context_store import ContextStore, ContextStoreInstance - - -class PlatformContextStore(ContextStore): - @property - def required_extensions(self) -> set[str]: - return {PlatformApiExtensionSpec.URI} - - async def create(self, context_id: str) -> ContextStoreInstance: - return PlatformContextStoreInstance(context_id=context_id) - - -class PlatformContextStoreInstance(ContextStoreInstance): - def __init__(self, context_id: str): - self._context_id = context_id - - @asynccontextmanager - async def client(self): - if not (ext := PlatformApiExtensionServer.current()): - raise RuntimeError("PlatformApiExtensionServer is not initialized") - async with ext.use_client(): - yield - - async def load_history( - self, load_history_items: bool = False - ) -> AsyncIterator[ContextHistoryItem | Message | Artifact]: - async with self.client(): - async for history_item in Context.list_all_history(self._context_id): - if load_history_items: - yield history_item - else: - yield history_item.data - - async def store(self, data: Message | Artifact) -> None: - async with self.client(): - await Context.add_history_item(self._context_id, data=data) - - async def delete_history_from_id(self, from_id: UUID) -> None: - async with self.client(): - await Context.delete_history_from_id(self._context_id, from_id=from_id) diff --git a/apps/adk-py/tests/e2e/conftest.py b/apps/adk-py/tests/e2e/conftest.py index 82df111f..abbefc2e 100644 --- a/apps/adk-py/tests/e2e/conftest.py +++ b/apps/adk-py/tests/e2e/conftest.py @@ -6,7 +6,7 @@ import asyncio import base64 import socket -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager, closing from datetime import timedelta @@ -30,7 +30,6 @@ from kagenti_adk.a2a.types import AgentArtifact, AgentMessage, ArtifactChunk, InputRequired, RunYield, RunYieldResume from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.context_store import ContextStore pytestmark = pytest.mark.e2e @@ -52,7 +51,6 @@ def make_extension_context(extensions: list[str] | None = None) -> ClientCallCon async def run_server( server: Server, port: int, - context_store: ContextStore | None = None, task_timeout: timedelta | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: @@ -61,7 +59,6 @@ async def run_server( server.run, self_registration=False, port=port, - context_store=context_store, task_timeout=task_timeout or timedelta(minutes=5), ) ) @@ -91,14 +88,14 @@ def create_server_with_agent(): @asynccontextmanager async def _create_server( - agent_fn, context_store: ContextStore | None = None, task_timeout: timedelta | None = None - ) -> AsyncGenerator[tuple[Server, Client]]: + agent_fn, + task_timeout: timedelta | None = None, + ) -> AsyncIterator[tuple[Server, Client]]: server = Server() server.agent(detail=AgentDetail(interaction_mode="multi-turn"))(agent_fn) async with run_server( server, get_free_port(), - context_store=context_store, task_timeout=task_timeout, ) as (server, client): yield server, client diff --git a/apps/adk-py/tests/e2e/test_history.py b/apps/adk-py/tests/e2e/test_history.py index 786e8b4a..e67254b9 100644 --- a/apps/adk-py/tests/e2e/test_history.py +++ b/apps/adk-py/tests/e2e/test_history.py @@ -8,16 +8,15 @@ import pytest from a2a.client import Client, ClientEvent, create_text_message_object from a2a.types import ( + Artifact, Message, - Role, SendMessageRequest, Task, ) -from kagenti_adk.a2a.types import RunYield +from kagenti_adk.a2a.types import AgentMessage, RunYield from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.memory_context_store import InMemoryContextStore pytestmark = pytest.mark.e2e @@ -44,86 +43,31 @@ async def send_message_get_response( @pytest.fixture -async def history_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - context_store = InMemoryContextStore() - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: - await context.store(input) - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - yield message - await context.store(message) - - async with create_server_with_agent(history_agent, context_store=context_store) as (server, client): +async def history_reader_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: + """Agent that reads history from the task store via RunContext.load_history().""" + + async def history_reader(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: + # Load history from the task store (will contain messages from previous interactions in same task) + history_items: list[str] = [] + async for item in context.load_history(): + if isinstance(item, Message) and item.parts: + history_items.append(item.parts[0].text) + elif isinstance(item, Artifact) and item.parts: + history_items.append(f"artifact:{item.parts[0].text}") + + # Echo back what we found in history plus the current input + if history_items: + yield AgentMessage(text=f"history={','.join(history_items)}") + yield AgentMessage(text=f"input={input.parts[0].text}") + + async with create_server_with_agent(history_reader) as (server, client): yield server, client -@pytest.fixture -async def history_deleting_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - context_store = InMemoryContextStore() - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield, None]: - await context.store(input) - n_messages = 0 - async for message in context.load_history(load_history_items=True): - n_messages += 1 - if n_messages == 1: - delete_id = message.id - if n_messages > 3: - # pyrefly: ignore [unbound-name] - await context.delete_history_from_id(delete_id) - break - - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - yield message - - async with create_server_with_agent(history_agent, context_store=context_store) as (server, client): - yield server, client - - -async def test_agent_history(history_agent): - """Test that history starts empty.""" - _, client = history_agent - - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] - - agent_messages, context_id = await send_message_get_response(client, "second message", context_id=context_id) - assert agent_messages == ["first message", "first message", "second message"] - - agent_messages, context_id = await send_message_get_response(client, "third message", context_id=context_id) - assert agent_messages == [ - # first run - "first message", - # second run - "first message", - "second message", - # third run - "first message", - "first message", - "second message", - "third message", - ] - - -async def test_agent_deleting_history(history_deleting_agent): - """Test that history starts empty.""" - _, client = history_deleting_agent - - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] - - agent_messages, context_id = await send_message_get_response(client, "second message", context_id=context_id) - assert agent_messages == ["first message", "second message"] - - agent_messages, context_id = await send_message_get_response(client, "third message", context_id=context_id) - assert agent_messages == ["first message", "second message", "third message"] - - agent_messages, context_id = await send_message_get_response(client, "delete message", context_id=context_id) - assert agent_messages == [] +async def test_load_history_from_task_store(history_reader_agent): + """Test that RunContext.load_history() reads from the A2A task store.""" + _, client = history_reader_agent - agent_messages, context_id = await send_message_get_response(client, "first message") - assert agent_messages == ["first message"] + # First message — no history yet + agent_messages, context_id = await send_message_get_response(client, "hello") + assert any("input=hello" in msg for msg in agent_messages) diff --git a/apps/adk-server/src/adk_server/api/routes/contexts.py b/apps/adk-server/src/adk_server/api/routes/contexts.py index 5a501a8d..4e2d95eb 100644 --- a/apps/adk-server/src/adk_server/api/routes/contexts.py +++ b/apps/adk-server/src/adk_server/api/routes/contexts.py @@ -14,13 +14,11 @@ from adk_server.api.dependencies import ( ConfigurationDependency, ContextServiceDependency, - RequiresContextPermissionsPath, RequiresPermissions, ) -from adk_server.api.schema.common import EntityModel, PaginationQuery +from adk_server.api.schema.common import EntityModel from adk_server.api.schema.contexts import ( ContextCreateRequest, - ContextHistoryItemCreateRequest, ContextListQuery, ContextPatchMetadataRequest, ContextTokenCreateRequest, @@ -28,7 +26,7 @@ ContextUpdateRequest, ) from adk_server.domain.models.common import PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem +from adk_server.domain.models.context import Context from adk_server.domain.models.permissions import AuthorizedUser, Permissions logger = logging.getLogger(__name__) @@ -136,33 +134,3 @@ async def generate_context_token( configuration=configuration, ) return ContextTokenResponse(token=token, expires_at=expires_at) - - -@router.post("/{context_id}/history", status_code=status.HTTP_201_CREATED) -async def add_context_history_item( - context_id: UUID, - history_item_data: ContextHistoryItemCreateRequest, - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"write"}))], -) -> None: - await context_service.add_history_item(context_id=context_id, data=history_item_data.root, user=user.user) - - -@router.get("/{context_id}/history") -async def list_context_history( - context_id: UUID, - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"read"}))], - pagination: Annotated[PaginationQuery, Query()], -) -> PaginatedResult[ContextHistoryItem]: - return await context_service.list_history(context_id=context_id, user=user.user, pagination=pagination) - - -@router.delete("/{context_id}/history", status_code=status.HTTP_204_NO_CONTENT) -async def delete_context_history_from_id( - context_id: UUID, - from_id: Annotated[UUID, Query()], - context_service: ContextServiceDependency, - user: Annotated[AuthorizedUser, Depends(RequiresContextPermissionsPath(context_data={"read", "write"}))], -) -> None: - await context_service.delete_history_from_id(context_id=context_id, from_id=from_id, user=user.user) diff --git a/apps/adk-server/src/adk_server/api/schema/contexts.py b/apps/adk-server/src/adk_server/api/schema/contexts.py index 27c4c483..9a3a599d 100644 --- a/apps/adk-server/src/adk_server/api/schema/contexts.py +++ b/apps/adk-server/src/adk_server/api/schema/contexts.py @@ -5,11 +5,10 @@ from typing import Literal from uuid import UUID -from pydantic import AwareDatetime, BaseModel, Field, RootModel, field_validator +from pydantic import AwareDatetime, BaseModel, Field, field_validator from adk_server.api.schema.common import PaginationQuery from adk_server.domain.models.common import Metadata, MetadataPatch -from adk_server.domain.models.context import ContextHistoryItemData class ContextCreateRequest(BaseModel): @@ -88,5 +87,3 @@ class ContextTokenResponse(BaseModel): expires_at: AwareDatetime | None -class ContextHistoryItemCreateRequest(RootModel[ContextHistoryItemData]): - root: ContextHistoryItemData diff --git a/apps/adk-server/src/adk_server/domain/repositories/context.py b/apps/adk-server/src/adk_server/domain/repositories/context.py index df818ef6..738849f9 100644 --- a/apps/adk-server/src/adk_server/domain/repositories/context.py +++ b/apps/adk-server/src/adk_server/domain/repositories/context.py @@ -9,7 +9,7 @@ from uuid import UUID from adk_server.domain.models.common import PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem, TitleGenerationState +from adk_server.domain.models.context import Context class IContextRepository(Protocol): @@ -33,17 +33,3 @@ async def get(self, *, context_id: UUID, user_id: UUID | None = None) -> Context async def update(self, *, context: Context) -> None: ... async def delete(self, *, context_id: UUID, user_id: UUID | None = None) -> int: ... async def update_last_active(self, *, context_id: UUID) -> None: ... - async def update_title( - self, *, context_id: UUID, title: str | None = None, generation_state: TitleGenerationState - ) -> None: ... - async def add_history_item(self, *, context_id: UUID, history_item: ContextHistoryItem) -> None: ... - async def list_history( - self, - *, - context_id: UUID, - page_token: UUID | None = None, - limit: int = 20, - order_by: str = "created_at", - order="desc", - ) -> PaginatedResult[ContextHistoryItem]: ... - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID) -> int: ... diff --git a/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py b/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py index 119ffbf9..a975b517 100644 --- a/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py +++ b/apps/adk-server/src/adk_server/infrastructure/persistence/repositories/context.py @@ -5,16 +5,14 @@ from collections.abc import AsyncIterator from datetime import datetime -from uuid import UUID, uuid4 +from uuid import UUID from kink import inject -from pydantic import TypeAdapter from sqlalchemy import ( JSON, Column, DateTime, ForeignKey, - Index, Row, Table, delete, @@ -24,8 +22,8 @@ from sqlalchemy import UUID as SQL_UUID from sqlalchemy.ext.asyncio import AsyncConnection -from adk_server.domain.models.common import Metadata, PaginatedResult -from adk_server.domain.models.context import Context, ContextHistoryItem, TitleGenerationState +from adk_server.domain.models.common import PaginatedResult +from adk_server.domain.models.context import Context from adk_server.domain.repositories.context import IContextRepository from adk_server.exceptions import EntityNotFoundError from adk_server.infrastructure.persistence.repositories.db_metadata import metadata @@ -44,16 +42,6 @@ Column("metadata", JSON, nullable=True), ) -context_history_table = Table( - "context_history", - metadata, - Column("id", SQL_UUID, primary_key=True), - Column("context_id", ForeignKey("contexts.id", ondelete="CASCADE"), nullable=False), - Column("created_at", DateTime(timezone=True), nullable=False), - Column("data", JSON, nullable=False), - Index("idx_context_history_context_id", "context_id"), -) - @inject class SqlAlchemyContextRepository(IContextRepository): @@ -92,12 +80,8 @@ async def list_paginated( query = query.where(contexts_table.c.provider_id == provider_id) if last_active_before: query = query.where(contexts_table.c.last_active_at < last_active_before) - if not include_empty: - # Use EXISTS subquery to find contexts that have at least one history record - subquery = select(context_history_table.c.context_id).where( - context_history_table.c.context_id == contexts_table.c.id - ) - query = query.where(subquery.exists()) + # NOTE: include_empty is accepted but no longer filtered — context_history table has been removed. + # All contexts are now returned regardless of whether they have task history. result = await cursor_paginate( connection=self._connection, @@ -163,88 +147,6 @@ async def update_last_active(self, *, context_id: UUID) -> None: query = update(contexts_table).where(contexts_table.c.id == context_id).values(last_active_at=utc_now()) await self._connection.execute(query) - async def update_title( - self, *, context_id: UUID, title: str | None = None, generation_state: TitleGenerationState - ) -> None: - # validate length before saving to database - if title: - _ = TypeAdapter(Metadata).validate_python({"title": title}) - context = await self.get(context_id=context_id) - query = ( - contexts_table.update() - .where(contexts_table.c.id == context_id) - .values( - metadata=(context.metadata or {}) - | ({"title": title} if title else {}) - | {"title_generation_state": generation_state} - ) - ) - await self._connection.execute(query) - - async def add_history_item(self, *, context_id: UUID, history_item: ContextHistoryItem) -> None: - query = context_history_table.insert().values( - id=uuid4(), - context_id=history_item.context_id, - created_at=history_item.created_at, - data=history_item.data, - ) - await self._connection.execute(query) - - async def list_history( - self, - *, - context_id: UUID, - page_token: UUID | None = None, - limit: int = 20, - order_by: str = "created_at", - order="desc", - ) -> PaginatedResult[ContextHistoryItem]: - query = context_history_table.select().where(context_history_table.c.context_id == context_id) - result = await cursor_paginate( - connection=self._connection, - query=query, - after_cursor=page_token, - id_column=context_history_table.c.id, - order_column=getattr(context_history_table.c, order_by), - order=order, - limit=limit, - ) - return PaginatedResult( - items=[self._row_to_context_history_item(item) for item in result.items], - total_count=result.total_count, - has_more=result.has_more, - ) - - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID) -> int: - """Delete all history items from a specific item onwards (inclusive) in given context""" - # First, get the created_at timestamp of the item to delete from - query_item = select(context_history_table.c.created_at).where( - context_history_table.c.context_id == context_id, - context_history_table.c.id == from_id, - ) - result = await self._connection.execute(query_item) - row = result.first() - if not row: - raise EntityNotFoundError("context_history_item", from_id) - - created_at = row[0] - - # Delete all history items from the specified item onwards (created_at >= the target item's created_at) - query = delete(context_history_table).where( - context_history_table.c.context_id == context_id, - context_history_table.c.created_at >= created_at, - ) - result = await self._connection.execute(query) - return result.rowcount - - def _row_to_context_history_item(self, row: Row) -> ContextHistoryItem: - return ContextHistoryItem( - id=row.id, - data=row.data, - context_id=row.context_id, - created_at=row.created_at, - ) - def _row_to_context(self, row: Row) -> Context: return Context( id=row.id, diff --git a/apps/adk-server/src/adk_server/jobs/procrastinate.py b/apps/adk-server/src/adk_server/jobs/procrastinate.py index 5a171c05..94e30efe 100644 --- a/apps/adk-server/src/adk_server/jobs/procrastinate.py +++ b/apps/adk-server/src/adk_server/jobs/procrastinate.py @@ -13,7 +13,6 @@ from adk_server.jobs.crons.connector import blueprint as connector_crons from adk_server.jobs.crons.model_provider import blueprint as model_provider_crons from adk_server.jobs.crons.provider import blueprint as provider_crons -from adk_server.jobs.tasks.context import blueprint as context_tasks from adk_server.jobs.tasks.file import blueprint as file_tasks logger = logging.getLogger(__name__) @@ -54,7 +53,6 @@ def exit_app_on_db_error(*_args, **_kwargs): worker_defaults=WorkerOptions(install_signal_handlers=False), ) app.add_tasks_from(blueprint=file_tasks, namespace="text_extraction") - app.add_tasks_from(blueprint=context_tasks, namespace="context_tasks") app.add_tasks_from(blueprint=provider_crons, namespace="cron_provider") app.add_tasks_from(blueprint=model_provider_crons, namespace="cron_model_provider") app.add_tasks_from(blueprint=cleanup_crons, namespace="cron_cleanup") diff --git a/apps/adk-server/src/adk_server/jobs/queues.py b/apps/adk-server/src/adk_server/jobs/queues.py index cabf8600..2965a8e5 100644 --- a/apps/adk-server/src/adk_server/jobs/queues.py +++ b/apps/adk-server/src/adk_server/jobs/queues.py @@ -13,7 +13,6 @@ class Queues(StrEnum): CRON_MODEL_PROVIDER = "cron:model_provider" CRON_CONNECTOR = "cron:connector" # tasks - GENERATE_CONVERSATION_TITLE = "generate_conversation_title" TEXT_EXTRACTION = "text_extraction" TOOLKIT_DELETION = "toolkit_deletion" diff --git a/apps/adk-server/src/adk_server/jobs/tasks/context.py b/apps/adk-server/src/adk_server/jobs/tasks/context.py index 9d8fb7cd..7abe8ca6 100644 --- a/apps/adk-server/src/adk_server/jobs/tasks/context.py +++ b/apps/adk-server/src/adk_server/jobs/tasks/context.py @@ -1,20 +1,2 @@ # Copyright 2026 © IBM Corp. # SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from uuid import UUID - -from kink import inject -from procrastinate import Blueprint - -from adk_server.jobs.queues import Queues -from adk_server.service_layer.services.contexts import ContextService - -blueprint = Blueprint() - - -@blueprint.task(queue=str(Queues.GENERATE_CONVERSATION_TITLE)) -@inject -async def generate_conversation_title(context_id: str, context_service: ContextService): - await context_service.generate_conversation_title(context_id=UUID(context_id)) diff --git a/apps/adk-server/src/adk_server/run_workers.py b/apps/adk-server/src/adk_server/run_workers.py index ee8b08f0..d1b3d8d7 100644 --- a/apps/adk-server/src/adk_server/run_workers.py +++ b/apps/adk-server/src/adk_server/run_workers.py @@ -29,11 +29,6 @@ async def run_workers(app: procrastinate.App): ], concurrency=10, ), - WorkerOptions( - name="generate_conversation_title_worker", - queues=[str(Queues.GENERATE_CONVERSATION_TITLE)], - concurrency=10, - ), WorkerOptions(name="text_extraction_worker", queues=[str(Queues.TEXT_EXTRACTION)], concurrency=5), ] diff --git a/apps/adk-server/src/adk_server/service_layer/services/a2a.py b/apps/adk-server/src/adk_server/service_layer/services/a2a.py index fb0ba490..df9bd934 100644 --- a/apps/adk-server/src/adk_server/service_layer/services/a2a.py +++ b/apps/adk-server/src/adk_server/service_layer/services/a2a.py @@ -378,7 +378,8 @@ async def on_list_tasks( params: ListTasksRequest, context: ServerCallContext, ) -> ListTasksResponse: - raise NotImplementedError("This is not supported by the client transport yet") + async with self._client_transport(context) as transport: + return await transport.list_tasks(params, context=self._forward_context(context)) @inject diff --git a/apps/adk-server/src/adk_server/service_layer/services/contexts.py b/apps/adk-server/src/adk_server/service_layer/services/contexts.py index 25c5a8ee..3c16eb8a 100644 --- a/apps/adk-server/src/adk_server/service_layer/services/contexts.py +++ b/apps/adk-server/src/adk_server/service_layer/services/contexts.py @@ -4,10 +4,8 @@ from __future__ import annotations import logging -from collections.abc import Sequence from contextlib import suppress from datetime import timedelta -from typing import Any from uuid import UUID from fastapi import status @@ -15,19 +13,12 @@ from pydantic import TypeAdapter from adk_server.api.schema.common import PaginationQuery -from adk_server.api.schema.openai import ChatCompletionRequest from adk_server.configuration import Configuration from adk_server.domain.models.common import Metadata, MetadataPatch, PaginatedResult -from adk_server.domain.models.context import ( - Context, - ContextHistoryItem, - ContextHistoryItemData, - TitleGenerationState, -) +from adk_server.domain.models.context import Context from adk_server.domain.models.user import User from adk_server.domain.repositories.file import IObjectStorageRepository from adk_server.exceptions import EntityNotFoundError, PlatformError -from adk_server.service_layer.services.model_providers import ModelProviderService from adk_server.service_layer.unit_of_work import IUnitOfWorkFactory from adk_server.utils.utils import filter_dict, utc_now @@ -41,13 +32,11 @@ def __init__( uow: IUnitOfWorkFactory, configuration: Configuration, object_storage: IObjectStorageRepository, - model_provider_service: ModelProviderService, ): self._uow = uow self._object_storage = object_storage self._configuration = configuration self._expire_resources_after = timedelta(days=configuration.context.resources_expire_after_days) - self._model_provider_service = model_provider_service async def create(self, *, user: User, metadata: Metadata, provider_id: UUID | None = None) -> Context: context = Context(created_by=user.id, metadata=metadata, provider_id=provider_id) @@ -160,125 +149,3 @@ async def update_last_active(self, *, context_id: UUID) -> None: await uow.contexts.update_last_active(context_id=context_id) await uow.commit() - def _extract_content_for_title(self, msg: dict[str, Any]) -> tuple[str, str | None, Sequence[dict[str, Any]]]: - title_hint: str | None = None - text_parts: list[str] = [] - files: list[dict[str, Any]] = [] - for part in msg.get("parts", []): - if "text" in part: - text_parts.append(part["text"]) - elif "data" in part: - data = part["data"] - if isinstance(data, dict): - hint = data.get("title_hint") - if isinstance(hint, str) and hint and not title_hint: - title_hint = hint - elif "file" in part: - files.append(part["file"]) - - return "".join(text_parts), title_hint, files - - async def add_history_item(self, *, context_id: UUID, data: ContextHistoryItemData, user: User) -> None: - async with self._uow() as uow: - context = await uow.contexts.get(context_id=context_id, user_id=user.id) - await uow.contexts.add_history_item( - context_id=context_id, - history_item=ContextHistoryItem(context_id=context_id, data=data), - ) - - if data.get("role") == "ROLE_USER" and not (context.metadata or {}).get("title"): - from adk_server.jobs.tasks.context import generate_conversation_title as task - - # Use simple text extraction for the initial title placeholder - title = self._extract_content_for_title(data)[0] or "Untitled" - title = f"{title[:100]}..." if len(title) > 100 else title - - should_generate_title = self._configuration.generate_conversation_title.enabled - state = TitleGenerationState.PENDING if should_generate_title else TitleGenerationState.COMPLETED - await uow.contexts.update_title(context_id=context_id, title=title, generation_state=state) - - if should_generate_title: - await task.configure(queueing_lock=str(context_id)).defer_async(context_id=str(context_id)) - - await uow.commit() - - async def generate_conversation_title(self, *, context_id: UUID): - from jinja2 import Template - - async with self._uow() as uow: - msg = await uow.contexts.list_history(context_id=context_id, limit=1, order="desc", order_by="created_at") - system_config = await uow.configuration.get_system_configuration() - - model = self._configuration.generate_conversation_title.model - if model == "default": - if not system_config.default_llm_model: - logger.warning(f"Cannot generate title for context {context_id}: default LLM model not set.") - return - model = system_config.default_llm_model - - if not msg.items: - logger.warning(f"Cannot generate title for context {context_id}: no history found.") - return - - raw_message = msg.items[0].data - text, title_hint, files = self._extract_content_for_title(raw_message) - if not text and not title_hint and not files: - logger.warning(f"Cannot generate title for context {context_id}: first message has no content.") - return - - try: - # Render the system prompt using Jinja2 - template = Template(self._configuration.generate_conversation_title.prompt) - prompt = template.render( - text=text, - titleHint=title_hint, - files=[{"name": f.get("name"), "mime_type": f.get("mime_type")} for f in files], - rawMessage=raw_message, - ) - resp = await self._model_provider_service.create_chat_completion( - request=ChatCompletionRequest( - model=model, - stream=False, - max_completion_tokens=100, - messages=[{"role": "user", "content": prompt}], - ) - ) - title = (resp.choices[0].message.content or "").strip().strip("\"'") - title = f"{title[:100]}..." if len(title) > 100 else title - if not title: - raise RuntimeError("Generated title is empty.") - async with self._uow() as uow: - await uow.contexts.update_title( - context_id=context_id, title=title, generation_state=TitleGenerationState.COMPLETED - ) - await uow.commit() - except Exception as e: - async with self._uow() as uow: - await uow.contexts.update_title( - context_id=context_id, title=None, generation_state=TitleGenerationState.FAILED - ) - await uow.commit() - logger.warning(f"Failed to generate title for context {context_id}: {e}") - raise e - - async def list_history( - self, *, context_id: UUID, user: User, pagination: PaginationQuery - ) -> PaginatedResult[ContextHistoryItem]: - async with self._uow() as uow: - await uow.contexts.get(context_id=context_id, user_id=user.id) - return await uow.contexts.list_history( - context_id=context_id, - limit=pagination.limit, - page_token=pagination.page_token, - order=pagination.order, - order_by=pagination.order_by, - ) - - async def delete_history_from_id(self, *, context_id: UUID, from_id: UUID, user: User) -> None: - """Delete all history items from a specific item onwards (inclusive)""" - async with self._uow() as uow: - # Verify user has access to this context - await uow.contexts.get(context_id=context_id, user_id=user.id) - # Delete history items from the specified ID onwards - await uow.contexts.delete_history_from_id(context_id=context_id, from_id=from_id) - await uow.commit() diff --git a/apps/adk-server/tests/e2e/agents/conftest.py b/apps/adk-server/tests/e2e/agents/conftest.py index b99ec8cb..10f925af 100644 --- a/apps/adk-server/tests/e2e/agents/conftest.py +++ b/apps/adk-server/tests/e2e/agents/conftest.py @@ -14,7 +14,6 @@ from kagenti_adk.platform import PlatformClient, Provider from kagenti_adk.platform.context import ContextToken from kagenti_adk.server import Server -from kagenti_adk.server.store.context_store import ContextStore from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed from tests.conftest import Configuration @@ -27,7 +26,6 @@ async def run_server( test_admin: tuple[str, str], a2a_client_factory: Callable[[AgentCard | dict[str, Any], ContextToken], AsyncIterator[Client]], context_token: ContextToken, - context_store: ContextStore | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -35,7 +33,6 @@ async def run_server( server.run, port=port, self_registration_client_factory=lambda: PlatformClient(auth=test_admin), - context_store=context_store, ) ) @@ -66,7 +63,6 @@ def create_server_with_agent( async def _create_server( agent_fn, context_token: ContextToken, - context_store: ContextStore | None = None, ): server = Server() server.agent()(agent_fn) @@ -74,7 +70,6 @@ async def _create_server( server, free_port, a2a_client_factory=a2a_client_factory, - context_store=context_store, context_token=context_token, test_admin=test_admin, ) as (server, client): diff --git a/apps/adk-server/tests/e2e/agents/test_context_store.py b/apps/adk-server/tests/e2e/agents/test_context_store.py deleted file mode 100644 index 37fb6812..00000000 --- a/apps/adk-server/tests/e2e/agents/test_context_store.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2026 © IBM Corp. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import AsyncGenerator, AsyncIterator - -import pytest -from a2a.client import Client, ClientEvent, create_text_message_object -from a2a.types import Message, Role, SendMessageRequest, Task -from kagenti_adk.a2a.extensions import PlatformApiExtensionClient, PlatformApiExtensionSpec -from kagenti_adk.a2a.types import RunYield -from kagenti_adk.platform.context import Context, ContextPermissions, ContextToken, Permissions -from kagenti_adk.server import Server -from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore - -pytestmark = pytest.mark.e2e - - -async def get_final_task_from_stream(stream: AsyncIterator[ClientEvent | Message]) -> Task | None: - final_task = None - async for event in stream: - match event: - case (_, task): - final_task = task - return final_task - - -@pytest.fixture -async def history_agent(create_server_with_agent) -> AsyncGenerator[tuple[Server, Client]]: - """Agent that tests context.store.load_history() functionality.""" - - async def history_agent(input: Message, context: RunContext) -> AsyncGenerator[RunYield]: - input.metadata = {"test": "metadata"} - await context.store(input) - async for message in context.load_history(): - message.role = Role.ROLE_AGENT - assert message.metadata == {"test": "metadata"} - yield message - await context.store(message) - - context = await Context.create() - token = await context.generate_token(grant_global_permissions=Permissions(a2a_proxy={"*"})) - async with create_server_with_agent( - history_agent, - context_token=token, - context_store=PlatformContextStore(), - ) as (server, client): - yield server, client - - -def create_message(token: ContextToken, content: str) -> Message: - api_extension_client = PlatformApiExtensionClient(PlatformApiExtensionSpec()) - message = create_text_message_object(content=content) - message.metadata = api_extension_client.api_auth_metadata(auth_token=token.token, expires_at=token.expires_at) - message.context_id = token.context_id - return message - - -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_agent_history(history_agent, subtests): - _, client = history_agent - - with subtests.test("history repeats itself"): - context1 = await Context.create() - token = await context1.generate_token( - grant_context_permissions=ContextPermissions(context_data={"*"}), - grant_global_permissions=Permissions(a2a_proxy={"*"}), - ) - - final_task = await get_final_task_from_stream( - client.send_message(SendMessageRequest(message=create_message(token, "first message"))) - ) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == ["first message"] - - final_task = await get_final_task_from_stream( - client.send_message(SendMessageRequest(message=create_message(token, "second message"))) - ) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == ["first message", "first message", "second message"] - - final_task = await get_final_task_from_stream( - client.send_message(SendMessageRequest(message=create_message(token, "third message"))) - ) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert all(msg.metadata == {"test": "metadata"} for msg in final_task.history) - assert agent_messages == [ - # first run - "first message", - # second run - "first message", - "second message", - # third run - "first message", - "first message", - "second message", - "third message", - ] - - context1_history = await Context.list_history(context1.id) - assert context1_history.total_count == 14 - - with subtests.test("other context id does not mix history"): - context2 = await Context.create() - token = await context2.generate_token( - grant_context_permissions=ContextPermissions(context_data={"*"}), - grant_global_permissions=Permissions(a2a_proxy={"*"}), - ) - final_task = await get_final_task_from_stream( - client.send_message(SendMessageRequest(message=create_message(token, "first message"))) - ) - agent_messages = [msg.parts[0].text for msg in final_task.history] - assert agent_messages == ["first message"] - - context1_history = await Context.list_history(context1.id) - assert context1_history.total_count == 14 - - context2_history = await Context.list_history(context2.id) - assert context2_history.total_count == 2 diff --git a/apps/adk-server/tests/e2e/routes/test_contexts.py b/apps/adk-server/tests/e2e/routes/test_contexts.py index eeac9942..874a5675 100644 --- a/apps/adk-server/tests/e2e/routes/test_contexts.py +++ b/apps/adk-server/tests/e2e/routes/test_contexts.py @@ -6,9 +6,8 @@ import uuid import pytest -from httpx import HTTPStatusError -from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.platform.context import Context +from httpx import HTTPStatusError pytestmark = pytest.mark.e2e @@ -70,102 +69,6 @@ async def test_context_pagination(subtests): assert len(response.items) == 5 # Should return all contexts -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_history_pagination(subtests): - """Test cursor-based pagination for context history endpoint.""" - - # Create a context for testing - context = await Context.create() - - # Create more than 40 history items (default page size) to test pagination - num_items = 45 - - with subtests.test("add multiple history items"): - for i in range(num_items): - message = AgentMessage(text=f"Test message {i}") - await context.add_history_item(data=message) - - with subtests.test("test default pagination (first page)"): - response = await Context.list_history(context.id) - assert len(response.items) == 40 # Default page size - assert response.has_more is True - assert response.next_page_token is not None - - # Verify items are ordered by created_at desc (newest first) - created_ats = [item.created_at for item in response.items] - assert created_ats == sorted(created_ats, reverse=False) - - with subtests.test("test pagination with custom limit"): - response = await Context.list_history(context.id, limit=10) - assert len(response.items) == 10 - assert response.has_more is True - assert response.next_page_token is not None - - with subtests.test("test cursor-based pagination"): - # Get first page with limit 20 - first_page = await Context.list_history(context.id, limit=20) - assert len(first_page.items) == 20 - assert first_page.has_more is True - - # Get second page using next_page_token as cursor - second_page = await Context.list_history(context.id, limit=20, page_token=first_page.next_page_token) - assert len(second_page.items) == 20 - assert second_page.has_more is True - - # Get third page - third_page = await Context.list_history(context.id, limit=20, page_token=second_page.next_page_token) - assert len(third_page.items) == 5 # Remaining items - assert third_page.has_more is False - - # Verify no duplicate items across pages - all_items = first_page.items + second_page.items + third_page.items - all_ids = [item.id for item in all_items if hasattr(item, "id")] - assert len(all_ids) == len(set(all_ids)) # No duplicates - - with subtests.test("test ascending order"): - response = await Context.list_history(context.id, order="asc", limit=5) - created_ats = [item.created_at for item in response.items] - assert created_ats == sorted(created_ats) # Should be ascending - - with subtests.test("test list_all_history method"): - # Test the list_all_history method that automatically iterates through all pages - all_items = [] - async for item in Context.list_all_history(context.id): - all_items.append(item) - - assert len(all_items) == num_items - - # Verify chronological order (oldest first since it yields in order) - created_ats = [item.created_at for item in all_items] - # Note: list_all_history should maintain the order from list_history (desc by default) - # but iterate through all pages - - -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_empty_filtering(subtests): - """Test filtering contexts based on whether they have history records.""" - - with subtests.test("create contexts with and without history"): - # Create empty context (no history) - empty_context = await Context.create() - - # Create context with history - context_with_history = await Context.create() - message = AgentMessage(text="Test message") - await context_with_history.add_history_item(data=message) - - with subtests.test("include_empty=True returns all contexts"): - response = await Context.list(include_empty=True) - assert len(response.items) == 2 # Should include both contexts - - with subtests.test("include_empty=False returns only contexts with history"): - response = await Context.list(include_empty=False) - context_ids = [ctx.id for ctx in response.items] - assert len(context_ids) == 1 - assert context_with_history.id in context_ids - assert empty_context.id not in context_ids - - @pytest.mark.usefixtures("clean_up", "setup_platform_client") async def test_context_update_and_patch(subtests): """Test updating and patching context metadata.""" @@ -255,42 +158,3 @@ async def test_context_provider_filtering(subtests): assert fetched_context.provider_id == provider1.id -@pytest.mark.usefixtures("clean_up", "setup_platform_client") -async def test_context_delete_context_history_from_id(subtests): - """Test deleting context history from a specific item ID onwards.""" - - context = None - history_items = [] - n_messages = 3 - - with subtests.test("create context and add multiple history items"): - context = await Context.create() - for i in range(n_messages): - message = AgentMessage(text=f"Test message {i}") - await context.add_history_item(data=message) - - history = await context.list_history(limit=50) - history_items = history.items - assert len(history.items) == n_messages - - with subtests.test("delete history from a middle item onwards"): - await context.delete_history_from_id(from_id=history_items[1].id) - - remaining_history = await context.list_history(limit=50) - remaining_ids = [item.id for item in remaining_history.items] - assert len(remaining_history.items) == 1 - assert history_items[0].id in remaining_ids - assert history_items[1].id not in remaining_ids - assert history_items[2].id not in remaining_ids - - with subtests.test("delete with nonexistent item_id raises error"): - nonexistent_id = uuid.uuid4() - with pytest.raises(HTTPStatusError) as exc_info: - await context.delete_history_from_id(from_id=nonexistent_id) - assert exc_info.value.response.status_code == 404 - - with subtests.test("delete from first item deletes all"): - await context.delete_history_from_id(from_id=remaining_ids[0]) - # await context.delete_history_from_id(from_id=remaining_ids[0]) - remaining_history = await context.list_history(limit=50) - assert len(remaining_history.items) == 0 diff --git a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts index 3a6c4daa..48a1d606 100644 --- a/apps/adk-ui/src/api/a2a/jsonrpc-client.ts +++ b/apps/adk-ui/src/api/a2a/jsonrpc-client.ts @@ -8,6 +8,20 @@ import { agentCardSchema, streamResponseSchema } from '@kagenti/adk'; import { EventSourceParserStream } from 'eventsource-parser/stream'; import { v4 as uuid } from 'uuid'; +export interface ListTasksParams { + contextId?: string; + status?: string; + pageSize?: number; + pageToken?: string; +} + +export interface ListTasksResponse { + tasks: Task[]; + nextPageToken?: string; + totalSize?: number; + pageSize?: number; +} + export interface A2AClient { getAgentCard(): Promise; sendMessageStream(params: { @@ -17,6 +31,7 @@ export interface A2AClient { }): AsyncIterable; getTask(params: { id: string }): Promise; cancelTask(params: { id: string }): Promise; + listTasks(params: ListTasksParams): Promise; } interface CreateClientParams { @@ -118,6 +133,10 @@ export function createA2AClient({ endpointUrl, agentCard, fetchImpl, extensions async cancelTask(params) { return jsonRpcRequest('CancelTask', params) as Promise; }, + + async listTasks(params) { + return jsonRpcRequest('ListTasks', { ...params }) as Promise; + }, }; } diff --git a/apps/adk-ui/src/api/a2a/list-tasks.ts b/apps/adk-ui/src/api/a2a/list-tasks.ts new file mode 100644 index 00000000..25f09a0e --- /dev/null +++ b/apps/adk-ui/src/api/a2a/list-tasks.ts @@ -0,0 +1,79 @@ +/** + * Copyright 2026 © IBM Corp. + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Task } from '@kagenti/adk'; +import { v4 as uuid } from 'uuid'; + +import { ensureToken } from '#app/(auth)/rsc.tsx'; +import { runtimeConfig } from '#contexts/App/runtime-config.ts'; +import { getBaseUrl } from '#utils/api/getBaseUrl.ts'; + +export interface ListTasksParams { + contextId?: string; + status?: string; + pageSize?: number; + pageToken?: string; +} + +export interface ListTasksResponse { + tasks: Task[]; + nextPageToken?: string; + totalSize?: number; + pageSize?: number; +} + +/** + * Server-side function to fetch tasks from the A2A proxy via JSON-RPC. + * Used in React Server Components. + */ +export async function fetchTasksForContext( + providerId: string, + contextId: string, +): Promise { + try { + const baseUrl = getBaseUrl(); + const endpointUrl = `${baseUrl}/api/v1/a2a/${providerId}/`; + + const { isAuthEnabled } = runtimeConfig; + const headers: Record = { 'Content-Type': 'application/json' }; + + if (isAuthEnabled) { + const token = await ensureToken(); + if (token?.accessToken) { + headers['Authorization'] = `Bearer ${token.accessToken}`; + } + } + + const response = await fetch(endpointUrl, { + method: 'POST', + headers, + body: JSON.stringify({ + jsonrpc: '2.0', + id: uuid(), + method: 'ListTasks', + params: { + contextId, + }, + }), + }); + + if (!response.ok) { + console.error(`ListTasks request failed: ${response.status} ${response.statusText}`); + return undefined; + } + + const data = await response.json(); + + if (data.error) { + console.error('ListTasks error:', data.error); + return undefined; + } + + return data.result as ListTasksResponse; + } catch (error) { + console.error('Failed to fetch tasks:', error); + return undefined; + } +} diff --git a/apps/adk-ui/src/modules/history/utils.ts b/apps/adk-ui/src/modules/history/utils.ts index dc6356e6..efcb450a 100644 --- a/apps/adk-ui/src/modules/history/utils.ts +++ b/apps/adk-ui/src/modules/history/utils.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Artifact, ContextHistory, Message } from '@kagenti/adk'; +import type { Artifact, ContextHistory, Message, Task } from '@kagenti/adk'; import { v4 as uuid } from 'uuid'; import { processMessageMetadata, processParts } from '#api/a2a/part-processors.ts'; @@ -93,3 +93,51 @@ export function convertHistoryToUIMessages(history: ContextHistory[]): UIMessage return messages; } + +/** + * Convert A2A Tasks (with their history and artifacts) into UI messages. + * Tasks are expected in chronological order (oldest first). + */ +export function convertTasksToUIMessages(tasks: Task[]): UIMessage[] { + const allMessages: UIMessage[] = []; + + for (const task of tasks) { + const taskId = task.id; + + // Process history messages + for (const msg of task.history ?? []) { + const uiMessage = processHistoryMessage(msg, taskId); + + const lastMessage = allMessages.at(-1); + const shouldGroup = lastMessage && lastMessage.role === uiMessage.role && lastMessage.taskId === uiMessage.taskId; + + if (shouldGroup) { + allMessages.splice(-1, 1, { + ...lastMessage, + parts: [...uiMessage.parts, ...lastMessage.parts], + }); + } else { + allMessages.push(uiMessage); + } + } + + // Process artifacts + for (const artifact of task.artifacts ?? []) { + const uiMessage = processHistoryArtifact(artifact, taskId); + + const lastMessage = allMessages.at(-1); + const shouldGroup = lastMessage && lastMessage.role === uiMessage.role && lastMessage.taskId === uiMessage.taskId; + + if (shouldGroup) { + allMessages.splice(-1, 1, { + ...lastMessage, + parts: [...uiMessage.parts, ...lastMessage.parts], + }); + } else { + allMessages.push(uiMessage); + } + } + } + + return allMessages; +} diff --git a/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx b/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx index 9e8a99dc..84db539e 100644 --- a/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx +++ b/apps/adk-ui/src/modules/messages/contexts/Messages/MessagesProvider.tsx @@ -3,70 +3,22 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type PropsWithChildren, useCallback, useEffect, useMemo } from 'react'; +import { type PropsWithChildren, useCallback, useMemo } from 'react'; -import { useFetchNextPage } from '#hooks/useFetchNextPage.ts'; import { useImmerWithGetter } from '#hooks/useImmerWithGetter.ts'; -import { convertHistoryToUIMessages } from '#modules/history/utils.ts'; +import { convertTasksToUIMessages } from '#modules/history/utils.ts'; import type { UIMessage } from '#modules/messages/types.ts'; -import { isAgentMessage } from '#modules/messages/utils.ts'; -import { LIST_CONTEXT_HISTORY_DEFAULT_QUERY } from '#modules/platform-context/api/constants.ts'; -import { useListContextHistory } from '#modules/platform-context/api/queries/useListContextHistory.ts'; -import { isHistoryMessage } from '#modules/platform-context/api/utils.ts'; import { usePlatformContext } from '#modules/platform-context/contexts/index.ts'; import { MessagesContext } from './messages-context'; export function MessagesProvider({ children }: PropsWithChildren) { - const { contextId, history: initialHistory } = usePlatformContext(); - - const { data: history, ...queryRest } = useListContextHistory({ - context_id: contextId ?? undefined, - query: LIST_CONTEXT_HISTORY_DEFAULT_QUERY, - initialData: initialHistory, - // Ensures newly created messages are not fetched from history - initialPageParam: initialHistory?.next_page_token ?? undefined, - // Ensures history is not fetched for newly created contexts, where previous rule isn't sufficient to prevent message duplication - enabled: Boolean(initialHistory), - }); + const { initialTasks } = usePlatformContext(); const [messages, getMessages, setMessages] = useImmerWithGetter( - convertHistoryToUIMessages(history ?? []), + convertTasksToUIMessages(initialTasks ?? []), ); - useEffect(() => { - if (history) { - setMessages((messages) => { - const lastMessage = messages.at(-1); - const lastMessageHistoryIndex = lastMessage - ? history.findIndex(({ data }) => - isHistoryMessage(data) - ? data.messageId === lastMessage?.id - : isAgentMessage(lastMessage) && data.artifactId === lastMessage?.artifactId, - ) - : null; - - const historyContainsLastMessage = lastMessageHistoryIndex !== null && lastMessageHistoryIndex >= 0; - const newItems = historyContainsLastMessage ? history.slice(lastMessageHistoryIndex) : history; - - // Remove last message and convert it again from history, because - // newly fetched history can contain subsequent trajectories of the message - if (historyContainsLastMessage) { - messages.splice(-1, 1); - } - - messages.push(...convertHistoryToUIMessages(newItems)); - }); - } - }, [history, setMessages]); - - const { fetchNextPage, isFetching, hasNextPage } = queryRest; - const { ref: fetchNextPageInViewAnchorRef } = useFetchNextPage({ - fetchNextPage, - isFetching, - hasNextPage, - }); - const isLastMessage = useCallback((message: UIMessage) => getMessages().at(0)?.id === message.id, [getMessages]); const value = useMemo( @@ -76,11 +28,13 @@ export function MessagesProvider({ children }: PropsWithChildren) { setMessages, isLastMessage, queryControl: { - ...queryRest, - fetchNextPageInViewAnchorRef, + fetchNextPageInViewAnchorRef: { current: null } as React.RefObject, + isFetching: false, + isFetchingNextPage: false, + hasNextPage: false, }, }), - [messages, getMessages, setMessages, isLastMessage, queryRest, fetchNextPageInViewAnchorRef], + [messages, getMessages, setMessages, isLastMessage], ); return {children}; diff --git a/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts b/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts index 512dff3c..6551431e 100644 --- a/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts +++ b/apps/adk-ui/src/modules/messages/contexts/Messages/messages-context.ts @@ -4,11 +4,11 @@ */ 'use client'; +import type { RefObject } from 'react'; import { createContext } from 'react'; import type { Updater } from '#hooks/useImmerWithGetter.ts'; import type { UIMessage } from '#modules/messages/types.ts'; -import type { useListContextHistory } from '#modules/platform-context/api/queries/useListContextHistory.ts'; export const MessagesContext = createContext(null); @@ -18,6 +18,9 @@ export interface MessagesContextValue { getMessages: () => UIMessage[]; setMessages: Updater; queryControl: { - fetchNextPageInViewAnchorRef: (node?: Element | null) => void; - } & Omit, 'data'>; + fetchNextPageInViewAnchorRef: RefObject; + isFetching: boolean; + isFetchingNextPage: boolean; + hasNextPage: boolean; + }; } diff --git a/apps/adk-ui/src/modules/platform-context/api/constants.ts b/apps/adk-ui/src/modules/platform-context/api/constants.ts index 056f2afb..dc3d1db9 100644 --- a/apps/adk-ui/src/modules/platform-context/api/constants.ts +++ b/apps/adk-ui/src/modules/platform-context/api/constants.ts @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ListContextHistoryRequest, ListContextsRequest } from '@kagenti/adk'; +import type { ListContextsRequest } from '@kagenti/adk'; export const LIST_CONTEXTS_DEFAULT_QUERY: ListContextsRequest['query'] = { limit: 10, include_empty: false }; - -export const LIST_CONTEXT_HISTORY_DEFAULT_QUERY: ListContextHistoryRequest['query'] = { limit: 10 }; diff --git a/apps/adk-ui/src/modules/platform-context/api/index.ts b/apps/adk-ui/src/modules/platform-context/api/index.ts index 5bdc9b2c..a4e9ef47 100644 --- a/apps/adk-ui/src/modules/platform-context/api/index.ts +++ b/apps/adk-ui/src/modules/platform-context/api/index.ts @@ -7,13 +7,11 @@ import type { CreateContextRequest, CreateContextTokenRequest, DeleteContextRequest, - ListContextHistoryRequest, ListContextsRequest, } from '@kagenti/adk'; import { type MatchModelProvidersRequest, unwrapResult } from '@kagenti/adk'; import { adkClient } from '#api/adk-client.ts'; -import { fetchEntity } from '#api/utils.ts'; import type { PatchContextMetadataRequest } from './types'; import { contextSchema, listContextsResponseSchema } from './types'; @@ -39,13 +37,6 @@ export async function deleteContext(request: DeleteContextRequest) { return result; } -export async function listContextHistory(request: ListContextHistoryRequest) { - const response = await adkClient.listContextHistory(request); - const result = unwrapResult(response); - - return result; -} - export async function patchContextMetadata(request: PatchContextMetadataRequest) { const response = await adkClient.patchContextMetadata(request); const result = unwrapResult(response, contextSchema); @@ -66,7 +57,3 @@ export async function createContextToken(request: CreateContextTokenRequest) { return result; } - -export async function fetchContextHistory(request: ListContextHistoryRequest) { - return await fetchEntity(() => listContextHistory(request)); -} diff --git a/apps/adk-ui/src/modules/platform-context/api/keys.ts b/apps/adk-ui/src/modules/platform-context/api/keys.ts index 0adf0aa9..c1a8031b 100644 --- a/apps/adk-ui/src/modules/platform-context/api/keys.ts +++ b/apps/adk-ui/src/modules/platform-context/api/keys.ts @@ -3,15 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ListContextHistoryRequest, ListContextsRequest } from '@kagenti/adk'; +import type { ListContextsRequest } from '@kagenti/adk'; export const contextKeys = { all: () => ['contexts'] as const, lists: () => [...contextKeys.all(), 'list'] as const, list: ({ query = {} }: ListContextsRequest) => [...contextKeys.lists(), query] as const, - histories: () => [...contextKeys.all(), 'history'] as const, - history: ({ context_id, query = {} }: ListContextHistoryRequest) => - [...contextKeys.histories(), context_id, query] as const, tokens: () => [...contextKeys.all(), 'token'] as const, token: (contextId: string, providerId: string) => [...contextKeys.tokens(), contextId, providerId] as const, }; diff --git a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts b/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts deleted file mode 100644 index 395c0ea1..00000000 --- a/apps/adk-ui/src/modules/platform-context/api/queries/useListContextHistory.ts +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2026 © IBM Corp. - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { ListContextHistoryRequest, ListContextHistoryResponse } from '@kagenti/adk'; -import { useInfiniteQuery } from '@tanstack/react-query'; - -import type { PartialBy } from '#@types/utils.ts'; -import { isNotNull } from '#utils/helpers.ts'; - -import { listContextHistory } from '..'; -import { contextKeys } from '../keys'; - -type Params = PartialBy & { - initialData?: ListContextHistoryResponse; - enabled?: boolean; - initialPageParam?: string; -}; - -export function useListContextHistory({ - context_id, - query: queryParams, - initialData, - initialPageParam, - enabled = true, -}: Params) { - const query = useInfiniteQuery({ - queryKey: contextKeys.history({ - context_id: context_id!, - query: queryParams, - }), - queryFn: ({ pageParam }: { pageParam?: string }) => { - return listContextHistory({ - context_id: context_id!, - query: { - ...queryParams, - page_token: pageParam, - }, - }); - }, - initialPageParam, - getNextPageParam: (lastPage) => { - return lastPage?.has_more && lastPage.next_page_token ? lastPage.next_page_token : undefined; - }, - select: (data) => { - if (!data) { - return undefined; - } - - const items = data.pages.flatMap((page) => page?.items).filter(isNotNull); - - return items; - }, - enabled: Boolean(context_id) && enabled, - initialData: initialData ? { pages: [initialData], pageParams: [undefined] } : undefined, - }); - - return query; -} diff --git a/apps/adk-ui/src/modules/platform-context/api/types.ts b/apps/adk-ui/src/modules/platform-context/api/types.ts index 73705262..70a84929 100644 --- a/apps/adk-ui/src/modules/platform-context/api/types.ts +++ b/apps/adk-ui/src/modules/platform-context/api/types.ts @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ContextHistory } from '@kagenti/adk'; import { contextSchema as sdkContextSchema, listContextsResponseSchema as sdkListContextsResponseSchema, @@ -41,7 +40,3 @@ export const patchContextMetadataRequestSchema = sdkPatchContextMetadataRequestS }); export type PatchContextMetadataRequest = z.infer; -// - -export type HistoryItem = ContextHistory['data']; -export type HistoryMessage = Extract; diff --git a/apps/adk-ui/src/modules/platform-context/api/utils.ts b/apps/adk-ui/src/modules/platform-context/api/utils.ts deleted file mode 100644 index 31f021f8..00000000 --- a/apps/adk-ui/src/modules/platform-context/api/utils.ts +++ /dev/null @@ -1,10 +0,0 @@ -/** - * Copyright 2026 © IBM Corp. - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { HistoryItem, HistoryMessage } from './types'; - -export function isHistoryMessage(item: HistoryItem): item is HistoryMessage { - return 'messageId' in item; -} diff --git a/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx b/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx index 470e7096..c249b8ba 100644 --- a/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx +++ b/apps/adk-ui/src/modules/platform-context/contexts/PlatformContextProvider.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ 'use client'; -import type { ListContextHistoryResponse } from '@kagenti/adk'; +import type { Task } from '@kagenti/adk'; import { type PropsWithChildren, useCallback, useState } from 'react'; import type { Agent } from '#modules/agents/api/types.ts'; @@ -14,10 +14,10 @@ import { PlatformContext } from './platform-context'; interface Props { contextId?: string; - history?: ListContextHistoryResponse; + initialTasks?: Task[]; } -export function PlatformContextProvider({ history, contextId: contextIdProp, children }: PropsWithChildren) { +export function PlatformContextProvider({ initialTasks, contextId: contextIdProp, children }: PropsWithChildren) { const [contextId, setContextId] = useState(contextIdProp ?? null); const { mutateAsync: createContext } = useCreateContext({ @@ -61,7 +61,7 @@ export function PlatformContextProvider({ history, contextId: contextIdProp, chi ContextId; resetContext: () => void; diff --git a/apps/adk-ui/src/modules/runs/components/AgentRun.tsx b/apps/adk-ui/src/modules/runs/components/AgentRun.tsx index 45e12ed8..9d05e5e6 100644 --- a/apps/adk-ui/src/modules/runs/components/AgentRun.tsx +++ b/apps/adk-ui/src/modules/runs/components/AgentRun.tsx @@ -5,9 +5,8 @@ import { notFound } from 'next/navigation'; +import { fetchTasksForContext } from '#api/a2a/list-tasks.ts'; import { runtimeConfig } from '#contexts/App/runtime-config.ts'; -import { LIST_CONTEXT_HISTORY_DEFAULT_QUERY } from '#modules/platform-context/api/constants.ts'; -import { fetchContextHistory } from '#modules/platform-context/api/index.ts'; import { PlatformContextProvider } from '#modules/platform-context/contexts/PlatformContextProvider.tsx'; import { RunView } from '#modules/runs/components/RunView.tsx'; @@ -23,12 +22,7 @@ export async function AgentRun({ providerId, contextId }: Props) { const { featureFlags } = runtimeConfig; const agentPromise = fetchAgent(providerId); - const contextHistoryPromise = contextId - ? fetchContextHistory({ - context_id: contextId, - query: LIST_CONTEXT_HISTORY_DEFAULT_QUERY, - }) - : undefined; + const tasksPromise = contextId ? fetchTasksForContext(providerId, contextId) : undefined; const agent = await agentPromise; @@ -40,14 +34,14 @@ export async function AgentRun({ providerId, contextId }: Props) { } } - const contextHistory = await contextHistoryPromise; + const tasksResponse = await tasksPromise; - if (contextId && !contextHistory) { + if (contextId && !tasksResponse) { notFound(); } return ( - + ); diff --git a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx index 6bcaa6ba..155b5d34 100644 --- a/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx +++ b/apps/adk-ui/src/modules/runs/contexts/agent-run/AgentRunProvider.tsx @@ -292,7 +292,6 @@ function AgentRunProvider({ agent, children }: PropsWithChildren) { pendingSubscription.current = undefined; queryClient.invalidateQueries({ queryKey: contextKeys.lists() }); - queryClient.invalidateQueries({ queryKey: contextKeys.history({ context_id: contextId }) }); } }, [ diff --git a/docs/development/agent-integration/canvas.mdx b/docs/development/agent-integration/canvas.mdx index e3817925..a2bbc5b1 100644 --- a/docs/development/agent-integration/canvas.mdx +++ b/docs/development/agent-integration/canvas.mdx @@ -107,7 +107,6 @@ async def code_agent( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], canvas: Annotated[CanvasExtensionServer, CanvasExtensionSpec()], ): - await context.store(message) canvas_edit = await canvas.parse_canvas_edit_request(message=message) # Adapt system prompt based on whether this is an edit or new generation @@ -116,8 +115,6 @@ async def code_agent( artifact = await call_llm(llm, system_prompt, message) yield artifact - await context.store(artifact) - if __name__ == "__main__": server.run() diff --git a/docs/development/agent-integration/multi-turn.mdx b/docs/development/agent-integration/multi-turn.mdx index 2943dbcc..de6417a4 100644 --- a/docs/development/agent-integration/multi-turn.mdx +++ b/docs/development/agent-integration/multi-turn.mdx @@ -9,10 +9,7 @@ When building conversational AI agents, one of the key requirements is maintaini | Operation | Purpose | | :--- | :--- | -| **await context.store(input)** | Stores current user message in conversation history. Storage of messages must be explicitly requested| -| **await context.store(response)** | Stores agent’s responses in conversation history, and must be explicitly requested | -| **context: RunContext)** | Sets up a RunContext instance for storing and accessing the conversation history | -| **context_store=PlatformContextStore()** | Configures server to use the platform’s persistent context store to maintain conversation history across agent restarts | +| **context: RunContext)** | Sets up a RunContext instance for accessing the conversation history | ## Simple History Access Example @@ -40,9 +37,6 @@ server = Server() async def basic_history_example(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -57,9 +51,6 @@ async def basic_history_example(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) @@ -72,10 +63,8 @@ if __name__ == "__main__": ### Steps -1. **Access conversation history:** Use `RunContext` to set up an instance of the conversation history to store and load previous messages. -1. **Store incoming messages:** Use `await context.store(input)` to store the current user message in the conversation history. +1. **Access conversation history:** Use `RunContext` to set up an instance of the conversation history to load previous messages. 1. **Filter and process history:** Retrieve the conversation history with `load_history()` and filter to get the messages relevant to your agent's logic. -1. **Store agent responses:** Use `await context.store(response)` to store your agent's responses for future conversation context. ## Streaming with Buffered History Example @@ -90,12 +79,11 @@ This is usecase-specific and one my opt for a combination of this and previous a import asyncio import os -from a2a.types import Message, Role +from a2a.types import Message from a2a.utils.message import get_message_text from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -108,8 +96,8 @@ async def example_tool() -> str: async def history_counter(history: list[Message]) -> str: """Create a concise conversation-state summary.""" await asyncio.sleep(.1) # doing some agent work - user_count = sum(1 for item in history if item.role == Role.ROLE_USER) - agent_count = sum(1 for item in history if item.role == Role.ROLE_AGENT) + user_count = sum(1 for item in history if item.role.value == "user") + agent_count = sum(1 for item in history if item.role.value == "agent") history_count = len(history) return f"total={history_count}, user={user_count}, agent={agent_count}" @@ -120,9 +108,6 @@ async def streaming_agent_w_single_history_write_example(input: Message, context Stream partial answers, execute tools, and persist one finalized assistant message. See other examples for actual implementation of multi-turn conversation agent with tool use. """ - # Store the user input as the first persisted item for this turn. - await context.store(data=input) - history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] current_message = get_message_text(input) @@ -159,20 +144,14 @@ async def streaming_agent_w_single_history_write_example(input: Message, context # This does not need to be the go-to approach in all cases, sometimes the partial outputs are of no value and one does not want them to be properly stored. # # Why not store each chunk? - # - Calling `context.store()`, PlatformContextStore saves every message as a distinct history item. - # - Storing per chunk would fragment one assistant turn into many partial messages. - # - A single aggregated write keeps replay, memory, and history semantics clean. - # aggregated_response = AgentMessage(text="\n".join(buffered_parts)) - yield "Final result check:\n" + aggregated_response.parts[0].text - await context.store(data=aggregated_response) + yield "Final result check:\n" + str(aggregated_response.text) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), ) @@ -183,9 +162,8 @@ if __name__ == "__main__": ### When to use buffering -- Use **simple yield + store** when your agent emits a single final response. -- Use **stream + buffer + single store** when your agent emits multiple partial chunks which are streamed to the user. -- With `PlatformContextStore`, each `context.store()` call creates a persisted history item, so buffering prevents chunk-level history fragmentation. +- Use **simple yield** when your agent emits a single final response. +- Use **stream + buffer** when your agent emits multiple partial chunks which are streamed to the user. ## Advanced BeeAI Framework Example @@ -205,7 +183,6 @@ from kagenti_adk.a2a.extensions import LLMServiceExtensionServer, LLMServiceExte from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ConditionalRequirement @@ -219,11 +196,11 @@ FrameworkMessage = UserMessage | AssistantMessage def to_framework_message(message: Message) -> FrameworkMessage: """Convert A2A Message to Kagenti ADK Message format""" - message_text = "".join(part.text for part in message.parts if part.text) + message_text = "".join(part.root.text for part in message.parts if part.root.kind == "text") - if message.role == Role.ROLE_AGENT: + if message.role == Role.agent: return AssistantMessage(message_text) - elif message.role == Role.ROLE_USER: + elif message.role == Role.user: return UserMessage(message_text) else: raise ValueError(f"Invalid message role: {message.role}") @@ -236,8 +213,6 @@ async def advanced_history_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -273,14 +248,12 @@ async def advanced_history_example( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) @@ -295,31 +268,9 @@ This advanced example demonstrates several key concepts: - **Framework Integration:** Leverages the BeeAI Framework for sophisticated agent capabilities - **Memory Management:** Converts conversation history to framework format and loads it into agent memory - **Tool Usage:** Includes thinking tools and conditional requirements for better reasoning -- **Persistent Storage:** Uses `PlatformContextStore` for conversation persistence ## Using Content History -### Persistent Storage Example - -By default, conversation history is stored in memory and is lost when the agent process restarts. For production applications, you'll want to use persistent context storage to maintain conversation history across agent restarts. The `PlatformContextStore` automatically handles conversation persistence, ensuring that users can continue their conversations even after agent restarts or deployments. - -```python -import os -from kagenti_adk.server import Server -from kagenti_adk.server.store.platform_context_store import PlatformContextStore - -server = Server() - -def run(): - server.run( - host=os.getenv("HOST", "127.0.0.1"), - port=int(os.getenv("PORT", 8000)), - context_store=PlatformContextStore() - ) -``` - - - ### History Contents The `context.load_history()` method returns an async iterator containing all items in the conversation, including the current message. This can include: @@ -339,49 +290,7 @@ The history includes the current message, so if you want only previous messages, The history iterator returns all message types. Always filter messages using `isinstance(message, Message)` to ensure you're working with the correct message format. -### Editing and Removing Messages from History - -Sometimes you may need to edit a previous message in a conversation or remove messages that are no longer relevant. -The Kagenti ADK provides a mechanism to delete history items from a specific point onward, allowing you to effectively “rewind” the conversation and replace a message with an edited version. -Possible use cases include editing a previous message, clearing irrelevant exchanges, or removing messages that resulted from processing errors. - -Here's an example of a function for editing a user message in a conversation using the context API. This assumes you know the context message id, which can be obtained as an id field of an object returned by `RunContext.load_history(load_history_items=True)`, `Context.list_history` or `Context.list_all_history`. - -```python -import uuid -from uuid import UUID -from typing import Any - -from a2a.types import Message, Part, Role -from kagenti_adk.platform.context import Context -from kagenti_adk.server.context import RunContext - -async def edit_message_in_context(run_context: RunContext, id: UUID, new_text: str): - # Step 1: Delete from this message onwards - await run_context.delete_history_from_id(from_id=id) - - # Step 2: Create the corrected message - corrected_message = Message( - message_id=str(uuid.uuid4()), - parts=[Part(text=new_text)], - role=Role.ROLE_USER, - ) - - # Step 3: Store the corrected message - await run_context.store(data=corrected_message) -``` - - -When you delete history from a specific message onwards, all messages created after that point (including the message itself) are removed. This effectively creates a new conversation branch starting from the message before the deleted one. - - - -This operation is permanent. Once messages are deleted, they cannot be recovered. Consider informing users about this operation or implementing a confirmation step for important conversations. - - -### Message Storage Guidelines - -Since messages are not automatically stored, you need to explicitly call `context.store()` for any message you want to be available in future interactions. Here are the key guidelines: +### Message History #### Store Request Example @@ -389,15 +298,9 @@ Since messages are not automatically stored, you need to explicitly call `contex ```python @server.agent() async def my_agent(input: Message, context: RunContext): - # Store the incoming user message immediately - await context.store(input) - # Process the message and generate response response = AgentMessage(text="Your response here") yield response - - # Store the agent's response after yielding - await context.store(response) ``` #### What to Store diff --git a/docs/development/agent-integration/rag.mdx b/docs/development/agent-integration/rag.mdx index a27505cc..75543a7f 100644 --- a/docs/development/agent-integration/rag.mdx +++ b/docs/development/agent-integration/rag.mdx @@ -357,7 +357,7 @@ import os from collections.abc import AsyncGenerator from typing import Annotated -from a2a.types import Message, Part +from a2a.types import FilePart, FileWithUri, Message, TextPart from kagenti_adk.a2a.extensions import ( EmbeddingServiceExtensionServer, EmbeddingServiceExtensionSpec, @@ -410,13 +410,13 @@ async def simple_rag_agent_example( files: list[File] = [] query = "" for part in input.parts: - content_type = part.WhichOneof("content") - if content_type == "url": - files.append(await File.get(PlatformFileUrl(part.url).file_id)) - elif content_type == "text": - query = part.text - else: - raise NotImplementedError(f"Unsupported part content type: {content_type}") + match part.root: + case FilePart(file=FileWithUri(uri=uri)): + files.append(await File.get(PlatformFileUrl(uri).file_id)) + case TextPart(text=text): + query = text + case _: + raise NotImplementedError(f"Unsupported part: {type(part.root)}") if not files or not query: raise ValueError("No files or query provided") @@ -469,7 +469,7 @@ import json import os from typing import Annotated -from a2a.types import Message, Part +from a2a.types import DataPart, FilePart, FileWithUri, Message, Part, TextPart from kagenti_adk.a2a.extensions import ( EmbeddingServiceExtensionServer, EmbeddingServiceExtensionSpec, @@ -528,30 +528,26 @@ async def conversation_rag_agent_example( files: list[File] = [] query = "" for part in input.parts: - content_type = part.WhichOneof("content") - if content_type == "url": - files.append(await File.get(PlatformFileUrl(part.url).file_id)) - elif content_type == "text": - query = part.text - else: - raise NotImplementedError(f"Unsupported part content type: {content_type}") + match part.root: + case FilePart(file=FileWithUri(uri=uri)): + files.append(await File.get(PlatformFileUrl(uri).file_id)) + case TextPart(text=text): + query = text + case _: + raise NotImplementedError(f"Unsupported part: {type(part.root)}") # Check if vector store exists vector_store = None async for message in context.load_history(): - if isinstance(message, Message) and len(message.parts) == 1 and message.parts[0].WhichOneof("content") == "data": - data = dict(message.parts[0].data.struct_value) - vector_store = await VectorStore.get(data["vector_store_id"]) + match message: + case Message(parts=[Part(root=DataPart(data=data))]): + vector_store = await VectorStore.get(data["vector_store_id"]) # Create vector store if it does not exist if not vector_store: vector_store = await create_vector_store(embedding_client, embedding_model) # store vector store id in context for future messages - from google.protobuf.struct_pb2 import Value - - data_value = Value() - data_value.struct_value.update({"vector_store_id": vector_store.id}) - await context.store(AgentMessage(parts=[Part(data=data_value)])) + data_part = DataPart(data={"vector_store_id": vector_store.id}) # Process files, add to vector store for file in files: diff --git a/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py b/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py index 48629cb5..f1bb3b72 100644 --- a/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py +++ b/examples/agent-integration/canvas/canvas-with-llm/src/canvas_with_llm/agent.py @@ -75,7 +75,6 @@ async def canvas_with_llm_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], canvas: Annotated[CanvasExtensionServer, CanvasExtensionSpec()], ): - await context.store(message) canvas_edit = await canvas.parse_canvas_edit_request(message=message) # Adapt system prompt based on whether this is an edit or new generation @@ -84,8 +83,6 @@ async def canvas_with_llm_example( artifact = await call_llm(llm, system_prompt, message) yield artifact - await context.store(artifact) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py b/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py index ae38bbde..9ff8ecb4 100644 --- a/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py +++ b/examples/agent-integration/multi-turn/advanced-history/src/advanced_history/agent.py @@ -10,7 +10,6 @@ from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore from beeai_framework.adapters.agentstack.backend.chat import AgentStackChatModel from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ConditionalRequirement @@ -41,8 +40,6 @@ async def advanced_history_example( llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], ): """Multi-turn chat agent with conversation memory and LLM integration""" - await context.store(input) - # Load conversation history history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] @@ -78,14 +75,12 @@ async def advanced_history_example( response = AgentMessage(text=step.input["response"]) yield response - await context.store(response) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), # Enable persistent storage ) diff --git a/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py b/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py index f17cc5e5..fc38c9df 100644 --- a/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py +++ b/examples/agent-integration/multi-turn/basic-history/src/basic_history/agent.py @@ -17,9 +17,6 @@ async def basic_history_example(input: Message, context: RunContext): """Agent that demonstrates conversation history access""" - # Store the current message in the context store - await context.store(input) - # Get the current user message current_message = get_message_text(input) print(f"Current message: {current_message}") @@ -34,9 +31,6 @@ async def basic_history_example(input: Message, context: RunContext): message = AgentMessage(text=f"Hello! I can see we have {len(history)} messages in our conversation.") yield message - # Store the message in the context store - await context.store(message) - def run(): server.run(host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", 8000))) diff --git a/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py b/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py index 998e9133..5bf21dbc 100644 --- a/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py +++ b/examples/agent-integration/multi-turn/streaming-agent-history/src/streaming_agent_history/agent.py @@ -3,12 +3,11 @@ import asyncio import os -from a2a.types import Message, Role +from a2a.types import Message from a2a.utils.message import get_message_text from kagenti_adk.a2a.types import AgentMessage from kagenti_adk.server import Server from kagenti_adk.server.context import RunContext -from kagenti_adk.server.store.platform_context_store import PlatformContextStore server = Server() @@ -21,8 +20,8 @@ async def example_tool() -> str: async def history_counter(history: list[Message]) -> str: """Create a concise conversation-state summary.""" await asyncio.sleep(.1) # doing some agent work - user_count = sum(1 for item in history if item.role == Role.ROLE_USER) - agent_count = sum(1 for item in history if item.role == Role.ROLE_AGENT) + user_count = sum(1 for item in history if item.role.value == "user") + agent_count = sum(1 for item in history if item.role.value == "agent") history_count = len(history) return f"total={history_count}, user={user_count}, agent={agent_count}" @@ -33,9 +32,6 @@ async def streaming_agent_w_single_history_write_example(input: Message, context Stream partial answers, execute tools, and persist one finalized assistant message. See other examples for actual implementation of multi-turn conversation agent with tool use. """ - # Store the user input as the first persisted item for this turn. - await context.store(data=input) - history = [message async for message in context.load_history() if isinstance(message, Message) and message.parts] current_message = get_message_text(input) @@ -72,20 +68,14 @@ async def streaming_agent_w_single_history_write_example(input: Message, context # This does not need to be the go-to approach in all cases, sometimes the partial outputs are of no value and one does not want them to be properly stored. # # Why not store each chunk? - # - Calling `context.store()`, PlatformContextStore saves every message as a distinct history item. - # - Storing per chunk would fragment one assistant turn into many partial messages. - # - A single aggregated write keeps replay, memory, and history semantics clean. - # aggregated_response = AgentMessage(text="\n".join(buffered_parts)) - yield "Final result check:\n" + aggregated_response.parts[0].text - await context.store(data=aggregated_response) + yield "Final result check:\n" + str(aggregated_response.text) def run(): server.run( host=os.getenv("HOST", "127.0.0.1"), port=int(os.getenv("PORT", "8000")), - context_store=PlatformContextStore(), ) diff --git a/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py b/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py index f392d162..9d1bc9fa 100644 --- a/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py +++ b/examples/agent-integration/rag/conversation-rag-agent/src/conversation_rag_agent/agent.py @@ -6,7 +6,7 @@ import os from typing import Annotated -from a2a.types import Message, Part +from a2a.types import DataPart, FilePart, FileWithUri, Message, Part, TextPart from kagenti_adk.a2a.extensions import ( EmbeddingServiceExtensionServer, EmbeddingServiceExtensionSpec, @@ -65,30 +65,26 @@ async def conversation_rag_agent_example( files: list[File] = [] query = "" for part in input.parts: - content_type = part.WhichOneof("content") - if content_type == "url": - files.append(await File.get(PlatformFileUrl(part.url).file_id)) - elif content_type == "text": - query = part.text - else: - raise NotImplementedError(f"Unsupported part content type: {content_type}") + match part.root: + case FilePart(file=FileWithUri(uri=uri)): + files.append(await File.get(PlatformFileUrl(uri).file_id)) + case TextPart(text=text): + query = text + case _: + raise NotImplementedError(f"Unsupported part: {type(part.root)}") # Check if vector store exists vector_store = None async for message in context.load_history(): - if isinstance(message, Message) and len(message.parts) == 1 and message.parts[0].WhichOneof("content") == "data": - data = dict(message.parts[0].data.struct_value) - vector_store = await VectorStore.get(data["vector_store_id"]) + match message: + case Message(parts=[Part(root=DataPart(data=data))]): + vector_store = await VectorStore.get(data["vector_store_id"]) # Create vector store if it does not exist if not vector_store: vector_store = await create_vector_store(embedding_client, embedding_model) # store vector store id in context for future messages - from google.protobuf.struct_pb2 import Value - - data_value = Value() - data_value.struct_value.update({"vector_store_id": vector_store.id}) - await context.store(AgentMessage(parts=[Part(data=data_value)])) + data_part = DataPart(data={"vector_store_id": vector_store.id}) # Process files, add to vector store for file in files: diff --git a/skills/kagenti-adk-wrapper/SKILL.md b/skills/kagenti-adk-wrapper/SKILL.md index bded669b..4675d6d5 100644 --- a/skills/kagenti-adk-wrapper/SKILL.md +++ b/skills/kagenti-adk-wrapper/SKILL.md @@ -159,7 +159,7 @@ Read the agent's code and classify it. This determines the `interaction_mode` va This classification determines: -- How to use `context.store()` and `context.load_history()`: persist input/response by default for all agents; `context.load_history()` is required for multi-turn, and optional for single-turn (use only when prior context is intentionally part of behavior) +- How to use `context.load_history()`: history is auto-persisted by the A2A TaskStore; `context.load_history()` is required for multi-turn, and optional for single-turn (use only when prior context is intentionally part of behavior) - Whether to define an `initial_form` for structured inputs (single-turn with named parameters) --- @@ -277,8 +277,7 @@ When building and testing the wrapper, ensure you avoid these common pitfalls: - **Never use synchronous functions for the agent handler.** Agent functions must be `async def` generators using `yield`. - **Never hide platform wiring behind abstraction layers.** Keep `@server.agent(...)`, extension parameters, and integration contracts visible in the main entrypoint so behavior is auditable. - **Never treat runtime inspection as first source.** `kagenti_adk` and `a2a` details must come from provided docs first; use installed-environment inspection only as documented fallback, then validate imports at the end. -- **Never assume history is auto-saved.** Explicitly call `await context.store(input)` and `await context.store(response)`. -- **Never assume persistent history without `PlatformContextStore`.** Without it, context storage is in-memory and lost on restart. +- **History is auto-saved by the A2A framework.** Messages and artifacts are persisted in the A2A TaskStore automatically — do not manually store them. - **Never forget to filter history.** `context.load_history()` returns Messages and Artifacts. Filter with `isinstance(message, Message)`. - **Never store individual streaming chunks.** Accumulate the full response and store once. - **Never treat extension data as dictionaries.** Use dot notation (e.g., `config.api_key`, not `config.get("api_key")`). @@ -343,9 +342,7 @@ After wrapping, confirm: ### Context & History -- [ ] `input` and `response` stored via `context.store()` -- [ ] `context_store=PlatformContextStore()` present if context is persisted/read -- [ ] Multi-turn uses `context.load_history()`; single-turn only if intentionally needed +- [ ] Multi-turn uses `context.load_history()` to read conversation history from the A2A TaskStore ### Forms & Files diff --git a/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md b/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md index 218538ec..0820e65a 100644 --- a/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md +++ b/skills/kagenti-adk-wrapper/references/wrapper-entrypoint.md @@ -49,18 +49,15 @@ Based on the classification in Step 2, follow exactly ONE of these workflows: - [ ] Pass necessary inputs (from forms or text) to original agent logic - [ ] Yield trajectory for meaningful intermediate activity (same rule as all agents) - [ ] Yield the final response via `AgentMessage(text=result)` -- [ ] Persist both input and response via `context.store()` ``` ### If the agent is Multi-turn ``` -- [ ] Store input: Save incoming user message immediately with `await context.store(input)` - [ ] Load history: Retrieve past conversation via `[msg async for msg in context.load_history() if isinstance(msg, Message)]` - [ ] Execute agent: Pass the filtered history to the original agent logic - [ ] Yield trajectory for meaningful intermediate activity (same rule as all agents) - [ ] Yield response: Return final answering chunks with `yield AgentMessage(text=...)` -- [ ] Store response: Save the final response with `await context.store(response)` ``` ## Entrypoint @@ -68,6 +65,6 @@ Based on the classification in Step 2, follow exactly ONE of these workflows: Create a `run()` / `serve()` function protected by an `if __name__ == "__main__":` guard. This function should call `server.run()`: - The server should be configured to listen on a `host` and `port` from environment variables (e.g., `host=os.getenv("HOST", "127.0.0.1")`, `port=int(os.getenv("PORT", 8000))`). -- If the agent persists or reads context history, you must pass `context_store=PlatformContextStore()` to `server.run()`. +- Conversation history is automatically persisted in the A2A TaskStore. Use `context.load_history()` to read it. - **Remove all CLI argument parsing** (`argparse`). Map required CLI inputs to the wrapper parameters instead (e.g., from Forms, Settings, or Environment variables). - Only `auth_backend` if explicitly requested.