diff --git a/README.md b/README.md index 8a6c1ccd..c1ded955 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,9 @@ JSON file: ## Architecture +For the planned Space background-worker architecture, see +[`docs/phase-3-space-background-workers.md`](docs/phase-3-space-background-workers.md). + ### Component Overview ``` diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py index 5c125b38..e00b6427 100644 --- a/agent/core/session_persistence.py +++ b/agent/core/session_persistence.py @@ -9,7 +9,8 @@ import logging import os -from datetime import UTC, datetime +import uuid +from datetime import UTC, datetime, timedelta from typing import Any from bson import BSON @@ -86,6 +87,9 @@ async def append_event(self, *_: Any, **__: Any) -> int | None: async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]: return [] + async def latest_event_seq(self, *_: Any, **__: Any) -> int: + return 0 + async def append_trace_message(self, *_: Any, **__: Any) -> int | None: return None @@ -98,6 +102,21 @@ async def try_increment_quota(self, *_: Any, **__: Any) -> int | None: async def refund_quota(self, *_: Any, **__: Any) -> None: return None + async def enqueue_run(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def claim_next_run(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def heartbeat_run(self, *_: Any, **__: Any) -> bool: + return False + + async def finish_run(self, *_: Any, **__: Any) -> None: + return None + + async def interrupt_expired_runs(self, *_: Any, **__: Any) -> int: + return 0 + class MongoSessionStore(NoopSessionStore): """MongoDB-backed session store.""" @@ -152,6 +171,14 @@ async def _create_indexes(self) -> None: [("session_id", 1), ("seq", 1)], unique=True ) await self.db.session_trace_messages.create_index([("created_at", -1)]) + await self.db.session_runs.create_index( + [("status", 1), ("lease_until", 1), ("created_at", 1)] + ) + await self.db.session_runs.create_index([("session_id", 1), ("created_at", -1)]) + await self.db.session_runs.create_index([("user_id", 1), ("created_at", -1)]) + await self.db.session_runs.create_index( + [("idempotency_key", 1)], unique=True, sparse=True + ) def _ready(self) -> bool: return bool(self.enabled and self.db is not None) @@ -348,6 +375,15 @@ async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[d ).sort("seq", 1) return [row async for row in cursor] + async def latest_event_seq(self, session_id: str) -> int: + if not self._ready(): + return 0 + doc = await self.db.session_events.find_one( + {"session_id": session_id}, + sort=[("seq", -1)], + ) + return int(doc.get("seq", 0)) if doc else 0 + async def append_trace_message( self, session_id: str, message: dict[str, Any], source: str = "message" ) -> int | None: @@ -410,6 +446,235 @@ async def refund_quota(self, user_id: str, day: str) -> None: {"$inc": {"count": -1}, "$set": {"updated_at": _now()}}, ) + async def enqueue_run( + self, + *, + session_id: str, + user_id: str, + operation: dict[str, Any], + idempotency_key: str | None = None, + surface: str = "space", + ) -> dict[str, Any] | None: + """Create a durable queued run and attach it to the session. + + Returns None when the session already has an active run. The caller can + surface that as a 409 rather than starting two concurrent turns against + the same context. + """ + if not self._ready(): + return None + if idempotency_key: + existing = await self.db.session_runs.find_one( + {"idempotency_key": idempotency_key} + ) + if existing: + return existing + + now = _now() + run_id = str(uuid.uuid4()) + run = { + "_id": run_id, + "run_id": run_id, + "schema_version": SCHEMA_VERSION, + "session_id": session_id, + "user_id": user_id, + "surface": surface, + "operation": operation, + "status": "queued", + "idempotency_key": idempotency_key, + "lease_owner": None, + "lease_until": None, + "retry_count": 0, + "max_retries": 1, + "created_at": now, + "started_at": None, + "updated_at": now, + "finished_at": None, + "error": None, + } + try: + await self.db.session_runs.insert_one(run) + except DuplicateKeyError: + if idempotency_key: + return await self.db.session_runs.find_one( + {"idempotency_key": idempotency_key} + ) + raise + + attached = await self.db.sessions.update_one( + { + "_id": session_id, + "$or": [ + {"active_run_id": {"$exists": False}}, + {"active_run_id": None}, + ], + }, + { + "$set": { + "active_run_id": run_id, + "runtime_state": "queued", + "updated_at": now, + } + }, + ) + if attached.matched_count == 0: + await self.db.session_runs.update_one( + {"_id": run_id}, + { + "$set": { + "status": "cancelled", + "error": "session already has an active run", + "updated_at": _now(), + "finished_at": _now(), + } + }, + ) + return None + return run + + async def claim_next_run( + self, + *, + worker_id: str, + lease_seconds: int = 120, + ) -> dict[str, Any] | None: + """Atomically claim the oldest queued run for a worker.""" + if not self._ready(): + return None + now = _now() + lease_until = now + timedelta(seconds=lease_seconds) + run = await self.db.session_runs.find_one_and_update( + {"status": "queued"}, + { + "$set": { + "status": "running", + "lease_owner": worker_id, + "lease_until": lease_until, + "started_at": now, + "updated_at": now, + }, + "$inc": {"retry_count": 1}, + }, + sort=[("created_at", 1)], + return_document=ReturnDocument.AFTER, + ) + if not run: + return None + + guarded = await self.db.sessions.update_one( + {"_id": run["session_id"], "active_run_id": run["_id"]}, + { + "$set": { + "runtime_state": "processing", + "worker_owner": worker_id, + "worker_lease_until": lease_until, + "updated_at": now, + } + }, + ) + if guarded.matched_count == 0: + await self.finish_run( + run["_id"], + status="interrupted", + error="session guard failed while claiming run", + ) + return None + return run + + async def heartbeat_run( + self, + run_id: str, + *, + worker_id: str, + lease_seconds: int = 120, + ) -> bool: + if not self._ready(): + return False + now = _now() + lease_until = now + timedelta(seconds=lease_seconds) + result = await self.db.session_runs.update_one( + {"_id": run_id, "status": "running", "lease_owner": worker_id}, + {"$set": {"lease_until": lease_until, "updated_at": now}}, + ) + if result.matched_count: + run = await self.db.session_runs.find_one({"_id": run_id}) + if run: + await self.db.sessions.update_one( + {"_id": run["session_id"], "active_run_id": run_id}, + { + "$set": { + "worker_lease_until": lease_until, + "updated_at": now, + } + }, + ) + return bool(result.matched_count) + + async def finish_run( + self, + run_id: str, + *, + status: str, + error: str | None = None, + ) -> None: + if not self._ready(): + return + now = _now() + run = await self.db.session_runs.find_one_and_update( + {"_id": run_id}, + { + "$set": { + "status": status, + "error": error, + "updated_at": now, + "finished_at": now, + "lease_until": None, + } + }, + return_document=ReturnDocument.AFTER, + ) + if not run: + return + + runtime_state = { + "completed": "idle", + "waiting_approval": "waiting_approval", + "failed": "idle", + "cancelled": "idle", + "interrupted": "interrupted", + }.get(status, "idle") + await self.db.sessions.update_one( + {"_id": run["session_id"], "active_run_id": run_id}, + { + "$set": { + "runtime_state": runtime_state, + "active_run_id": None, + "worker_owner": None, + "worker_lease_until": None, + "updated_at": now, + } + }, + ) + + async def interrupt_expired_runs(self, *, before: datetime | None = None) -> int: + """Mark expired running runs interrupted instead of replaying tool calls.""" + if not self._ready(): + return 0 + now = _now() + cutoff = before or now + cursor = self.db.session_runs.find( + {"status": "running", "lease_until": {"$lt": cutoff}} + ) + count = 0 + async for run in cursor: + await self.finish_run( + run["_id"], + status="interrupted", + error="worker lease expired", + ) + count += 1 + return count + _store: NoopSessionStore | MongoSessionStore | None = None diff --git a/backend/background_worker.py b/backend/background_worker.py new file mode 100644 index 00000000..8868b5c6 --- /dev/null +++ b/backend/background_worker.py @@ -0,0 +1,230 @@ +"""Durable session-run worker for Space background execution. + +This worker consumes Mongo-backed ``session_runs``. It intentionally reuses the +existing ``SessionManager`` and agent loop instead of introducing a second agent +execution path. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +import socket +import uuid +from typing import Any + +from dotenv import load_dotenv + +load_dotenv(Path(__file__).parent.parent / ".env") + +from agent.core.session import OpType +from session_manager import AgentSession, Operation, SessionManager + +logger = logging.getLogger(__name__) + +TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} + + +def background_workers_enabled() -> bool: + return os.environ.get("ML_INTERN_BACKGROUND_WORKERS", "").lower() in { + "1", + "true", + "yes", + } + + +def in_process_worker_enabled() -> bool: + return os.environ.get("ML_INTERN_RUN_WORKER_IN_PROCESS", "").lower() in { + "1", + "true", + "yes", + } + + +def default_worker_id() -> str: + return os.environ.get( + "ML_INTERN_WORKER_ID", + f"{socket.gethostname()}-{uuid.uuid4().hex[:8]}", + ) + + +def operation_from_run(run: dict[str, Any]) -> Operation: + operation = run.get("operation") or {} + op_type = operation.get("type") + payload = operation.get("payload") or {} + + if op_type == OpType.USER_INPUT.value: + return Operation(op_type=OpType.USER_INPUT, data={"text": payload.get("text", "")}) + if op_type == OpType.EXEC_APPROVAL.value: + return Operation( + op_type=OpType.EXEC_APPROVAL, + data={"approvals": payload.get("approvals") or []}, + ) + if op_type == OpType.UNDO.value: + return Operation(op_type=OpType.UNDO) + if op_type == OpType.COMPACT.value: + return Operation(op_type=OpType.COMPACT) + if op_type == OpType.SHUTDOWN.value: + return Operation(op_type=OpType.SHUTDOWN) + + raise ValueError(f"Unsupported background run operation: {op_type!r}") + + +async def _wait_for_broadcaster(agent_session: AgentSession, timeout: float = 5.0): + deadline = asyncio.get_running_loop().time() + timeout + while agent_session.broadcaster is None: + if asyncio.get_running_loop().time() >= deadline: + raise TimeoutError("session broadcaster was not initialized") + await asyncio.sleep(0.05) + return agent_session.broadcaster + + +def _run_status_from_event(event_type: str) -> str: + if event_type == "approval_required": + return "waiting_approval" + if event_type == "error": + return "failed" + if event_type == "interrupted": + return "interrupted" + return "completed" + + +async def _heartbeat_loop( + store, + *, + run_id: str, + worker_id: str, + lease_seconds: int, + interval_seconds: int, +) -> None: + while True: + await asyncio.sleep(interval_seconds) + ok = await store.heartbeat_run( + run_id, + worker_id=worker_id, + lease_seconds=lease_seconds, + ) + if not ok: + logger.warning("Worker %s lost lease for run %s", worker_id, run_id) + return + + +async def process_run( + manager: SessionManager, + run: dict[str, Any], + *, + worker_id: str, + lease_seconds: int = 120, + heartbeat_interval_seconds: int = 30, +) -> None: + """Execute one claimed run and update its durable status.""" + store = manager._store() + run_id = str(run["_id"]) + session_id = str(run["session_id"]) + user_id = str(run.get("user_id") or "dev") + heartbeat_task: asyncio.Task | None = None + sub_id: int | None = None + broadcaster = None + + try: + agent_session = await manager.ensure_session_loaded(session_id, user_id) + if not agent_session or not agent_session.is_active: + raise RuntimeError("session not found or inactive") + + broadcaster = await _wait_for_broadcaster(agent_session) + sub_id, event_queue = broadcaster.subscribe() + operation = operation_from_run(run) + + heartbeat_task = asyncio.create_task( + _heartbeat_loop( + store, + run_id=run_id, + worker_id=worker_id, + lease_seconds=lease_seconds, + interval_seconds=heartbeat_interval_seconds, + ) + ) + + success = await manager.submit(session_id, operation) + if not success: + raise RuntimeError("session rejected background run submission") + + while True: + event = await event_queue.get() + event_type = str(event.get("event_type") or "") + if event_type in TERMINAL_EVENTS: + await store.finish_run( + run_id, + status=_run_status_from_event(event_type), + error=(event.get("data") or {}).get("error"), + ) + logger.info("Worker %s finished run %s as %s", worker_id, run_id, event_type) + return + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception("Worker %s failed run %s", worker_id, run_id) + await store.finish_run(run_id, status="failed", error=str(e)) + finally: + if heartbeat_task is not None: + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass + if broadcaster is not None and sub_id is not None: + broadcaster.unsubscribe(sub_id) + + +async def run_worker_loop( + manager: SessionManager, + *, + worker_id: str | None = None, + poll_interval_seconds: float = 0.25, + idle_interval_seconds: float = 2.0, + lease_seconds: int = 120, + heartbeat_interval_seconds: int = 30, +) -> None: + """Continuously claim and process queued durable runs.""" + worker_id = worker_id or default_worker_id() + store = manager._store() + logger.info("Background worker %s starting", worker_id) + + while True: + await store.interrupt_expired_runs() + run = await store.claim_next_run( + worker_id=worker_id, + lease_seconds=lease_seconds, + ) + if not run: + await asyncio.sleep(idle_interval_seconds) + continue + + await process_run( + manager, + run, + worker_id=worker_id, + lease_seconds=lease_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + ) + await asyncio.sleep(poll_interval_seconds) + + +async def main() -> None: + from session_manager import session_manager + + await session_manager.start() + try: + await run_worker_loop(session_manager) + finally: + await session_manager.close() + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + asyncio.run(main()) diff --git a/backend/main.py b/backend/main.py index f6bc64d1..fe9aa6ec 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ """FastAPI application for HF Agent web interface.""" +import asyncio import logging import os from contextlib import asynccontextmanager @@ -24,13 +25,30 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) +_background_worker_task = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan handler.""" + global _background_worker_task logger.info("Starting HF Agent backend...") await session_manager.start() + try: + from background_worker import ( + background_workers_enabled, + in_process_worker_enabled, + run_worker_loop, + ) + + if background_workers_enabled() and in_process_worker_enabled(): + _background_worker_task = asyncio.create_task( + run_worker_loop(session_manager), + name="ml-intern-background-worker", + ) + logger.info("Started in-process background worker") + except Exception as e: + logger.warning("Background worker failed to start: %s", e) # Start in-process hourly KPI rollup. Replaces an external cron so the # rollup lives next to the data and reuses the Space's HF token. try: @@ -41,6 +59,13 @@ async def lifespan(app: FastAPI): yield logger.info("Shutting down HF Agent backend...") + if _background_worker_task is not None: + _background_worker_task.cancel() + try: + await _background_worker_task + except asyncio.CancelledError: + pass + _background_worker_task = None try: import kpis_scheduler await kpis_scheduler.shutdown() diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 3067f4fd..d3172ffe 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -33,6 +33,7 @@ import user_quotas +from background_worker import background_workers_enabled from agent.core.hf_access import get_jobs_access from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token from agent.core.llm_params import _resolve_llm_params @@ -653,6 +654,8 @@ async def chat_sse( # Parse body body = await request.json() + if background_workers_enabled(): + return await _chat_sse_background(session_id, request, user, agent_session, body) # Subscribe BEFORE submitting so we never miss events — even if the # agent loop processes the submission before this coroutine continues. @@ -706,6 +709,92 @@ async def chat_sse( return _sse_response(broadcaster, event_queue, sub_id) +async def _chat_sse_background( + session_id: str, + request: Request, + user: dict, + agent_session: AgentSession, + body: dict[str, Any], +) -> StreamingResponse: + """Durably enqueue chat work and stream events from Mongo. + + This is the Phase 3 path. It is opt-in so production can fall back to the + existing direct in-process execution path instantly. + """ + store = session_manager._store() + if not getattr(store, "enabled", False): + raise HTTPException( + status_code=503, + detail="Background workers require Mongo session persistence", + ) + + text = body.get("text") + approvals = body.get("approvals") + if text is None and not approvals: + raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") + + if text is not None and not approvals: + await _enforce_claude_quota(user, agent_session) + operation = {"type": "user_input", "payload": {"text": text}} + else: + formatted = [ + { + "tool_call_id": a["tool_call_id"], + "approved": a["approved"], + "feedback": a.get("feedback"), + "edited_script": a.get("edited_script"), + "namespace": a.get("namespace"), + } + for a in approvals + ] + await _enforce_jobs_access_for_approvals(user, agent_session, formatted) + operation = {"type": "exec_approval", "payload": {"approvals": formatted}} + + idempotency_key = ( + request.headers.get("Idempotency-Key") + or request.headers.get("X-Idempotency-Key") + or body.get("idempotency_key") + ) + after_seq = max(_last_event_seq(request), await store.latest_event_seq(session_id)) + run = await store.enqueue_run( + session_id=session_id, + user_id=agent_session.user_id, + operation=operation, + idempotency_key=idempotency_key, + surface="space", + ) + if not run: + raise HTTPException( + status_code=409, + detail="Session already has a queued or running operation", + ) + + broadcaster = await _wait_for_session_broadcaster(agent_session) + sub_id, event_queue = broadcaster.subscribe() + replay_events = await store.load_events_after(session_id, after_seq) + return _sse_response( + broadcaster, + event_queue, + sub_id, + replay_events=replay_events, + after_seq=after_seq, + poll_session_id=session_id, + poll_interval_seconds=0.5, + ) + + +async def _wait_for_session_broadcaster(agent_session: AgentSession): + deadline = asyncio.get_running_loop().time() + 5 + while agent_session.broadcaster is None: + if asyncio.get_running_loop().time() >= deadline: + raise HTTPException( + status_code=503, + detail="Session event broadcaster is not ready", + ) + await asyncio.sleep(0.05) + return agent_session.broadcaster + + @router.post("/pro-click/{session_id}") async def record_pro_click( session_id: str, @@ -767,17 +856,23 @@ def _sse_response( *, replay_events: list[dict[str, Any]] | None = None, after_seq: int = 0, + poll_session_id: str | None = None, + poll_interval_seconds: float = 0.5, ) -> StreamingResponse: """Build a StreamingResponse that drains *event_queue* as SSE, sending keepalive comments every 15 s to prevent proxy timeouts.""" async def event_generator(): + last_seq = after_seq + last_keepalive = asyncio.get_running_loop().time() try: for doc in replay_events or []: msg = _event_doc_to_msg(doc) seq = msg.get("seq") if isinstance(seq, int) and seq <= after_seq: continue + if isinstance(seq, int): + last_seq = max(last_seq, seq) yield _format_sse(msg) if msg.get("event_type", "") in _TERMINAL_EVENTS: return @@ -785,13 +880,44 @@ async def event_generator(): while True: try: msg = await asyncio.wait_for( - event_queue.get(), timeout=_SSE_KEEPALIVE_SECONDS + event_queue.get(), + timeout=( + poll_interval_seconds + if poll_session_id + else _SSE_KEEPALIVE_SECONDS + ), ) except asyncio.TimeoutError: + if poll_session_id: + docs = await session_manager._store().load_events_after( + poll_session_id, + last_seq, + ) + if docs: + for doc in docs: + msg = _event_doc_to_msg(doc) + seq = msg.get("seq") + if isinstance(seq, int): + if seq <= last_seq: + continue + last_seq = seq + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + return + continue + + now = asyncio.get_running_loop().time() + if now - last_keepalive < _SSE_KEEPALIVE_SECONDS: + continue + last_keepalive = now + # SSE comment — ignored by parsers, keeps connection alive yield ": keepalive\n\n" continue event_type = msg.get("event_type", "") + seq = msg.get("seq") + if isinstance(seq, int): + last_seq = max(last_seq, seq) yield _format_sse(msg) if event_type in _TERMINAL_EVENTS: break @@ -826,7 +952,7 @@ async def subscribe_events( after_seq = _last_event_seq(request) replay_events = await session_manager._store().load_events_after(session_id, after_seq) - broadcaster = agent_session.broadcaster + broadcaster = await _wait_for_session_broadcaster(agent_session) sub_id, event_queue = broadcaster.subscribe() return _sse_response( broadcaster, @@ -834,6 +960,7 @@ async def subscribe_events( sub_id, replay_events=replay_events, after_seq=after_seq, + poll_session_id=session_id if background_workers_enabled() else None, ) diff --git a/backend/start.sh b/backend/start.sh index 72b35198..3b11a0c9 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -4,6 +4,11 @@ # Only the first instance can bind port 7860 — the rest must exit # with code 0 so the dev mode daemon doesn't mark the app as crashed. +if [ "${ML_INTERN_PROCESS_ROLE:-api}" = "worker" ]; then + uvicorn worker_app:app --host 0.0.0.0 --port 7860 + exit $? +fi + # Run uvicorn; if it fails due to port conflict, exit cleanly. uvicorn main:app --host 0.0.0.0 --port 7860 EXIT_CODE=$? diff --git a/backend/worker_app.py b/backend/worker_app.py new file mode 100644 index 00000000..53f8b1fe --- /dev/null +++ b/backend/worker_app.py @@ -0,0 +1,60 @@ +"""HTTP health wrapper for the background worker Space.""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from background_worker import default_worker_id, run_worker_loop +from session_manager import session_manager + +logger = logging.getLogger(__name__) + +_worker_task: asyncio.Task | None = None +_worker_id: str | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global _worker_task, _worker_id + + _worker_id = default_worker_id() + logger.info("Starting worker Space app (%s)", _worker_id) + await session_manager.start() + _worker_task = asyncio.create_task( + run_worker_loop(session_manager, worker_id=_worker_id), + name="ml-intern-worker", + ) + try: + yield + finally: + logger.info("Stopping worker Space app (%s)", _worker_id) + if _worker_task is not None: + _worker_task.cancel() + try: + await _worker_task + except asyncio.CancelledError: + pass + _worker_task = None + await session_manager.close() + + +app = FastAPI( + title="ML Intern Worker", + description="Background worker for durable ML Intern session runs", + version="1.0.0", + lifespan=lifespan, +) + + +@app.get("/") +@app.get("/health") +async def health(): + return { + "status": "ok", + "worker_id": _worker_id, + "worker_running": bool(_worker_task and not _worker_task.done()), + } diff --git a/docs/phase-3-space-background-workers.md b/docs/phase-3-space-background-workers.md new file mode 100644 index 00000000..7fcaae85 --- /dev/null +++ b/docs/phase-3-space-background-workers.md @@ -0,0 +1,529 @@ +# Phase 3 Plan: Space Background Workers + +## Summary + +Phase 3 decouples agent execution from the browser's SSE connection. The +frontend/backend Space will accept user submissions and persist them to MongoDB. +A long-running worker, preferably a separate Hugging Face Space, will claim those +submissions, run the agent loop, and write durable events/messages/snapshots back +to MongoDB. + +This lets a user close their laptop or lose the SSE connection while the agent +continues running in the background. When the user returns, the frontend +rehydrates the session from MongoDB and replays events after the last seen event +sequence. + +## Prerequisite + +Mongo-backed session persistence must be working in production. + +The backend must log: + +```text +Mongo session persistence enabled (db=...) +``` + +If Mongo persistence is disabled, Phase 3 must not be enabled because queued +runs, event replay, session restore, and quota state all depend on durable +storage. + +## Current State + +The current Space implementation persists the right durable primitives: + +- `sessions`: session metadata, runtime state, title, pending approvals, quota + marker, soft-delete visibility. +- `session_messages`: latest restorable runtime context, stored per message. +- `session_events`: append-only event log with per-session sequence numbers. +- `session_trace_messages`: raw trace/SFT-ready message stream. +- `claude_quotas`: Mongo-backed quota counters. + +However, active work is still driven by the API process: + +1. The browser sends `POST /api/chat/{session_id}`. +2. The backend queues the operation in an in-memory `asyncio.Queue`. +3. The in-memory agent task runs `process_submission()`. +4. SSE streams events while the browser connection is alive. + +That is good enough for restart recovery between turns, but it is not enough for +true background execution. If the API process restarts mid-turn, the in-flight +turn is lost. + +## Target Architecture + +```text +Browser + POST /api/chat/{session_id} + GET /api/events/{session_id}?after= + +Frontend/backend Space + authenticates user + validates session ownership + creates durable session_runs document + serves session metadata/messages/events + does not own long-running agent execution + +MongoDB + sessions + session_messages + session_events + session_trace_messages + session_runs + claude_quotas + +Worker Space + claims queued session_runs + restores sessions from MongoDB + runs process_submission() + writes events/messages/snapshots + renews leases while running +``` + +The worker should be a shared pool, not one spawned worker Space per user +session. Start with one worker Space and scale to multiple workers once Mongo +claim/lease behavior is proven. + +## New Collection: `session_runs` + +`session_runs` is the durable queue for work that must survive browser +disconnects and backend restarts. + +Suggested document shape: + +```json +{ + "_id": "run_uuid", + "schema_version": 1, + "session_id": "session_uuid", + "user_id": "hf_user_id", + "surface": "space", + "operation": { + "type": "user_input", + "payload": { + "text": "build a demo", + "attachments": [] + } + }, + "status": "queued", + "idempotency_key": "client_generated_or_server_generated_key", + "lease_owner": null, + "lease_until": null, + "retry_count": 0, + "max_retries": 1, + "created_at": "2026-04-28T00:00:00Z", + "started_at": null, + "updated_at": "2026-04-28T00:00:00Z", + "finished_at": null, + "error": null +} +``` + +Allowed `operation.type` values: + +- `user_input` +- `exec_approval` +- `interrupt` +- `compact` +- `undo` +- `truncate` +- `shutdown` + +Allowed `status` values: + +- `queued` +- `running` +- `waiting_approval` +- `completed` +- `failed` +- `cancelled` +- `interrupted` + +Indexes: + +```text +{ status: 1, lease_until: 1, created_at: 1 } +{ session_id: 1, created_at: -1 } +{ user_id: 1, created_at: -1 } +{ idempotency_key: 1 } unique sparse +``` + +Optional later index for a worker pool: + +```text +{ lease_owner: 1, lease_until: 1 } +``` + +## Session Metadata Additions + +The current `sessions` collection already has most of what Phase 3 needs. Add or +standardize: + +```text +runtime_state: idle | queued | processing | waiting_approval | ended | interrupted +active_run_id: string | null +worker_owner: string | null +worker_lease_until: datetime | null +last_event_seq: int +``` + +Only one active run should exist per session in v1. This avoids concurrent turns +modifying the same context. + +## API Changes + +### `POST /api/chat/{session_id}` + +Current behavior: enqueue into an in-memory queue and stream events for the +current turn. + +Phase 3 behavior: + +1. Authenticate the user. +2. Verify session ownership via `ensure_session_loaded()` or a metadata-only + ownership path. +3. Create a `session_runs` document with `status="queued"`. +4. Set `sessions.runtime_state="queued"` and `sessions.active_run_id=`. +5. Return either: + - `202 Accepted` with `{ run_id, session_id }`, or + - the existing SSE response shape while internally streaming from the event + log. + +For lowest frontend disruption, v1 can keep the existing `POST /api/chat` +streaming contract. The important change is that the POST only creates durable +work; the worker owns execution. + +### `GET /api/events/{session_id}?after=` + +Keep this as the reconnect/replay endpoint. + +Behavior: + +1. Load persisted events after `after`. +2. Stream replayed events with their durable `seq`. +3. Subscribe to live event fanout if the current API process has one. +4. Keep the connection alive with comments. + +Longer term, the API process should also tail Mongo events for sessions whose +worker is in another process. For v1, polling Mongo every 1-2 seconds is +acceptable and simpler than change streams. + +### `GET /api/session/{session_id}` and `/messages` + +Keep these as session rehydration endpoints. They should not require the worker +to be in the same process. + +## Worker Space + +Create a worker entrypoint, for example: + +```text +ML_INTERN_PROCESS_ROLE=worker +``` + +The worker Space should use the same codebase and these secrets: + +- `MONGODB_URI` +- `MONGODB_DB` +- model provider secrets +- any HF/tool credentials needed for agent execution + +The worker does not need a public UI. In this repo, `ML_INTERN_PROCESS_ROLE=worker` +starts `worker_app:app`, which exposes `/health` for the Space while the worker +loop runs from the app lifespan. + +The current implementation also supports an in-process worker for the API Space: + +```text +ML_INTERN_BACKGROUND_WORKERS=true +ML_INTERN_RUN_WORKER_IN_PROCESS=true +``` + +That mode is the safe first rollout because the API process already has the +user's HF token in memory after request authentication. A separate worker Space +can claim the same durable runs, but user-scoped HF tool execution still needs an +explicit token handoff/token-broker design before it should be enabled for +production user traffic. Do not persist raw user OAuth tokens to Mongo as the +default path. + +Worker loop: + +1. Initialize Mongo session store. +2. Claim one queued or expired run atomically. +3. Restore the session context from Mongo. +4. Recreate runtime `Session`, `ToolRouter`, queues, and event persistence. +5. Mark run `running`; mark session `processing`. +6. Start a heartbeat task that renews `lease_until`. +7. Execute `process_submission()`. +8. Persist final message snapshot and runtime state. +9. Mark run `completed`, `waiting_approval`, `failed`, `cancelled`, or + `interrupted`. + +Atomic claim query: + +```js +db.session_runs.findOneAndUpdate( + { + status: { $in: ["queued", "running"] }, + $or: [ + { status: "queued" }, + { lease_until: { $lt: now } } + ] + }, + { + $set: { + status: "running", + lease_owner: worker_id, + lease_until: now + lease_duration, + started_at: now, + updated_at: now + }, + $inc: { retry_count: 1 } + }, + { sort: { created_at: 1 }, returnDocument: "after" } +) +``` + +Use a session-level guard so two workers cannot run two turns for the same +session: + +```js +db.sessions.updateOne( + { + _id: session_id, + $or: [ + { active_run_id: null }, + { active_run_id: run_id }, + { worker_lease_until: { $lt: now } } + ] + }, + { + $set: { + active_run_id: run_id, + runtime_state: "processing", + worker_owner: worker_id, + worker_lease_until: now + lease_duration + } + } +) +``` + +If the session guard fails, release or requeue the run. + +## Handling Browser Close + +With Phase 3: + +1. User submits a request. +2. Backend writes `session_runs(status="queued")`. +3. Worker claims and runs it. +4. User closes the browser. +5. Nothing important happens to the run. The worker keeps executing. +6. Worker keeps appending `session_events` and saving snapshots. +7. User returns later. +8. Frontend loads `/api/sessions`, `/api/session/{id}/messages`, and + `/api/events/{id}?after=`. +9. UI shows the completed turn or current progress. + +The browser is an observer, not the owner of execution. + +## Restart Semantics + +### API Space Restart + +No active run should be lost. The worker Space continues running. Reopened +browsers reconnect through the API and replay persisted events. + +### Worker Space Restart Between Runs + +No issue. Another worker or restarted worker claims the next queued run. + +### Worker Space Restart Mid-Turn + +The worker lease expires. V1 should not blindly resume arbitrary in-flight tool +calls. Instead: + +1. Mark the run `interrupted`. +2. Mark session `runtime_state="interrupted"`. +3. Append an event telling the frontend the turn was interrupted. +4. Let the user continue/retry from the latest saved snapshot. + +Pending approvals are the exception: if the session was waiting for user +approval, restore pending approvals exactly and keep the session in +`waiting_approval`. + +## Tool-Call Idempotency Policy + +Do not assume tools are safe to replay. + +For v1: + +- Completed turns are durable. +- Pending approvals are durable and exactly restorable. +- In-flight non-approval tool calls are interrupted on worker crash/restart. +- Long-running external HF Jobs should persist job IDs as soon as they are + created so the restored agent can inspect status later. + +Later, individual tools can opt into idempotent resume behavior. + +## Frontend Changes + +Frontend should make SSE reconnect normal: + +- Keep `lastEventSeq` per session. +- On reconnect, call `/api/events/{session_id}?after=`. +- Treat `POST /api/chat` response as submission acknowledgement plus optional + live stream. +- Continue to merge server-side sessions into sidebar metadata. +- Hydrate titles from `sessions.title`. +- Hide `visibility="deleted"` sessions. + +The UI should not show an expired-session recovery banner just because the +browser slept while a worker continued running. + +## Implementation Phases + +### Phase 3.1: Durable Run Store + +- Add `session_runs` methods to `SessionStore`. +- Add Mongo indexes. +- Add typed run payload helpers. +- Add tests for enqueue, claim, lease renewal, completion, and expired lease + recovery. + +### Phase 3.2: API Enqueue Path + +- Change `/api/chat/{session_id}` to create a durable run. +- Preserve current SSE response shape where possible. +- Add idempotency keys to prevent duplicate submissions on browser retry. +- Add session-level single-active-run guard. + +### Phase 3.3: In-Process Worker First + +- Add a worker loop inside the existing backend process behind a feature flag. +- Use it to validate queue semantics without deploying a second Space yet. +- Keep old direct execution path available behind a rollback flag. + +### Phase 3.4: Separate Worker Space + +- Add `backend/worker.py`. +- Add a Docker/entrypoint option for worker mode. +- Deploy `ml-intern-worker` Space with the same Mongo/model secrets. +- Disable in-process worker on the API Space once the worker Space is healthy. + +### Phase 3.5: Reconnect Polish + +- Ensure event replay is complete enough to rebuild visible progress. +- Add clear interrupted/retry UI states. +- Add observability for queued/running/failed/interrupted run counts. + +## Testing Plan + +Unit tests: + +- run enqueue idempotency +- atomic claim only returns one worker winner +- lease renewal extends `lease_until` +- expired lease becomes retryable/interrupted +- one active run per session +- pending approvals survive restore +- event replay after sequence + +Integration tests with local Mongo: + +1. Submit a turn. +2. Kill the browser/SSE client. +3. Confirm worker completes the run. +4. Reconnect and replay events. +5. Restart API process mid-run; worker continues. +6. Restart worker mid-run; run becomes interrupted after lease expiry. +7. Submit approval after restore; worker continues from pending approval. + +Production smoke: + +1. Deploy API Space with Mongo enabled. +2. Deploy worker Space with same Mongo/model secrets. +3. Submit a long-ish turn. +4. Close the browser. +5. Reopen after completion and confirm the response is visible. +6. Restart API Space during a worker run and confirm no run loss. + +## Observability + +Add logs/metrics for: + +- worker startup and worker ID +- run claimed +- run completed +- run failed +- run interrupted +- lease renewed +- lease expired +- queue depth +- oldest queued run age + +Mongoku queries should make it easy to inspect: + +```js +db.session_runs.find({ status: { $in: ["queued", "running"] } }).sort({ created_at: 1 }) +db.sessions.find({ runtime_state: { $in: ["queued", "processing", "waiting_approval"] } }) +db.session_events.find({ session_id }).sort({ seq: 1 }) +``` + +## Rollback Plan + +Keep a feature flag while migrating: + +```text +ML_INTERN_BACKGROUND_WORKERS=false +``` + +When disabled: + +- `/api/chat` uses the current direct in-process execution path. +- Mongo session persistence remains enabled. +- Event replay remains enabled. + +This lets us deploy Phase 3 code safely before routing production traffic through +the worker queue. + +Worker/process flags: + +```text +ML_INTERN_BACKGROUND_WORKERS=true # /api/chat enqueues durable runs +ML_INTERN_RUN_WORKER_IN_PROCESS=true # API Space also runs a local worker +ML_INTERN_PROCESS_ROLE=worker # run this container as worker-only +ML_INTERN_WORKER_ID=ml-intern-worker-1 # optional stable worker name +``` + +## Open Decisions + +- Exact worker deployment name and ownership: `ml-intern-worker` is the suggested + first name. +- Whether API Space should run an in-process worker as fallback when the external + worker Space is unhealthy. +- Whether to use polling or Mongo change streams for live event fanout from + external workers. Polling is simpler for v1. +- Lease duration and heartbeat interval. Suggested initial values: + `lease_duration=120s`, `heartbeat_interval=30s`. +- How many retries before marking a run permanently `interrupted` or `failed`. + Suggested v1: one retry for queued-before-start failures, no automatic replay + for mid-tool-call failures. + +## Non-Goals + +- No per-user worker Space spawning in v1. +- No replay of arbitrary non-idempotent in-flight tool calls. +- No CLI Mongo persistence. +- No replacement of the existing sandbox Spaces; sandbox Spaces remain for user + code execution, while the worker Space runs agent orchestration. + +## Acceptance Criteria + +- Closing the browser does not stop an active agent turn. +- Reopening the browser restores the session and displays completed or current + progress. +- API Space restart does not stop worker-owned runs. +- Worker restart mid-turn produces a clear interrupted state, not silent loss. +- Pending approvals restore exactly. +- Only one worker can run a turn for a session at a time. +- Feature flag can roll back to the current direct execution path. diff --git a/tests/unit/test_background_worker.py b/tests/unit/test_background_worker.py new file mode 100644 index 00000000..674697df --- /dev/null +++ b/tests/unit/test_background_worker.py @@ -0,0 +1,48 @@ +"""Tests for durable background worker helpers.""" + +import sys +from pathlib import Path + +import pytest + +_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" +if str(_BACKEND_DIR) not in sys.path: + sys.path.insert(0, str(_BACKEND_DIR)) + +from background_worker import operation_from_run # noqa: E402 +from agent.core.session import OpType # noqa: E402 + + +def test_operation_from_user_input_run(): + operation = operation_from_run( + { + "operation": { + "type": "user_input", + "payload": {"text": "build a demo"}, + } + } + ) + + assert operation.op_type == OpType.USER_INPUT + assert operation.data == {"text": "build a demo"} + + +def test_operation_from_approval_run(): + approvals = [{"tool_call_id": "call_1", "approved": True}] + + operation = operation_from_run( + { + "operation": { + "type": "exec_approval", + "payload": {"approvals": approvals}, + } + } + ) + + assert operation.op_type == OpType.EXEC_APPROVAL + assert operation.data == {"approvals": approvals} + + +def test_operation_from_unknown_run_rejects_unsupported_type(): + with pytest.raises(ValueError, match="Unsupported background run operation"): + operation_from_run({"operation": {"type": "truncate", "payload": {}}}) diff --git a/tests/unit/test_session_persistence.py b/tests/unit/test_session_persistence.py index 8bddb10f..8104f95b 100644 --- a/tests/unit/test_session_persistence.py +++ b/tests/unit/test_session_persistence.py @@ -22,6 +22,12 @@ async def test_noop_store_keeps_local_cli_and_tests_db_free(): assert await store.list_sessions("u1") == [] assert await store.append_event("s1", "processing", {}) is None assert await store.try_increment_quota("u1", "2099-01-01", 1) is None + assert await store.latest_event_seq("s1") == 0 + assert await store.enqueue_run(session_id="s1", user_id="u1", operation={}) is None + assert await store.claim_next_run(worker_id="w1") is None + assert await store.heartbeat_run("r1", worker_id="w1") is False + assert await store.interrupt_expired_runs() == 0 + await store.finish_run("r1", status="completed") def test_unsafe_message_payload_is_replaced_with_marker():