diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index a76e3f4..a931d29 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -75,6 +75,16 @@ jobs: cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max + - name: Verify — pytest collection in runtime image + run: | + SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7) + docker run --rm \ + -v "${{ github.workspace }}:/repo" \ + -w /repo \ + --entrypoint sh \ + "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA}" \ + -c 'pip install -q pytest pytest-asyncio && python -m pytest tests/ --collect-only -q' + - name: Summary run: | echo "## Published" >> "$GITHUB_STEP_SUMMARY" diff --git a/graph/agent.py b/graph/agent.py index 2e26edb..355c3fc 100644 --- a/graph/agent.py +++ b/graph/agent.py @@ -29,7 +29,7 @@ def _build_middleware(config: LangGraphConfig, knowledge_store=None): if config.audit_middleware: middleware.append(AuditMiddleware()) - if config.memory_middleware and knowledge_store: + if config.memory_middleware: middleware.append(MemoryMiddleware(knowledge_store)) middleware.append(MessageCaptureMiddleware()) diff --git a/graph/middleware/memory.py b/graph/middleware/memory.py index 3871e44..1352fec 100644 --- a/graph/middleware/memory.py +++ b/graph/middleware/memory.py @@ -166,15 +166,133 @@ class MemoryMiddleware(AgentMiddleware): Also persists a session summary on session end via on_session_end. """ - def __init__(self, knowledge_store): + def __init__(self, knowledge_store=None): super().__init__() self._store = knowledge_store + self._prior_sessions_cache: str | None = None + + # --- Session memory loading (only used when no KnowledgeMiddleware is active) --- + + def _load_prior_sessions(self) -> str: + """Lazy-load prior session summaries when standalone (no KnowledgeMiddleware). + + When KnowledgeMiddleware is also in the chain it owns `` + injection. This method runs only when `self._store is None`, so there is + no double-injection risk. + + Reads from MEMORY_PATH, returns an XML block or empty string on first + run. Mirrors KnowledgeMiddleware.load_memory() but without the store + dependency — single source of truth would be cleaner but would couple + the two files. + """ + if not os.path.isdir(MEMORY_PATH): + return "" + try: + entries = [] + for fname in os.listdir(MEMORY_PATH): + if not fname.endswith(".json"): + continue + fpath = os.path.join(MEMORY_PATH, fname) + try: + entries.append((os.path.getmtime(fpath), fpath)) + except OSError: + continue + entries.sort(reverse=True) + except OSError: + return "" + if not entries: + return "" + summaries = [] + for _, fpath in entries[:10]: + try: + with open(fpath, encoding="utf-8") as fh: + summaries.append(json.load(fh)) + except (OSError, json.JSONDecodeError, ValueError): + continue + if not summaries: + return "" + lines_out = [] + for s in summaries: + ts = s.get("timestamp", "unknown") + sid = s.get("session_id", "unknown") + lines = [f''] + msgs = s.get("messages", []) or [] + if msgs: + lines.append(" ") + for m in msgs: + role = m.get("role", "unknown") + content = (m.get("content", "") or "")[:500] + lines.append(f" <{role}>{content}") + lines.append(" ") + final = (s.get("final_output") or "")[:300] + if final: + lines.append(f" {final}") + lines.append("") + lines_out.append("\n".join(lines)) + # 2K token budget — chars // 4 approx, drop oldest first + while lines_out: + joined = "\n".join(lines_out) + if max(1, len(joined) // 4) <= 2000: + break + lines_out.pop() + if not lines_out: + return "" + return "\n" + "\n".join(lines_out) + "\n" + + def before_model(self, state, runtime) -> dict | None: + """Inject `` into system prompt when running standalone. + + When KnowledgeMiddleware is present it handles this; we only act when + `self._store is None`. + """ + if self._store is not None: + return None + if self._prior_sessions_cache is None: + self._prior_sessions_cache = self._load_prior_sessions() + if not self._prior_sessions_cache: + return None + messages = state.get("messages", []) + if not messages: + return None + # Prepend as a system-adjacent HumanMessage block. LangGraph has no + # dedicated system-context append hook on state, so we piggyback on + # the first human message by modifying its content. + from langchain_core.messages import SystemMessage + first = messages[0] + if isinstance(first, SystemMessage): + # Already has a system message — append prior_sessions to it + new_content = first.content + "\n\n" + self._prior_sessions_cache + new_msgs = [SystemMessage(content=new_content)] + list(messages[1:]) + return {"messages": new_msgs} + # Otherwise prepend a new SystemMessage + new_msgs = [SystemMessage(content=self._prior_sessions_cache)] + list(messages) + return {"messages": new_msgs} + + async def abefore_model(self, state, runtime) -> dict | None: + return self.before_model(state, runtime) # --- Knowledge extraction (existing) --- def after_agent(self, state, runtime) -> dict | None: - """Queue conversation for async knowledge extraction.""" + """Queue conversation for async knowledge extraction. Persists session on terminal turn.""" messages = state.get("messages", []) + + # --- Session persistence: detect terminal turn --- + # Terminal = last message is AIMessage with content and no pending tool calls + if messages: + last_msg = messages[-1] + if ( + isinstance(last_msg, AIMessage) + and last_msg.content + and not getattr(last_msg, "tool_calls", None) + ): + import tracing + trace_id = tracing.current_trace_id() + _persist_session(state, trace_id) + + # --- Knowledge extraction (only when a store is configured) --- + if self._store is None: + return None if len(messages) < 2: return None diff --git a/pyproject.toml b/pyproject.toml index e63ed8b..b730fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,4 @@ requires-python = ">=3.11" [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +pythonpath = ["."] diff --git a/requirements.txt b/requirements.txt index 35710b1..9cb6ff6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ websockets>=12.0 # LangGraph agent backend langchain>=1.2.3 +langchain-core>=0.3.0 langgraph>=1.1.0 langchain-openai>=0.3.0 diff --git a/server.py b/server.py index d5d5f4a..2221b11 100644 --- a/server.py +++ b/server.py @@ -251,6 +251,14 @@ async def _chat_langgraph(message: str, session_id: str) -> list[dict[str, Any]] AGENT_NAME = os.environ.get("AGENT_NAME", "protoagent") +def _build_security_schemes() -> dict: + """Return securitySchemes dict, adding bearer only when A2A_AUTH_TOKEN is set.""" + schemes: dict = {"apiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"}} + if os.environ.get("A2A_AUTH_TOKEN", ""): + schemes["bearer"] = {"type": "http", "scheme": "bearer"} + return schemes + + def _build_agent_card(host: str) -> dict: """Build the A2A agent card served at /.well-known/agent-card.json. @@ -308,9 +316,7 @@ def _build_agent_card(host: str) -> dict: "examples": ["hello", "what can you do?"], }, ], - "securitySchemes": { - "apiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"} - }, + "securitySchemes": _build_security_schemes(), "security": [{"apiKey": []}], } diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..165a0f6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +"""Ensure deterministic import resolution for the protoagent test suite. + +Moves site-packages to the front of sys.path so installed packages +(langchain_core, langchain, etc.) are never shadowed by local directories +that pytest inserts during collection. +""" +from __future__ import annotations + +import site +import sys + + +def pytest_configure(config): # noqa: ARG001 + """Prepend site-packages to sys.path before any test imports occur.""" + site_dirs = site.getsitepackages() + for sp in reversed(site_dirs): + if sp in sys.path: + sys.path.remove(sp) + sys.path.insert(0, sp) diff --git a/tests/test_a2a_integration.py b/tests/test_a2a_integration.py index 569df07..9d77b3f 100644 --- a/tests/test_a2a_integration.py +++ b/tests/test_a2a_integration.py @@ -54,6 +54,29 @@ def test_agent_card_has_at_least_one_skill() -> None: assert "description" in skill +def test_agent_card_no_bearer_when_token_unset(monkeypatch) -> None: + """With A2A_AUTH_TOKEN unset, card must NOT advertise bearer scheme.""" + monkeypatch.delenv("A2A_AUTH_TOKEN", raising=False) + from server import _build_agent_card + + card = _build_agent_card("protoagent:7870") + schemes = card.get("securitySchemes", {}) + assert "apiKey" in schemes, "apiKey scheme must always be present" + assert "bearer" not in schemes, "bearer must not appear when A2A_AUTH_TOKEN is unset" + + +def test_agent_card_bearer_when_token_set(monkeypatch) -> None: + """With A2A_AUTH_TOKEN set, card must advertise bearer scheme.""" + monkeypatch.setenv("A2A_AUTH_TOKEN", "secret-test-token") + from server import _build_agent_card + + card = _build_agent_card("protoagent:7870") + schemes = card.get("securitySchemes", {}) + assert "apiKey" in schemes, "apiKey scheme must always be present" + assert "bearer" in schemes, "bearer must appear when A2A_AUTH_TOKEN is set" + assert schemes["bearer"] == {"type": "http", "scheme": "bearer"} + + def test_agent_card_declares_cost_v1_extension() -> None: """The runtime captures token usage on `on_chat_model_end` and the A2A handler emits a cost-v1 DataPart on every terminal task. The diff --git a/tests/test_memory_persistence.py b/tests/test_memory_persistence.py index 28bfc17..d7df6a0 100644 --- a/tests/test_memory_persistence.py +++ b/tests/test_memory_persistence.py @@ -392,3 +392,73 @@ def test_on_session_end_calls_persist_session(tmp_path): mock_persist.assert_called_once_with(state, "trace-hook") assert result is None + + +# --------------------------------------------------------------------------- +# 10. after_agent persistence on terminal turn +# --------------------------------------------------------------------------- + +def test_after_agent_persists_on_terminal_turn(tmp_path): + """after_agent must call _persist_session when last message is a terminal AIMessage.""" + mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""}) + + store = MagicMock() + mw = mod.MemoryMiddleware(knowledge_store=store) + + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Final answer with no pending tool calls."), + ] + state = _make_state("after-agent-terminal", messages=messages) + runtime = MagicMock() + + with patch.object(mod, "_persist_session") as mock_persist, \ + patch("tracing.current_trace_id", return_value="trace-after"): + mw.after_agent(state, runtime) + + mock_persist.assert_called_once_with(state, "trace-after") + + +def test_after_agent_does_not_persist_when_tool_calls_pending(tmp_path): + """after_agent must NOT persist when the last AIMessage has pending tool_calls.""" + mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""}) + + store = MagicMock() + mw = mod.MemoryMiddleware(knowledge_store=store) + + ai_msg = AIMessage(content="") + ai_msg.tool_calls = [{"id": "tc1", "name": "search", "args": {"query": "x"}}] + messages = [ + HumanMessage(content="Search for x"), + ai_msg, + ] + state = _make_state("after-agent-pending", messages=messages) + runtime = MagicMock() + + with patch.object(mod, "_persist_session") as mock_persist, \ + patch("tracing.current_trace_id", return_value="trace-pending"): + mw.after_agent(state, runtime) + + mock_persist.assert_not_called() + + +def test_after_agent_does_not_persist_when_last_msg_not_ai(tmp_path): + """after_agent must NOT persist when the last message is not an AIMessage.""" + mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""}) + + store = MagicMock() + mw = mod.MemoryMiddleware(knowledge_store=store) + + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Some response."), + HumanMessage(content="Follow-up question"), + ] + state = _make_state("after-agent-human-last", messages=messages) + runtime = MagicMock() + + with patch.object(mod, "_persist_session") as mock_persist, \ + patch("tracing.current_trace_id", return_value="trace-human"): + mw.after_agent(state, runtime) + + mock_persist.assert_not_called()