Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/docker-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ jobs:
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max

- name: Verify — pytest collection in runtime image
run: |
SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7)
docker run --rm \
-v "${{ github.workspace }}:/repo" \
-w /repo \
--entrypoint sh \
"${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA}" \
-c 'pip install -q pytest pytest-asyncio && python -m pytest tests/ --collect-only -q'

- name: Summary
run: |
echo "## Published" >> "$GITHUB_STEP_SUMMARY"
Expand Down
2 changes: 1 addition & 1 deletion graph/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _build_middleware(config: LangGraphConfig, knowledge_store=None):
if config.audit_middleware:
middleware.append(AuditMiddleware())

if config.memory_middleware and knowledge_store:
if config.memory_middleware:
middleware.append(MemoryMiddleware(knowledge_store))

middleware.append(MessageCaptureMiddleware())
Expand Down
122 changes: 120 additions & 2 deletions graph/middleware/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,133 @@ class MemoryMiddleware(AgentMiddleware):
Also persists a session summary on session end via on_session_end.
"""

def __init__(self, knowledge_store):
def __init__(self, knowledge_store=None):
super().__init__()
self._store = knowledge_store
self._prior_sessions_cache: str | None = None

# --- Session memory loading (only used when no KnowledgeMiddleware is active) ---

def _load_prior_sessions(self) -> str:
"""Lazy-load prior session summaries when standalone (no KnowledgeMiddleware).

When KnowledgeMiddleware is also in the chain it owns `<prior_sessions>`
injection. This method runs only when `self._store is None`, so there is
no double-injection risk.

Reads from MEMORY_PATH, returns an XML block or empty string on first
run. Mirrors KnowledgeMiddleware.load_memory() but without the store
dependency — single source of truth would be cleaner but would couple
the two files.
"""
if not os.path.isdir(MEMORY_PATH):
return ""
try:
entries = []
for fname in os.listdir(MEMORY_PATH):
if not fname.endswith(".json"):
continue
fpath = os.path.join(MEMORY_PATH, fname)
try:
entries.append((os.path.getmtime(fpath), fpath))
except OSError:
continue
entries.sort(reverse=True)
except OSError:
return ""
if not entries:
return "<prior_sessions/>"
summaries = []
for _, fpath in entries[:10]:
try:
with open(fpath, encoding="utf-8") as fh:
summaries.append(json.load(fh))
except (OSError, json.JSONDecodeError, ValueError):
continue
if not summaries:
return "<prior_sessions/>"
lines_out = []
for s in summaries:
ts = s.get("timestamp", "unknown")
sid = s.get("session_id", "unknown")
lines = [f'<session id="{sid}" timestamp="{ts}">']
msgs = s.get("messages", []) or []
if msgs:
lines.append(" <messages>")
for m in msgs:
role = m.get("role", "unknown")
content = (m.get("content", "") or "")[:500]
lines.append(f" <{role}>{content}</{role}>")
lines.append(" </messages>")
final = (s.get("final_output") or "")[:300]
if final:
lines.append(f" <final_output>{final}</final_output>")
lines.append("</session>")
lines_out.append("\n".join(lines))
# 2K token budget — chars // 4 approx, drop oldest first
while lines_out:
joined = "\n".join(lines_out)
if max(1, len(joined) // 4) <= 2000:
break
lines_out.pop()
if not lines_out:
return "<prior_sessions/>"
return "<prior_sessions>\n" + "\n".join(lines_out) + "\n</prior_sessions>"

def before_model(self, state, runtime) -> dict | None:
"""Inject `<prior_sessions>` into system prompt when running standalone.

When KnowledgeMiddleware is present it handles this; we only act when
`self._store is None`.
"""
if self._store is not None:
return None
if self._prior_sessions_cache is None:
self._prior_sessions_cache = self._load_prior_sessions()
if not self._prior_sessions_cache:
return None
messages = state.get("messages", [])
if not messages:
return None
# Prepend as a system-adjacent HumanMessage block. LangGraph has no
# dedicated system-context append hook on state, so we piggyback on
# the first human message by modifying its content.
from langchain_core.messages import SystemMessage
first = messages[0]
if isinstance(first, SystemMessage):
# Already has a system message — append prior_sessions to it
new_content = first.content + "\n\n" + self._prior_sessions_cache
new_msgs = [SystemMessage(content=new_content)] + list(messages[1:])
return {"messages": new_msgs}
# Otherwise prepend a new SystemMessage
new_msgs = [SystemMessage(content=self._prior_sessions_cache)] + list(messages)
return {"messages": new_msgs}

async def abefore_model(self, state, runtime) -> dict | None:
return self.before_model(state, runtime)

# --- Knowledge extraction (existing) ---

def after_agent(self, state, runtime) -> dict | None:
"""Queue conversation for async knowledge extraction."""
"""Queue conversation for async knowledge extraction. Persists session on terminal turn."""
messages = state.get("messages", [])

# --- Session persistence: detect terminal turn ---
# Terminal = last message is AIMessage with content and no pending tool calls
if messages:
last_msg = messages[-1]
if (
isinstance(last_msg, AIMessage)
and last_msg.content
and not getattr(last_msg, "tool_calls", None)
):
import tracing
trace_id = tracing.current_trace_id()
_persist_session(state, trace_id)

# --- Knowledge extraction (only when a store is configured) ---
if self._store is None:
return None
if len(messages) < 2:
return None

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ requires-python = ">=3.11"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
pythonpath = ["."]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ websockets>=12.0

# LangGraph agent backend
langchain>=1.2.3
langchain-core>=0.3.0
langgraph>=1.1.0
langchain-openai>=0.3.0

Expand Down
12 changes: 9 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ async def _chat_langgraph(message: str, session_id: str) -> list[dict[str, Any]]
AGENT_NAME = os.environ.get("AGENT_NAME", "protoagent")


def _build_security_schemes() -> dict:
"""Return securitySchemes dict, adding bearer only when A2A_AUTH_TOKEN is set."""
schemes: dict = {"apiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"}}
if os.environ.get("A2A_AUTH_TOKEN", ""):
schemes["bearer"] = {"type": "http", "scheme": "bearer"}
return schemes


def _build_agent_card(host: str) -> dict:
"""Build the A2A agent card served at /.well-known/agent-card.json.

Expand Down Expand Up @@ -308,9 +316,7 @@ def _build_agent_card(host: str) -> dict:
"examples": ["hello", "what can you do?"],
},
],
"securitySchemes": {
"apiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
},
"securitySchemes": _build_security_schemes(),
"security": [{"apiKey": []}],
}

Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Ensure deterministic import resolution for the protoagent test suite.

Moves site-packages to the front of sys.path so installed packages
(langchain_core, langchain, etc.) are never shadowed by local directories
that pytest inserts during collection.
"""
from __future__ import annotations

import site
import sys


def pytest_configure(config): # noqa: ARG001
"""Prepend site-packages to sys.path before any test imports occur."""
site_dirs = site.getsitepackages()
for sp in reversed(site_dirs):
if sp in sys.path:
sys.path.remove(sp)
sys.path.insert(0, sp)
23 changes: 23 additions & 0 deletions tests/test_a2a_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ def test_agent_card_has_at_least_one_skill() -> None:
assert "description" in skill


def test_agent_card_no_bearer_when_token_unset(monkeypatch) -> None:
"""With A2A_AUTH_TOKEN unset, card must NOT advertise bearer scheme."""
monkeypatch.delenv("A2A_AUTH_TOKEN", raising=False)
from server import _build_agent_card

card = _build_agent_card("protoagent:7870")
schemes = card.get("securitySchemes", {})
assert "apiKey" in schemes, "apiKey scheme must always be present"
assert "bearer" not in schemes, "bearer must not appear when A2A_AUTH_TOKEN is unset"


def test_agent_card_bearer_when_token_set(monkeypatch) -> None:
"""With A2A_AUTH_TOKEN set, card must advertise bearer scheme."""
monkeypatch.setenv("A2A_AUTH_TOKEN", "secret-test-token")
from server import _build_agent_card

card = _build_agent_card("protoagent:7870")
schemes = card.get("securitySchemes", {})
assert "apiKey" in schemes, "apiKey scheme must always be present"
assert "bearer" in schemes, "bearer must appear when A2A_AUTH_TOKEN is set"
assert schemes["bearer"] == {"type": "http", "scheme": "bearer"}


def test_agent_card_declares_cost_v1_extension() -> None:
"""The runtime captures token usage on `on_chat_model_end` and the
A2A handler emits a cost-v1 DataPart on every terminal task. The
Expand Down
70 changes: 70 additions & 0 deletions tests/test_memory_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,73 @@ def test_on_session_end_calls_persist_session(tmp_path):

mock_persist.assert_called_once_with(state, "trace-hook")
assert result is None


# ---------------------------------------------------------------------------
# 10. after_agent persistence on terminal turn
# ---------------------------------------------------------------------------

def test_after_agent_persists_on_terminal_turn(tmp_path):
"""after_agent must call _persist_session when last message is a terminal AIMessage."""
mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""})

store = MagicMock()
mw = mod.MemoryMiddleware(knowledge_store=store)

messages = [
HumanMessage(content="Hello"),
AIMessage(content="Final answer with no pending tool calls."),
]
state = _make_state("after-agent-terminal", messages=messages)
runtime = MagicMock()

with patch.object(mod, "_persist_session") as mock_persist, \
patch("tracing.current_trace_id", return_value="trace-after"):
mw.after_agent(state, runtime)

mock_persist.assert_called_once_with(state, "trace-after")


def test_after_agent_does_not_persist_when_tool_calls_pending(tmp_path):
"""after_agent must NOT persist when the last AIMessage has pending tool_calls."""
mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""})

store = MagicMock()
mw = mod.MemoryMiddleware(knowledge_store=store)

ai_msg = AIMessage(content="")
ai_msg.tool_calls = [{"id": "tc1", "name": "search", "args": {"query": "x"}}]
messages = [
HumanMessage(content="Search for x"),
ai_msg,
]
state = _make_state("after-agent-pending", messages=messages)
runtime = MagicMock()

with patch.object(mod, "_persist_session") as mock_persist, \
patch("tracing.current_trace_id", return_value="trace-pending"):
mw.after_agent(state, runtime)

mock_persist.assert_not_called()


def test_after_agent_does_not_persist_when_last_msg_not_ai(tmp_path):
"""after_agent must NOT persist when the last message is not an AIMessage."""
mod = _reload_memory({"MEMORY_PATH": str(tmp_path), "PROTOAGENT_DISABLE_MEMORY": ""})

store = MagicMock()
mw = mod.MemoryMiddleware(knowledge_store=store)

messages = [
HumanMessage(content="Hello"),
AIMessage(content="Some response."),
HumanMessage(content="Follow-up question"),
]
state = _make_state("after-agent-human-last", messages=messages)
runtime = MagicMock()

with patch.object(mod, "_persist_session") as mock_persist, \
patch("tracing.current_trace_id", return_value="trace-human"):
mw.after_agent(state, runtime)

mock_persist.assert_not_called()
Loading