diff --git a/.gitignore b/.gitignore index d758b077..5b591873 100644 --- a/.gitignore +++ b/.gitignore @@ -61,6 +61,8 @@ session_logs/ hf-agent-leaderboard/ skills/ .claude/ +.omc/ +.omx/ *.jsonl *.csv diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 630bcd26..b0919b09 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -5,7 +5,6 @@ import asyncio import json import logging -import os import time from dataclasses import dataclass, field from typing import Any @@ -917,7 +916,6 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) token_count = response.usage.total_tokens if response.usage else 0 thinking_blocks, reasoning_content = _extract_thinking_state(message) - # Build tool_calls_acc in the same format as streaming tool_calls_acc: dict[int, dict] = {} if message.tool_calls: for idx, tc in enumerate(message.tool_calls): @@ -930,7 +928,6 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) }, } - # Emit the full message as a single event if content: await session.send_event( Event(event_type="assistant_message", data={"content": content}) @@ -1306,37 +1303,40 @@ async def _exec_tool( ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - _exec_tool(tc, name, args, decision, valid, err) - for tc, name, args, decision, valid, err in parsed_tools - ] - )) - cancel_task = asyncio.ensure_future(session._cancelled.wait()) - - done, _ = await asyncio.wait( - [gather_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - if cancel_task in done: - gather_task.cancel() - try: - await gather_task - except asyncio.CancelledError: - pass - # Notify frontend that in-flight tools were cancelled - for tc, name, _args, _decision, valid, _ in parsed_tools: - if valid: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, - )) - await _cleanup_on_cancel(session) - break + session.is_in_tool_call = True + try: + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + _exec_tool(tc, name, args, decision, valid, err) + for tc, name, args, decision, valid, err in parsed_tools + ] + )) + cancel_task = asyncio.ensure_future(session._cancelled.wait()) + + done, _ = await asyncio.wait( + [gather_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED, + ) - cancel_task.cancel() - results = gather_task.result() + if cancel_task in done: + gather_task.cancel() + try: + await gather_task + except asyncio.CancelledError: + pass + for tc, name, _args, _decision, valid, _ in parsed_tools: + if valid: + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, + )) + await _cleanup_on_cancel(session) + break + + cancel_task.cancel() + results = gather_task.result() + finally: + session.is_in_tool_call = False # 4. Record results and send outputs (order preserved) for tc, tool_name, tool_args, output, success in results: @@ -1610,40 +1610,44 @@ async def execute_tool(tc, tool_name, tool_args, was_edited): # Execute all approved tools concurrently (cancellable) if approved_tasks: - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - )) - cancel_task = asyncio.ensure_future(session._cancelled.wait()) - - done, _ = await asyncio.wait( - [gather_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED, - ) + session.is_in_tool_call = True + try: + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + )) + cancel_task = asyncio.ensure_future(session._cancelled.wait()) - if cancel_task in done: - gather_task.cancel() - try: - await gather_task - except asyncio.CancelledError: - pass - # Notify frontend that approved tools were cancelled - for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, - )) - await _cleanup_on_cancel(session) - await session.send_event(Event(event_type="interrupted")) - session.increment_turn() - await session.auto_save_if_needed() - return - - cancel_task.cancel() - results = gather_task.result() + done, _ = await asyncio.wait( + [gather_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if cancel_task in done: + gather_task.cancel() + try: + await gather_task + except asyncio.CancelledError: + pass + # Notify frontend that approved tools were cancelled + for tc, tool_name, _args, _was_edited in approved_tasks: + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, + )) + await _cleanup_on_cancel(session) + await session.send_event(Event(event_type="interrupted")) + session.increment_turn() + await session.auto_save_if_needed() + return + + cancel_task.cancel() + results = gather_task.result() + finally: + session.is_in_tool_call = False # Process results and add to context for result in results: diff --git a/agent/core/session.py b/agent/core/session.py index f0874ed9..11c5ef2a 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -114,6 +114,7 @@ def __init__( self.config = config self.is_running = True self._cancelled = asyncio.Event() + self.is_in_tool_call: bool = False self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None self._running_job_ids: set[str] = set() # HF job IDs currently executing diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py index f2c2d367..67099765 100644 --- a/agent/core/session_persistence.py +++ b/agent/core/session_persistence.py @@ -9,10 +9,12 @@ import logging import os -from datetime import UTC, datetime +import socket +import uuid +from datetime import UTC, datetime, timedelta from typing import Any -from bson import BSON +from bson import BSON, ObjectId from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError @@ -30,6 +32,21 @@ def _doc_id(session_id: str, idx: int) -> str: return f"{session_id}:{idx}" +def make_holder_id(mode: str) -> str: + """Build a process holder id ``f"{mode}:{hostname}:{8-hex-suffix}"``. + + Uses ``uuid7`` if available (Python ≥ 3.13) for chronological ordering; + falls back to ``uuid4`` otherwise. Pick once at process start; do not + change mid-run. + """ + hostname = socket.gethostname() + if hasattr(uuid, "uuid7"): + suffix = uuid.uuid7().hex[:8] # type: ignore[attr-defined] + else: + suffix = uuid.uuid4().hex[:8] + return f"{mode}:{hostname}:{suffix}" + + def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]: """Return a Mongo-safe message document payload. @@ -101,6 +118,42 @@ async def refund_quota(self, *_: Any, **__: Any) -> None: async def mark_pro_seen(self, *_: Any, **__: Any) -> dict[str, Any] | None: return None + # ── Lease + pending-submission control plane (no-op) ────────────────── + + async def claim_lease(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def renew_lease(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def release_lease(self, *_: Any, **__: Any) -> None: + return None + + async def enqueue_pending_submission(self, *_: Any, **__: Any) -> str: + return "" + + async def claim_pending_submission(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def mark_submission_done(self, *_: Any, **__: Any) -> None: + return None + + async def requeue_claimed_for(self, *_: Any, **__: Any) -> int: + return 0 + + async def change_stream_pending_submissions(self, *_: Any, **__: Any): + raise NotImplementedError("change streams require Mongo persistence") + yield # pragma: no cover - makes this an async generator + + async def change_stream_events(self, *_: Any, **__: Any): + raise NotImplementedError("change streams require Mongo persistence") + yield # pragma: no cover - makes this an async generator + + async def poll_pending_submissions_after( + self, *_: Any, **__: Any + ) -> list[dict[str, Any]]: + return [] + class MongoSessionStore(NoopSessionStore): """MongoDB-backed session store.""" @@ -121,6 +174,7 @@ async def init(self) -> None: await self.client.admin.command("ping") await self._create_indexes() self.enabled = True + await self._backfill_lease_state() logger.info("Mongo session persistence enabled (db=%s)", self.db_name) except Exception as e: logger.warning("Mongo session persistence disabled: %s", e) @@ -130,6 +184,39 @@ async def init(self) -> None: self.client = None self.db = None + async def _backfill_lease_state(self) -> None: + """One-shot migration for sessions predating the lease control plane. + + Idempotent: the ``lease: {$exists: false}`` filter excludes already + migrated rows, so re-running ``init()`` is a no-op. + """ + if self.db is None: + return + try: + cutoff = _now() - timedelta(hours=1) + recent = await self.db.sessions.update_many( + { + "lease": {"$exists": False}, + "status": "active", + "last_active_at": {"$gt": cutoff}, + }, + {"$set": {"lease": {"holder_id": None, "expires_at": _now()}}}, + ) + old = await self.db.sessions.update_many( + { + "lease": {"$exists": False}, + "status": "active", + "last_active_at": {"$lte": cutoff}, + }, + {"$set": {"runtime_state": "idle"}}, + ) + logger.info( + f"Backfilled empty lease on {recent.modified_count} sessions; " + f"flipped {old.modified_count} old sessions to idle." + ) + except PyMongoError as e: + logger.warning(f"Lease backfill skipped due to Mongo error: {e}") + async def close(self) -> None: if self.client is not None: await self.client.close() @@ -156,6 +243,9 @@ async def _create_indexes(self) -> None: ) await self.db.session_trace_messages.create_index([("created_at", -1)]) await self.db.pro_users.create_index([("first_seen_pro_at", -1)]) + await self.db.pending_submissions.create_index( + [("session_id", 1), ("status", 1), ("created_at", 1)] + ) def _ready(self) -> bool: return bool(self.enabled and self.db is not None) @@ -483,6 +573,252 @@ async def mark_pro_seen( "first_seen_at": (doc.get("first_seen_at") or now).isoformat(), } + # ── Lease control plane ─────────────────────────────────────────────── + + async def claim_lease( + self, session_id: str, holder_id: str, ttl_s: int = 30 + ) -> dict[str, Any] | None: + """Atomic CAS claim. Succeeds iff lease missing or expired. + + Returns the updated session doc, or ``None`` if another holder + currently owns an unexpired lease. + """ + if not self._ready(): + return None + now = _now() + try: + return await self.db.sessions.find_one_and_update( + { + "_id": session_id, + "$or": [ + {"lease.expires_at": {"$lt": now}}, + {"lease": {"$exists": False}}, + {"lease.holder_id": None}, + ], + }, + { + "$set": { + "lease": { + "holder_id": holder_id, + "expires_at": now + timedelta(seconds=ttl_s), + "claimed_at": now, + }, + }, + "$inc": {"lease_generation": 1}, + }, + return_document=ReturnDocument.AFTER, + ) + except PyMongoError as e: + logger.warning(f"claim_lease failed for {session_id} ({holder_id}): {e}") + return None + + async def renew_lease( + self, session_id: str, holder_id: str, ttl_s: int = 30 + ) -> dict[str, Any] | None: + """Atomic renew. Returns updated doc, or ``None`` if we lost it. + + Raises ``PyMongoError`` on transient Mongo failures so callers can + distinguish "we lost the lease" (return value ``None``) from "Mongo + flapped" (exception). The heartbeat loop catches the exception and + skips this tick for the affected session; only ``None`` triggers + ``_on_lease_lost``. + """ + if not self._ready(): + return None + now = _now() + return await self.db.sessions.find_one_and_update( + {"_id": session_id, "lease.holder_id": holder_id}, + {"$set": {"lease.expires_at": now + timedelta(seconds=ttl_s)}}, + return_document=ReturnDocument.AFTER, + ) + + async def release_lease(self, session_id: str, holder_id: str) -> None: + """Atomic release. No-op if we no longer hold the lease. + + Clears ``lease.holder_id`` in addition to expiring the lease so the + renew CAS filter (``{"lease.holder_id": holder_id}``) no longer + matches — preventing a heartbeat tick that snapshotted the session + id pre-release from re-extending the lease 30 s into the future. + """ + if not self._ready(): + return + now = _now() + try: + await self.db.sessions.update_one( + {"_id": session_id, "lease.holder_id": holder_id}, + {"$set": {"lease.expires_at": now, "lease.holder_id": None}}, + ) + except PyMongoError as e: + logger.warning(f"release_lease failed for {session_id} ({holder_id}): {e}") + + # ── Pending submissions ─────────────────────────────────────────────── + + async def enqueue_pending_submission( + self, session_id: str, op_type: str, payload: dict[str, Any] + ) -> str: + """Insert a pending submission and return its inserted ``_id`` (str).""" + if not self._ready(): + return "" + doc = { + "_id": ObjectId(), + "session_id": session_id, + "op_type": op_type, + "payload": payload or {}, + "status": "pending", + "claimed_by": None, + "created_at": _now(), + } + try: + await self.db.pending_submissions.insert_one(doc) + return str(doc["_id"]) + except PyMongoError as e: + logger.warning( + f"enqueue_pending_submission failed for {session_id}: {e}" + ) + return "" + + async def claim_pending_submission( + self, session_id: str, holder_id: str + ) -> dict[str, Any] | None: + """Atomic FIFO claim of the oldest pending submission for a session.""" + if not self._ready(): + return None + now = _now() + try: + return await self.db.pending_submissions.find_one_and_update( + {"session_id": session_id, "status": "pending"}, + { + "$set": { + "status": "claimed", + "claimed_by": holder_id, + "claimed_at": now, + } + }, + sort=[("created_at", 1)], + return_document=ReturnDocument.AFTER, + ) + except PyMongoError as e: + logger.warning( + f"claim_pending_submission failed for {session_id} ({holder_id}): {e}" + ) + return None + + async def mark_submission_done(self, submission_id: str | ObjectId) -> None: + """Mark a previously claimed submission as completed.""" + if not self._ready(): + return None + _id = submission_id if isinstance(submission_id, ObjectId) else ObjectId(submission_id) + try: + await self.db.pending_submissions.update_one( + {"_id": _id}, + {"$set": {"status": "done", "completed_at": _now()}}, + ) + except PyMongoError as e: + logger.warning(f"mark_submission_done failed for {submission_id}: {e}") + + async def requeue_claimed_for( + self, holder_id: str, session_id: str | None = None + ) -> int: + """Flip ``claimed`` submissions for ``holder_id`` back to ``pending``. + + When ``session_id`` is provided, only submissions for that session + are flipped — used by ``_on_lease_lost`` so losing one session's + lease doesn't disturb the holder's other sessions. When ``None`` + (default), every claimed submission for this holder is flipped — + the correct behaviour for ``release_session_to_background`` and the + lifespan shutdown sweep. + + Must NOT modify ``created_at`` — FIFO ordering is preserved across + handovers. + """ + if not self._ready(): + return 0 + query: dict[str, Any] = {"status": "claimed", "claimed_by": holder_id} + if session_id is not None: + query["session_id"] = session_id + try: + result = await self.db.pending_submissions.update_many( + query, + { + "$set": {"status": "pending", "claimed_by": None}, + "$unset": {"claimed_at": ""}, + }, + ) + return int(result.modified_count or 0) + except PyMongoError as e: + logger.warning(f"requeue_claimed_for failed for {holder_id}: {e}") + return 0 + + # ── Change-stream tails (replica-set required) ──────────────────────── + + async def change_stream_pending_submissions(self, session_id: str): + """Yield newly inserted pending submissions for ``session_id``. + + Raises ``PyMongoError`` (the standard pymongo behaviour) if the + deployment isn't a replica set; callers fall back to polling. + """ + if not self._ready(): + return + pipeline = [ + { + "$match": { + "operationType": "insert", + "fullDocument.session_id": session_id, + "fullDocument.status": "pending", + } + } + ] + async with self.db.pending_submissions.watch(pipeline=pipeline) as stream: + async for change in stream: + full = change.get("fullDocument") + if full is not None: + yield full + + async def change_stream_events(self, session_id: str, after_seq: int = 0): + """Yield session_events documents with seq > ``after_seq``.""" + if not self._ready(): + return + pipeline = [ + { + "$match": { + "operationType": "insert", + "fullDocument.session_id": session_id, + "fullDocument.seq": {"$gt": int(after_seq or 0)}, + } + } + ] + async with self.db.session_events.watch(pipeline=pipeline) as stream: + async for change in stream: + full = change.get("fullDocument") + if full is not None: + yield full + + # ── Polling fallback ────────────────────────────────────────────────── + + async def poll_pending_submissions_after( + self, session_id: str, after_id: str | None + ) -> list[dict[str, Any]]: + """Return all pending submissions for ``session_id`` newer than ``after_id``. + + Used when change streams are unavailable. Sorted by ``created_at``. + """ + if not self._ready(): + return [] + query: dict[str, Any] = {"session_id": session_id, "status": "pending"} + if after_id: + try: + query["_id"] = {"$gt": ObjectId(after_id)} + except Exception: # noqa: BLE001 - bad id ⇒ start from beginning + pass + try: + cursor = self.db.pending_submissions.find(query).sort("created_at", 1) + return [row async for row in cursor] + except PyMongoError as e: + logger.warning( + f"poll_pending_submissions_after failed for {session_id}: {e}" + ) + return [] + _store: NoopSessionStore | MongoSessionStore | None = None diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py index c4480b97..97cd1745 100644 --- a/agent/tools/research_tool.py +++ b/agent/tools/research_tool.py @@ -452,14 +452,16 @@ async def _log(text: str) -> None: continue try: - import json as _json - - args_str = _json.dumps(tool_args)[:80] + args_str = json.dumps(tool_args)[:80] await _log(f"▸ {tool_name} {args_str}") - output, _success = await session.tool_router.call_tool( - tool_name, tool_args, session=session, tool_call_id=tc.id - ) + session.is_in_tool_call = True + try: + output, _success = await session.tool_router.call_tool( + tool_name, tool_args, session=session, tool_call_id=tc.id + ) + finally: + session.is_in_tool_call = False _tool_uses += 1 await _log(f"tools:{_tool_uses}") # Truncate tool output for the research context diff --git a/backend/main.py b/backend/main.py index f6bc64d1..aa367d57 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ """FastAPI application for HF Agent web interface.""" +import asyncio import logging import os from contextlib import asynccontextmanager @@ -60,6 +61,26 @@ async def lifespan(app: FastAPI): logger.warning("Failed to flush session %s: %s", sid, e) except Exception as e: logger.warning("Lifespan final-flush skipped: %s", e) + + # Lease handover sweep — for sessions still mid-turn when Main shuts + # down, emit a migrating event then release the lease so a Worker can + # pick them up. Idle sessions just rehydrate normally on next request + # and don't need this dance. + try: + for sid, agent_session in list(session_manager.sessions.items()): + runtime_state = session_manager._runtime_state(agent_session) + if runtime_state == "processing" or agent_session.is_processing: + try: + await session_manager.release_session_to_background( + sid, + reason="main_shutdown", + ) + except Exception as e: + logger.warning( + "Lease handover sweep failed for %s: %s", sid, e + ) + except Exception as e: + logger.warning("Lifespan lease sweep skipped: %s", e) await session_manager.close() @@ -112,3 +133,73 @@ async def api_root(): port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port) + + +# ── Worker mode entrypoint ─────────────────────────────────────────────── + + +async def _worker_claim_tick() -> None: + """One pass: find sessions with pending submissions and no live lease, + claim each via ``claim_dormant_session`` so the existing + ``_consume_submissions`` loop in ``_run_session`` will pick up their + pending docs. + + Sessions already held by this process (in ``session_manager.sessions``) + are skipped — heartbeat keeps their leases alive. + """ + store = session_manager._store() + if not getattr(store, "enabled", False): + return + db = getattr(store, "db", None) + if db is None: + return + + held = set(session_manager.sessions.keys()) + cursor = db.pending_submissions.find( + {"status": "pending"}, {"session_id": 1} + ).limit(200) + candidate_session_ids: set[str] = set() + async for doc in cursor: + sid = doc.get("session_id") + if sid and sid not in held: + candidate_session_ids.add(sid) + + for sid in candidate_session_ids: + try: + # claim_dormant_session does claim_lease internally; if claim + # fails (another process holds it) we'll just try the next + # session in the next tick. + await session_manager.claim_dormant_session(sid) + except Exception as e: + logger.warning(f"Worker failed to claim {sid}: {e}") + + +async def worker_loop() -> None: + """Worker mode entrypoint. Initializes a SessionManager in worker mode, + polls Mongo for pending submissions across all sessions, claims their + leases, and runs their agent loops. + + Heartbeat (US-002) renews held leases on a 10s cadence; the grace + sweeper (US-006) auto-backgrounds inactive held sessions; idle + eviction (US-007) drops fully idle ones. + """ + logger.info("Starting ml-intern worker mode...") + # Use the global session_manager (already imported); it has read MODE + # at construction time. + await session_manager.start() + try: + # Single coordinator task: scan ``pending_submissions`` across + # all sessions and claim those with no current holder. Polling + # cadence is 1s; that's enough for v1. + while True: + try: + await _worker_claim_tick() + await asyncio.sleep(1.0) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Worker claim tick error: {e}") + await asyncio.sleep(2.0) + finally: + await session_manager.close() + logger.info("Worker shut down cleanly.") diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 7175e536..bde2e739 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -19,6 +19,7 @@ ) from fastapi.responses import StreamingResponse from litellm import acompletion +from pymongo.errors import PyMongoError from models import ( ApprovalRequest, HealthResponse, @@ -571,6 +572,24 @@ async def delete_session( return {"status": "deleted", "session_id": session_id} +@router.post("/session/{session_id}/background") +async def background_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: + """Manually move a session to background. + + Releases the lease so a Worker can pick it up. The frontend uses this + when the user clicks "run in background". + """ + await _check_session_access(session_id, user) + success = await session_manager.release_session_to_background( + session_id, reason="user_requested" + ) + if not success: + raise HTTPException(status_code=404, detail="Session not found or inactive") + return {"status": "released", "session_id": session_id} + + @router.post("/submit") async def submit_input( request: SubmitRequest, user: dict = Depends(get_current_user) @@ -612,20 +631,19 @@ async def chat_sse( request: Request, user: dict = Depends(get_current_user), ) -> StreamingResponse: - """SSE endpoint: submit input or approval, then stream events until turn ends.""" + """SSE endpoint: submit input or approval, then stream events until turn ends. + + With the Mongo durability layer (US-001..US-004), every emitted event is + persisted via ``append_event`` before being fanned out, so the + replay-then-live pattern in ``_sse_response`` catches anything the + submission produces — there's no need to subscribe pre-submit anymore. + """ agent_session = await _check_session_access(session_id, user, request) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") # Parse body body = await request.json() - - # Subscribe BEFORE submitting so we never miss events — even if the - # agent loop processes the submission before this coroutine continues. - broadcaster = agent_session.broadcaster - sub_id, event_queue = broadcaster.subscribe() - - # Submit the operation text = body.get("text") approvals = body.get("approvals") @@ -633,42 +651,39 @@ async def chat_sse( # continuations of an in-progress turn — the session was already charged # on its first message, so we skip the gate there. if text is not None and not approvals: - try: - await _enforce_gated_model_quota(user, agent_session) - except HTTPException: - broadcaster.unsubscribe(sub_id) - raise + await _enforce_gated_model_quota(user, agent_session) - try: - if approvals: - formatted = [ - { - "tool_call_id": a["tool_call_id"], - "approved": a["approved"], - "feedback": a.get("feedback"), - "edited_script": a.get("edited_script"), - "namespace": a.get("namespace"), - } - for a in approvals - ] - success = await session_manager.submit_approval(session_id, formatted) - elif text is not None: - success = await session_manager.submit_user_input(session_id, text) - else: - broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") - - if not success: - broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=404, detail="Session not found or inactive") - except HTTPException: - broadcaster.unsubscribe(sub_id) - raise - except Exception: - broadcaster.unsubscribe(sub_id) - raise + after_seq = _last_event_seq(request) - return _sse_response(broadcaster, event_queue, sub_id) + if approvals: + formatted = [ + { + "tool_call_id": a["tool_call_id"], + "approved": a["approved"], + "feedback": a.get("feedback"), + "edited_script": a.get("edited_script"), + "namespace": a.get("namespace"), + } + for a in approvals + ] + success = await session_manager.submit_approval(session_id, formatted) + elif text is not None: + success = await session_manager.submit_user_input(session_id, text) + else: + raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") + + if not success: + raise HTTPException(status_code=404, detail="Session not found or inactive") + + replay_events = await session_manager._store().load_events_after( + session_id, after_seq + ) + return _sse_response( + session_id, + agent_session, + replay_events=replay_events, + after_seq=after_seq, + ) @router.post("/pro-click/{session_id}") @@ -726,42 +741,145 @@ def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]: def _sse_response( - broadcaster, - event_queue, - sub_id, + session_id: str, + agent_session: AgentSession | None, *, replay_events: list[dict[str, Any]] | None = None, after_seq: int = 0, ) -> StreamingResponse: - """Build a StreamingResponse that drains *event_queue* as SSE, - sending keepalive comments every 15 s to prevent proxy timeouts.""" + """Build a StreamingResponse that: + + 1. Yields ``replay_events`` first (events from Mongo that the client missed, + filtered to ``seq > after_seq``). + 2. Picks the live transport based on lease ownership: + - Holder fast-path: subscribe to the in-process ``EventBroadcaster``. + - Non-holder slow-path: tail Mongo via ``change_stream_events`` and fall + back to a 500 ms poll loop on ``PyMongoError``. + 3. Sends keepalive comments every 15 s to prevent proxy timeouts. + + Subscriber bookkeeping (``_attach_subscriber`` / ``_detach_subscriber``) is + called from BOTH branches so US-006's grace-period sweeper can see when a + session has had no readers. + """ async def event_generator(): - try: - for doc in replay_events or []: - msg = _event_doc_to_msg(doc) - seq = msg.get("seq") - if isinstance(seq, int) and seq <= after_seq: - continue - yield _format_sse(msg) - if msg.get("event_type", "") in _TERMINAL_EVENTS: - return - - while True: - try: - msg = await asyncio.wait_for( - event_queue.get(), timeout=_SSE_KEEPALIVE_SECONDS - ) - except asyncio.TimeoutError: - # SSE comment — ignored by parsers, keeps connection alive - yield ": keepalive\n\n" - continue - event_type = msg.get("event_type", "") - yield _format_sse(msg) - if event_type in _TERMINAL_EVENTS: - break - finally: - broadcaster.unsubscribe(sub_id) + # ── Phase 1: replay events the client missed ──────────────────────── + last_seen_seq = after_seq + replay_count = 0 + for doc in replay_events or []: + msg = _event_doc_to_msg(doc) + seq = msg.get("seq") + if isinstance(seq, int) and seq <= after_seq: + continue + if isinstance(seq, int): + last_seen_seq = max(last_seen_seq, seq) + replay_count += 1 + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + logger.info( + f"replay_event_count session_id={session_id} " + f"count={replay_count} after_seq={after_seq}" + ) + return + logger.info( + f"replay_event_count session_id={session_id} " + f"count={replay_count} after_seq={after_seq}" + ) + + # ── Phase 2: live tail ────────────────────────────────────────────── + is_holder = ( + agent_session is not None + and agent_session.is_active + and agent_session.holder_id == session_manager._holder_id + and agent_session.broadcaster is not None + ) + + if is_holder: + # Fast path: in-process broadcaster fan-out. The wholesale write + # path still goes through Mongo first (Session.send_event calls + # append_event before put-on-event_queue), so this is purely a + # latency win for the holder process. + broadcaster = agent_session.broadcaster + session_manager._attach_subscriber(session_id) + sub_id, queue = broadcaster.subscribe() + try: + while True: + try: + msg = await asyncio.wait_for( + queue.get(), timeout=_SSE_KEEPALIVE_SECONDS + ) + except asyncio.TimeoutError: + yield ": keepalive\n\n" + continue + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + break + finally: + broadcaster.unsubscribe(sub_id) + session_manager._detach_subscriber(session_id) + else: + # Slow path: cross-process — tail Mongo. On stream open failure + # (no replica set, etc.) fall back to 500 ms polling. + session_manager._attach_subscriber(session_id) + try: + store = session_manager._store() + use_stream = bool(getattr(store, "enabled", False)) + if use_stream: + try: + async for doc in store.change_stream_events( + session_id, after_seq=last_seen_seq + ): + msg = _event_doc_to_msg(doc) + seq = msg.get("seq") + if isinstance(seq, int): + last_seen_seq = max(last_seen_seq, seq) + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + return + # Stream closed without yielding terminal — fall + # through to the polling loop below to keep tailing. + use_stream = False + except PyMongoError as e: + logger.warning( + f"Change stream failed for {session_id}, " + f"falling back to polling: {e}" + ) + use_stream = False + except NotImplementedError: + # NoopSessionStore (or any store without watch()). + use_stream = False + if not use_stream: + # 500 ms poll loop — emits a keepalive every 15 s of silence. + silence_since = asyncio.get_event_loop().time() + while True: + if not getattr(store, "enabled", False): + await asyncio.sleep(0.5) + now = asyncio.get_event_loop().time() + if now - silence_since >= _SSE_KEEPALIVE_SECONDS: + yield ": keepalive\n\n" + silence_since = now + continue + new_events = await store.load_events_after( + session_id, last_seen_seq + ) + if not new_events: + await asyncio.sleep(0.5) + now = asyncio.get_event_loop().time() + if now - silence_since >= _SSE_KEEPALIVE_SECONDS: + yield ": keepalive\n\n" + silence_since = now + continue + silence_since = asyncio.get_event_loop().time() + for doc in new_events: + msg = _event_doc_to_msg(doc) + seq = msg.get("seq") + if isinstance(seq, int): + last_seen_seq = max(last_seen_seq, seq) + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + return + finally: + session_manager._detach_subscriber(session_id) return StreamingResponse( event_generator(), @@ -791,12 +909,9 @@ async def subscribe_events( after_seq = _last_event_seq(request) replay_events = await session_manager._store().load_events_after(session_id, after_seq) - broadcaster = agent_session.broadcaster - sub_id, event_queue = broadcaster.subscribe() return _sse_response( - broadcaster, - event_queue, - sub_id, + session_id, + agent_session, replay_events=replay_events, after_seq=after_seq, ) diff --git a/backend/session_manager.py b/backend/session_manager.py index 7f85cf76..0ba1accc 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -3,17 +3,21 @@ import asyncio import json import logging +import os +import time import uuid from dataclasses import dataclass, field -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import Any, Optional +from pymongo.errors import PyMongoError + from agent.config import load_config from agent.core.agent_loop import process_submission from agent.messaging.gateway import NotificationGateway from agent.core.session import Event, OpType, Session -from agent.core.session_persistence import get_session_store +from agent.core.session_persistence import get_session_store, make_holder_id from agent.core.tools import ToolRouter # Get project root (parent of backend directory) @@ -80,12 +84,19 @@ async def run(self) -> None: @dataclass class AgentSession: - """Wrapper for an agent session with its associated resources.""" + """Wrapper for an agent session with its associated resources. + + ``session`` and ``tool_router`` are ``Optional`` to support cross-process + *stubs* — lightweight placeholders for sessions held by another holder + (Worker). A stub carries enough identity (``session_id``, ``user_id``, + ``holder_id``) to satisfy access checks and the SSE non-holder slow + path, but no live runtime resources. Only the actual lease holder ever + constructs a fully populated ``AgentSession``. + """ session_id: str - session: Session - tool_router: ToolRouter - submission_queue: asyncio.Queue + session: Session | None + tool_router: ToolRouter | None user_id: str = "dev" # Owner of this session hf_username: str | None = None # HF namespace used for personal trace uploads hf_token: str | None = None # User's HF OAuth token for tool execution @@ -99,6 +110,17 @@ class AgentSession: # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. claude_counted: bool = False + # Wall-clock timestamp of the last submission processed for this + # session — used by US-007's idle eviction. Updated in one place + # inside ``_drain_and_process``. + last_submission_at: float = field(default_factory=lambda: 0.0) + # Holder identity that owns this session's lease. Set when the lease + # is claimed in ``create_session`` / ``ensure_session_loaded`` to + # ``SessionManager._holder_id``. The SSE fast path branches on + # ``holder_id == session_manager._holder_id`` to decide whether to + # subscribe to the in-process broadcaster (this process holds it) or + # tail the Mongo change stream (a different process holds it). + holder_id: str | None = None class SessionCapacityError(Exception): @@ -129,18 +151,351 @@ def __init__(self, config_path: str | None = None) -> None: self._lock = asyncio.Lock() self.persistence_store = None + # Holder identity — pick once at process start, never recompute. + # MODE controls which lane this process owns ("main" = synchronous + # frontend handler, "worker" = background submission consumer). + raw_mode = os.environ.get("MODE", "main").lower().strip() + if raw_mode not in {"main", "worker"}: + logger.warning( + f"Unknown MODE={raw_mode!r}; falling back to 'main'" + ) + raw_mode = "main" + self.mode: str = raw_mode + self._holder_id: str = make_holder_id(self.mode) + self._heartbeat_task: asyncio.Task | None = None + self._grace_sweep_task: asyncio.Task | None = None + self._idle_eviction_task: asyncio.Task | None = None + # SSE subscriber bookkeeping — used by the US-006 grace-period + # sweeper to decide when a session has had no readers for long + # enough to be evicted. Populated by ``_attach_subscriber`` / + # ``_detach_subscriber`` from both SSE transport branches. + self._subscriber_counts: dict[str, int] = {} + # Wall-clock timestamp (``time.time()``) of when ``session_id``'s + # subscriber count last hit zero. Cleared the moment a new + # subscriber attaches. + self._no_subscriber_since: dict[str, float] = {} + logger.info( + "SessionManager init: mode=%s holder_id=%s", + self.mode, + self._holder_id, + ) + + def _attach_subscriber(self, session_id: str) -> None: + """Increment the subscriber count for ``session_id``. + + Called from both SSE transport branches (holder fast-path and + non-holder slow-path) when a stream attaches. + """ + self._subscriber_counts[session_id] = ( + self._subscriber_counts.get(session_id, 0) + 1 + ) + self._no_subscriber_since.pop(session_id, None) + + def _detach_subscriber(self, session_id: str) -> None: + """Decrement the subscriber count; record the zero-point on transitions.""" + n = self._subscriber_counts.get(session_id, 0) + n = max(0, n - 1) + if n == 0: + self._subscriber_counts.pop(session_id, None) + self._no_subscriber_since[session_id] = time.time() + else: + self._subscriber_counts[session_id] = n + async def start(self) -> None: """Start shared background resources.""" self.persistence_store = get_session_store() await self.persistence_store.init() await self.messaging_gateway.start() + self._heartbeat_task = asyncio.create_task(self._lease_heartbeat_loop()) + self._grace_sweep_task = asyncio.create_task(self._grace_period_sweep_loop()) + self._idle_eviction_task = asyncio.create_task(self._idle_eviction_loop()) async def close(self) -> None: """Flush and close shared background resources.""" + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + if self._grace_sweep_task is not None: + self._grace_sweep_task.cancel() + try: + await self._grace_sweep_task + except asyncio.CancelledError: + pass + self._grace_sweep_task = None + if self._idle_eviction_task is not None: + self._idle_eviction_task.cancel() + try: + await self._idle_eviction_task + except asyncio.CancelledError: + pass + self._idle_eviction_task = None await self.messaging_gateway.close() if self.persistence_store is not None: await self.persistence_store.close() + async def _lease_heartbeat_loop(self) -> None: + """Renew leases every TTL/3 seconds for sessions held by this process. + + On CAS-mismatch for a session (``renew_lease`` returns ``None``): + requeue that session's claimed submissions, drop the session, log + WARN. On a transient ``PyMongoError`` from ``renew_lease``: log a + warning and skip this tick for the affected session — do NOT treat + a Mongo flap as lease theft. The loop must never crash; any other + unexpected exception is logged and the loop sleeps before retrying. + """ + HEARTBEAT_INTERVAL_S = 10 # TTL=30s, renew at TTL/3 + while True: + try: + await asyncio.sleep(HEARTBEAT_INTERVAL_S) + store = self._store() + if not getattr(store, "enabled", False): + continue # NoopSessionStore — nothing to renew + # Snapshot session_ids under lock to avoid mutation during iteration. + async with self._lock: + session_ids = list(self.sessions.keys()) + for session_id in session_ids: + try: + renewed = await store.renew_lease( + session_id, self._holder_id, ttl_s=30 + ) + except PyMongoError as e: + logger.warning( + f"renew_lease transient error for {session_id} " + f"({self._holder_id}); skipping tick: {e}" + ) + continue + if renewed is None: + # Lease lost — someone else holds it now. + await self._on_lease_lost(session_id) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Heartbeat loop error: {e}") + # Don't crash the loop; sleep briefly and retry. + await asyncio.sleep(1) + + async def _on_lease_lost(self, session_id: str) -> None: + """Called when our lease for ``session_id`` has been taken by another holder. + + Requeue claimed submissions for THIS session only, drop the session, + log WARN. We must NOT requeue all claimed submissions for this + holder — losing one session's lease shouldn't cause double-execution + on every other session this Main still holds. The heartbeat loop + must keep going, so we don't await the cancelled task here. + """ + store = self._store() + try: + requeued = await store.requeue_claimed_for( + self._holder_id, session_id=session_id + ) + logger.warning( + "Lease lost for session %s (held_by=%s); requeued %d claimed submissions", + session_id, self._holder_id, requeued, + ) + except Exception as e: + logger.error( + f"requeue_claimed_for failed during lease-loss for {session_id}: {e}" + ) + async with self._lock: + agent_session = self.sessions.pop(session_id, None) + if agent_session and agent_session.task and not agent_session.task.done(): + agent_session.task.cancel() + # Don't await — heartbeat loop must keep going. + + async def release_session_to_background( + self, session_id: str, reason: str = "manual" + ) -> bool: + """Emit migrating event, requeue claimed submissions, release lease. + + Used by the lifespan shutdown sweep, the grace-period sweeper, and + the manual ``/background`` route. Idempotent on already-released or + unknown session IDs (returns False in that case). + """ + async with self._lock: + agent_session = self.sessions.get(session_id) + if agent_session is None: + return False + # 1) Emit migrating event so the frontend can render a "reconnecting" + # state. send_event also durably appends via append_event, so a + # non-holder reader will see it on the next change-stream tick. + try: + await agent_session.session.send_event( + Event(event_type="migrating", data={"reason": reason}) + ) + logger.info(f"migrating_emitted session_id={session_id} reason={reason}") + except Exception as e: + logger.warning( + f"Failed to emit migrating event for {session_id}: {e}" + ) + # 2) Requeue any submissions we have in-flight back to pending so a + # Worker can pick them up. + store = self._store() + if getattr(store, "enabled", False): + try: + n = await store.requeue_claimed_for(self._holder_id) + if n > 0: + logger.info(f"requeue_claimed holder_id={self._holder_id} count={n}") + except Exception as e: + logger.warning( + f"requeue_claimed_for failed during release of " + f"{session_id}: {e}" + ) + # 3) Release the lease. + try: + await store.release_lease(session_id, self._holder_id) + logger.info(f"lease_release session_id={session_id} holder_id={self._holder_id} reason={reason}") + except Exception as e: + logger.warning(f"release_lease failed for {session_id}: {e}") + # 4) Drop from in-memory and cancel the agent task. Don't await the + # cancel — heartbeat / sweep loops must keep going. + async with self._lock: + popped = self.sessions.pop(session_id, None) + if popped and popped.task and not popped.task.done(): + popped.task.cancel() + return True + + async def _grace_period_sweep_loop(self) -> None: + """Every 30s, scan sessions held by this process. If a session has + had zero subscribers for longer than ``GRACE_PERIOD_SECONDS`` AND has + either in-flight work or pending submissions, release it to + background. Idle-with-no-work sessions are NOT auto-backgrounded — + they wait for idle eviction (US-007) or shutdown. + """ + SWEEP_INTERVAL_S = 30 + GRACE_PERIOD_S = float(os.environ.get("GRACE_PERIOD_SECONDS", "180")) + while True: + try: + await asyncio.sleep(SWEEP_INTERVAL_S) + now = time.time() + async with self._lock: + session_ids = list(self.sessions.keys()) + store = self._store() + for session_id in session_ids: + agent_session = self.sessions.get(session_id) + if agent_session is None: + continue + if agent_session.holder_id != self._holder_id: + continue + no_sub_since = self._no_subscriber_since.get(session_id) + if no_sub_since is None: + # Either someone is connected now, or no one has + # ever connected — neither case is an eviction. + continue + if now - no_sub_since < GRACE_PERIOD_S: + continue + has_pending = False + if getattr(store, "enabled", False): + try: + pending_docs = await store.poll_pending_submissions_after( + session_id, None + ) + has_pending = len(pending_docs) > 0 + except Exception: + has_pending = False + has_work = ( + agent_session.is_processing + or has_pending + or getattr(agent_session.session, "is_in_tool_call", False) + ) + if not has_work: + continue + logger.info( + f"Grace period elapsed for {session_id} " + f"(no subs for {now - no_sub_since:.0f}s); " + "releasing to background" + ) + await self.release_session_to_background( + session_id, reason="grace_period_elapsed" + ) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Grace sweep loop error: {e}") + await asyncio.sleep(1) + + async def _idle_eviction_loop(self) -> None: + """Every 60s, drop sessions held by this process that are fully idle + past ``IDLE_EVICTION_SECONDS`` (default 1800s = 30min). + + "Idle" predicate (US-007 spec): + * not ``is_in_tool_call`` (tool may still be executing) + * not ``is_processing`` (agent loop currently busy) + * no pending submissions in Mongo + * ``now - last_submission_at > IDLE_TTL_S`` + + On eviction, the lease is released and the session is dropped from + the in-memory map. No ``migrating`` event is emitted — by definition + nobody is watching. + """ + SWEEP_INTERVAL_S = 60 + IDLE_TTL_S = float(os.environ.get("IDLE_EVICTION_SECONDS", "1800")) + while True: + try: + await asyncio.sleep(SWEEP_INTERVAL_S) + now = time.time() + store = self._store() + async with self._lock: + session_ids = list(self.sessions.keys()) + for sid in session_ids: + agent_session = self.sessions.get(sid) + if agent_session is None: + continue + if agent_session.holder_id != self._holder_id: + continue + if agent_session.is_processing: + continue + if getattr(agent_session.session, "is_in_tool_call", False): + continue + if agent_session.last_submission_at == 0.0: + # Never had a submission yet (just-created); allow + # IDLE_TTL grace measured from creation. + last = ( + agent_session.created_at.timestamp() + if hasattr(agent_session.created_at, "timestamp") + else 0.0 + ) + else: + last = agent_session.last_submission_at + if now - last < IDLE_TTL_S: + continue + # Pending submissions in Mongo? If so, skip — a worker + # tick will pick them up. + if getattr(store, "enabled", False): + try: + pending_docs = await store.poll_pending_submissions_after( + sid, None + ) + if pending_docs: + continue + except Exception: + # If Mongo flapped, err on the side of NOT + # evicting — heartbeat will renew the lease and + # we'll try again on the next sweep. + continue + logger.info( + f"Idle-evicting {sid} (idle for {now - last:.0f}s)" + ) + try: + await store.release_lease(sid, self._holder_id) + logger.info(f"lease_release session_id={sid} holder_id={self._holder_id} reason=idle_eviction") + except Exception as e: + logger.warning( + f"release_lease failed during idle evict for {sid}: {e}" + ) + async with self._lock: + popped = self.sessions.pop(sid, None) + if popped and popped.task and not popped.task.done(): + popped.task.cancel() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Idle eviction loop error: {e}") + await asyncio.sleep(2) + def _store(self): if self.persistence_store is None: self.persistence_store = get_session_store() @@ -328,7 +683,6 @@ async def _start_agent_session( task = asyncio.create_task( self._run_session( agent_session.session_id, - agent_session.submission_queue, event_queue, tool_router, ) @@ -407,57 +761,29 @@ async def persist_session_snapshot( e, ) - async def ensure_session_loaded( + async def _rebuild_agent_session_from_store( self, + loaded: dict[str, Any], session_id: str, - user_id: str, - hf_token: str | None = None, - hf_username: str | None = None, - ) -> AgentSession | None: - """Return a live runtime session, lazily restoring it from Mongo.""" - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - return existing - return None - - store = self._store() - loaded = await store.load_session(session_id) - if not loaded: - return None - - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - return existing - return None - - meta = loaded.get("metadata") or {} - owner = str(meta.get("user_id") or "") - if user_id != "dev" and owner != "dev" and owner != user_id: - return None + hf_token: str | None, + hf_username: str | None, + owner: str, + ) -> AgentSession: + """Reconstruct an ``AgentSession`` from a Mongo ``load_session`` result. + Caller is expected to have already claimed the lease (or persistence + is disabled). Shared by ``ensure_session_loaded`` and + ``claim_dormant_session``. + """ from litellm import Message + meta = loaded.get("metadata") or {} model = meta.get("model") or self.config.model_name event_queue: asyncio.Queue = asyncio.Queue() - submission_queue: asyncio.Queue = asyncio.Queue() tool_router, session = await asyncio.to_thread( self._create_session_sync, session_id=session_id, - user_id=owner or user_id, + user_id=owner, hf_username=hf_username, hf_token=hf_token, model=model, @@ -497,8 +823,7 @@ async def ensure_session_loaded( session_id=session_id, session=session, tool_router=tool_router, - submission_queue=submission_queue, - user_id=owner or user_id, + user_id=owner, hf_username=hf_username, hf_token=hf_token, created_at=created_at, @@ -506,6 +831,7 @@ async def ensure_session_loaded( is_processing=False, claude_counted=bool(meta.get("claude_counted")), title=meta.get("title"), + holder_id=self._holder_id, ) started = await self._start_agent_session( agent_session=agent_session, @@ -519,9 +845,154 @@ async def ensure_session_loaded( hf_username=hf_username, ) return started + return agent_session + + async def ensure_session_loaded( + self, + session_id: str, + user_id: str, + hf_token: str | None = None, + hf_username: str | None = None, + ) -> AgentSession | None: + """Return a live runtime session, lazily restoring it from Mongo.""" + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + return existing + return None + + store = self._store() + loaded = await store.load_session(session_id) + if not loaded: + return None + + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + return existing + return None + + meta = loaded.get("metadata") or {} + owner = str(meta.get("user_id") or "") + if user_id != "dev" and owner != "dev" and owner != user_id: + return None + + if getattr(store, "enabled", False): + claimed = await store.claim_lease( + session_id, self._holder_id, ttl_s=30 + ) + if claimed is None: + # Another holder owns an unexpired lease. Return a stub so + # the route layer's access check passes and the SSE slow + # path / submission-enqueue paths can deliver to the actual + # holder. The stub is NOT inserted into ``self.sessions`` — + # only the real holder owns that map entry. + foreign_lease = (meta.get("lease") or {}) + foreign_holder = foreign_lease.get("holder_id") + if not foreign_holder: + # Defensive: post-backfill every active session has a + # lease subdoc; if it's missing we can't safely build a + # stub. Preserve the legacy behaviour and return None. + logger.info( + f"Refusing restore of {session_id}: lease " + "held by another process (no holder_id on doc)" + ) + return None + created_at = meta.get("created_at") + if not isinstance(created_at, datetime): + created_at = datetime.utcnow() + logger.info( + f"ensure_session_loaded stub session_id={session_id} " + f"foreign_holder={foreign_holder} (lease held elsewhere)" + ) + return AgentSession( + session_id=session_id, + session=None, + tool_router=None, + user_id=owner or user_id, + hf_username=hf_username, + hf_token=hf_token, + task=None, + created_at=created_at, + is_active=True, + is_processing=False, + broadcaster=None, + holder_id=foreign_holder, + ) + logger.info(f"lease_claim session_id={session_id} holder_id={self._holder_id}") + + agent_session = await self._rebuild_agent_session_from_store( + loaded=loaded, + session_id=session_id, + hf_token=hf_token, + hf_username=hf_username, + owner=owner or user_id, + ) logger.info("Restored session %s for user %s", session_id, owner or user_id) return agent_session + async def claim_dormant_session( + self, session_id: str + ) -> AgentSession | None: + """Internal: claim and load a dormant session without a user-ownership + check. Used by ``worker_loop``'s claim tick — the worker process is + process-level trusted, and the lease CAS still enforces the + "one holder at a time" invariant. + + Returns the live ``AgentSession`` on success; ``None`` if the session + doesn't exist, the lease is already held, or persistence is disabled. + """ + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + return existing + + store = self._store() + if not getattr(store, "enabled", False): + return None + + loaded = await store.load_session(session_id) + if not loaded: + return None + + # Claim the lease BEFORE building the session — fast bail out if + # someone else holds it. + claimed = await store.claim_lease( + session_id, self._holder_id, ttl_s=30 + ) + if claimed is None: + logger.debug( + f"Worker refusing to claim {session_id}: lease held by another process" + ) + return None + logger.info(f"lease_claim session_id={session_id} holder_id={self._holder_id}") + + meta = loaded.get("metadata") or {} + owner = str(meta.get("user_id") or "") or "dev" + agent_session = await self._rebuild_agent_session_from_store( + loaded=loaded, + session_id=session_id, + hf_token=None, + hf_username=None, + owner=owner, + ) + logger.info( + "Worker claimed dormant session %s (owner=%s)", session_id, owner + ) + return agent_session + async def create_session( self, user_id: str = "dev", @@ -568,8 +1039,8 @@ async def create_session( session_id = str(uuid.uuid4()) - # Create queues for this session - submission_queue: asyncio.Queue = asyncio.Queue() + # Create queue for this session (events still flow through an + # in-process queue; submissions now live in Mongo). event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. @@ -588,12 +1059,35 @@ async def create_session( session_id=session_id, session=session, tool_router=tool_router, - submission_queue=submission_queue, user_id=user_id, hf_username=hf_username, hf_token=hf_token, ) + # Claim the lease before starting the runtime task — brand-new + # session_id, so this should always succeed; failure is treated as + # an internal error. + claimed = await self._store().claim_lease( + session_id, self._holder_id, ttl_s=30 + ) + if ( + claimed is None + and getattr(self._store(), "enabled", False) + ): + logger.warning( + f"Failed to claim lease for new session {session_id} " + f"(holder={self._holder_id})" + ) + raise RuntimeError( + f"Failed to claim lease for new session {session_id}" + ) + # Tag the session with our holder identity so the SSE fast-path + # branch knows we own it. Always safe to set: either we just + # claimed (Mongo path) or persistence is disabled (single-process + # local dev — we are trivially the only holder). + agent_session.holder_id = self._holder_id + logger.info(f"lease_claim session_id={session_id} holder_id={self._holder_id}") + await self._start_agent_session( agent_session=agent_session, event_queue=event_queue, @@ -727,19 +1221,141 @@ async def _cleanup_sandbox(session: Session) -> None: f"Orphan — sweep script will pick it up." ) + async def _consume_submissions(self, agent_session: AgentSession) -> None: + """Consume pending submissions for this session. + + Tries the Mongo change stream first (push-based, low-latency); on + replica-set unavailability or any ``PyMongoError`` from ``watch()`` + falls back to a 500 ms polling loop. Either path drains all + currently-pending submissions through ``_drain_and_process``. + """ + session = agent_session.session + session_id = agent_session.session_id + store = self._store() + use_change_stream = bool(getattr(store, "enabled", False)) + + # Drain anything that arrived before the consumer started — covers + # the race where ``enqueue`` happened during runtime startup. + await self._drain_and_process(agent_session) + + while session.is_running: + try: + if use_change_stream: + try: + async for _change_doc in store.change_stream_pending_submissions( + session_id + ): + await self._drain_and_process(agent_session) + if not session.is_running: + break + # Stream exited without error (e.g. shutdown). Break. + if not session.is_running: + break + except PyMongoError as e: + logger.warning( + f"Change stream failed for {session_id}, " + f"falling back to polling: {e}" + ) + use_change_stream = False + except NotImplementedError: + # NoopSessionStore (or any store without watch()) + use_change_stream = False + else: + await asyncio.sleep(0.5) + await self._drain_and_process(agent_session) + except asyncio.CancelledError: + break + except Exception as e: + logger.error( + f"Submission consume error for {session_id}: {e}" + ) + await asyncio.sleep(1) + + async def _drain_and_process(self, agent_session: AgentSession) -> None: + """Claim and process all pending submissions for ``agent_session``, FIFO. + + Handles ``interrupt`` and ``shutdown`` ops inline (they don't go + through the agent loop). All other ops are reconstructed into a + ``Submission(Operation(...))`` and dispatched to ``process_submission``. + Marks each submission ``done`` in a finally so a poison submission + never gets redelivered. + """ + session = agent_session.session + session_id = agent_session.session_id + store = self._store() + if not getattr(store, "enabled", False): + return + while session.is_running: + claimed = await store.claim_pending_submission(session_id, self._holder_id) + if claimed is None: + return + submission_id = claimed.get("_id") + op_type = claimed.get("op_type") + payload = claimed.get("payload") or {} + try: + # Wall-clock (time.time()) so it composes with the same clock + # used by ``_no_subscriber_since`` and the idle-eviction loop. + agent_session.last_submission_at = time.time() + created_at = claimed.get("created_at") + if isinstance(created_at, datetime): + _ca = created_at if created_at.tzinfo else created_at.replace(tzinfo=UTC) + lag = (datetime.now(UTC) - _ca).total_seconds() + if lag > 0.1: + logger.debug( + f"pending_submission_lag session_id={session_id} " + f"op_type={op_type} lag_ms={int(lag * 1000)}" + ) + # Inline ops: interrupt + shutdown bypass the agent loop. + if op_type == "interrupt": + session.cancel() + continue + if op_type == "shutdown": + session.is_running = False + return + agent_session.is_processing = True + try: + operation = self._build_operation(op_type, payload) + submission = Submission( + id=f"sub_{uuid.uuid4().hex[:8]}", + operation=operation, + ) + should_continue = await process_submission(session, submission) + finally: + agent_session.is_processing = False + await self.persist_session_snapshot(agent_session) + if not should_continue: + session.is_running = False + return + except Exception as e: + logger.error( + f"Error processing submission {submission_id} " + f"for {session_id}: {e}" + ) + finally: + # Always mark done so a poison row is not redelivered. + try: + await store.mark_submission_done(submission_id) + except Exception as e: + logger.debug( + f"mark_submission_done failed for {submission_id}: {e}" + ) + + def _build_operation(self, op_type: Any, payload: dict) -> Operation: + """Reconstruct an ``Operation`` from a ``pending_submissions`` row.""" + if isinstance(op_type, OpType): + enum_op = op_type + else: + enum_op = OpType(op_type) + return Operation(op_type=enum_op, data=payload or None) + async def _run_session( self, session_id: str, - submission_queue: asyncio.Queue, event_queue: asyncio.Queue, tool_router: ToolRouter, ) -> None: """Run the agent loop for a session and broadcast events via EventBroadcaster.""" - agent_session = self.sessions.get(session_id) - if not agent_session: - logger.error(f"Session {session_id} not found") - return - + agent_session = self.sessions[session_id] session = agent_session.session # Start event broadcaster task @@ -754,30 +1370,15 @@ async def _run_session( Event(event_type="ready", data={"message": "Agent initialized"}) ) - while session.is_running: - try: - # Wait for submission with timeout to allow checking is_running - submission = await asyncio.wait_for( - submission_queue.get(), timeout=1.0 - ) - agent_session.is_processing = True - try: - should_continue = await process_submission(session, submission) - finally: - agent_session.is_processing = False - await self.persist_session_snapshot(agent_session) - if not should_continue: - break - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - logger.info(f"Session {session_id} cancelled") - break - except Exception as e: - logger.error(f"Error in session {session_id}: {e}") - await session.send_event( - Event(event_type="error", data={"error": str(e)}) - ) + try: + await self._consume_submissions(agent_session) + except asyncio.CancelledError: + logger.info(f"Session {session_id} cancelled") + except Exception as e: + logger.error(f"Error in session {session_id}: {e}") + await session.send_event( + Event(event_type="error", data={"error": str(e)}) + ) finally: broadcast_task.cancel() @@ -788,6 +1389,14 @@ async def _run_session( await self._cleanup_sandbox(session) + try: + await self._store().release_lease(session_id, self._holder_id) + logger.info(f"lease_release session_id={session_id} holder_id={self._holder_id} reason=session_end") + except Exception as e: + logger.debug( + f"release_lease failed for {session_id} on session end: {e}" + ) + # Final-flush: always save on session death so we capture ended # sessions even if the client disconnects without /shutdown. # Idempotent via session_id key; detached subprocess. @@ -808,45 +1417,89 @@ async def _run_session( logger.info(f"Session {session_id} ended") - async def submit(self, session_id: str, operation: Operation) -> bool: - """Submit an operation to a session.""" - async with self._lock: - agent_session = self.sessions.get(session_id) - - if not agent_session or not agent_session.is_active: - logger.warning(f"Session {session_id} not found or inactive") + async def _enqueue_or_false( + self, session_id: str, op_type: str, payload: dict[str, Any] + ) -> bool: + """Enqueue a pending submission, returning False when no session + exists in either runtime memory or the durable store. + + The route layer's ``_check_session_access`` already gates by user; + this method only verifies the session exists somewhere we can + deliver to. When the store is the no-op (Mongo disabled), require + the session to be in our in-memory map and refuse if not — there + is no other holder to forward to. + """ + store = self._store() + in_memory = self.sessions.get(session_id) + if not getattr(store, "enabled", False): + if in_memory is None or not in_memory.is_active: + logger.warning(f"Session {session_id} not found or inactive") + return False + # No durable queue — without Mongo we cannot enqueue. Drop and + # warn; this path is exercised in CLI/local-dev only and the + # legacy in-memory flow has been removed. + logger.warning( + f"Cannot enqueue submission for {session_id}: " + "Mongo persistence disabled" + ) return False - - submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) - await agent_session.submission_queue.put(submission) + if in_memory is None: + doc = await store.load_session(session_id) + if doc is None: + logger.warning(f"Session {session_id} not found") + return False + await store.enqueue_pending_submission( + session_id, op_type=op_type, payload=payload + ) return True + async def submit(self, session_id: str, operation: Operation) -> bool: + """Submit an operation to a session via the durable pending queue.""" + return await self._enqueue_or_false( + session_id, + op_type=operation.op_type.value, + payload=operation.data or {}, + ) + async def submit_user_input(self, session_id: str, text: str) -> bool: """Submit user input to a session.""" - operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) - return await self.submit(session_id, operation) + return await self._enqueue_or_false( + session_id, op_type="user_input", payload={"text": text} + ) async def submit_approval( self, session_id: str, approvals: list[dict[str, Any]] ) -> bool: """Submit tool approvals to a session.""" - operation = Operation( - op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} + return await self._enqueue_or_false( + session_id, op_type="exec_approval", payload={"approvals": approvals} ) - return await self.submit(session_id, operation) async def interrupt(self, session_id: str) -> bool: - """Interrupt a session by signalling cancellation directly (bypasses queue).""" - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: + """Interrupt by signalling cancellation. Holder fast-path; non-holder + enqueues an interrupt op for the actual lease holder to consume. + """ + async with self._lock: + agent_session = self.sessions.get(session_id) + if agent_session and agent_session.is_active: + # We are the holder — fast path, cancel directly. + agent_session.session.cancel() + return True + store = self._store() + if not getattr(store, "enabled", False): + return False + if await store.load_session(session_id) is None: return False - agent_session.session.cancel() + await store.enqueue_pending_submission( + session_id, op_type="interrupt", payload={} + ) return True async def undo(self, session_id: str) -> bool: """Undo last turn in a session.""" - operation = Operation(op_type=OpType.UNDO) - return await self.submit(session_id, operation) + return await self._enqueue_or_false( + session_id, op_type="undo", payload={} + ) async def truncate(self, session_id: str, user_message_index: int) -> bool: """Truncate conversation to before a specific user message (direct, no queue).""" @@ -861,23 +1514,43 @@ async def truncate(self, session_id: str, user_message_index: int) -> bool: async def compact(self, session_id: str) -> bool: """Compact context in a session.""" - operation = Operation(op_type=OpType.COMPACT) - return await self.submit(session_id, operation) + return await self._enqueue_or_false( + session_id, op_type="compact", payload={} + ) async def shutdown_session(self, session_id: str) -> bool: - """Shutdown a specific session.""" - operation = Operation(op_type=OpType.SHUTDOWN) - success = await self.submit(session_id, operation) + """Shutdown a specific session. + + Enqueues a ``shutdown`` op (the consumer drains it inline by setting + ``session.is_running = False``), then releases the lease and awaits + the task locally so ``DELETE`` callers see a clean stop. + + We only acquire ``self._lock`` for the dict lookup — external I/O + (``release_lease`` Mongo round-trip, ``wait_for(task)`` agent loop + drain) runs without the lock so heartbeat snapshots, grace sweeps, + idle eviction, and other shutdowns aren't serialized behind us. + """ + success = await self._enqueue_or_false( + session_id, op_type="shutdown", payload={} + ) if success: async with self._lock: agent_session = self.sessions.get(session_id) - if agent_session and agent_session.task: - # Wait for task to complete - try: - await asyncio.wait_for(agent_session.task, timeout=5.0) - except asyncio.TimeoutError: - agent_session.task.cancel() + if agent_session and agent_session.task: + try: + await self._store().release_lease( + session_id, self._holder_id + ) + except Exception as e: + logger.debug( + f"release_lease failed during shutdown of {session_id}: {e}" + ) + # Wait for task to complete + try: + await asyncio.wait_for(agent_session.task, timeout=5.0) + except asyncio.TimeoutError: + agent_session.task.cancel() return success @@ -895,6 +1568,13 @@ async def delete_session(self, session_id: str) -> bool: # Clean up sandbox Space before cancelling the task await self._cleanup_sandbox(agent_session.session) + try: + await self._store().release_lease(session_id, self._holder_id) + except Exception as e: + logger.debug( + f"release_lease failed during delete of {session_id}: {e}" + ) + # Cancel the task if running if agent_session.task and not agent_session.task.done(): agent_session.task.cancel() diff --git a/backend/start.sh b/backend/start.sh index 72b35198..fc864de7 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -4,7 +4,14 @@ # Only the first instance can bind port 7860 — the rest must exit # with code 0 so the dev mode daemon doesn't mark the app as crashed. -# Run uvicorn; if it fails due to port conflict, exit cleanly. +# Worker mode: no HTTP listener, no port binding — just run the +# session-claim loop forever. Invoked from WORKDIR=/app/backend so the +# module path matches the existing `uvicorn main:app` style below. +if [ "$MODE" = "worker" ]; then + exec python -m worker +fi + +# Main mode (default): existing port-conflict graceful behavior. uvicorn main:app --host 0.0.0.0 --port 7860 EXIT_CODE=$? diff --git a/backend/worker.py b/backend/worker.py new file mode 100644 index 00000000..4d6188f9 --- /dev/null +++ b/backend/worker.py @@ -0,0 +1,8 @@ +"""Worker mode entrypoint. Runs the session-claim loop forever.""" + +import asyncio + +from main import worker_loop + +if __name__ == "__main__": + asyncio.run(worker_loop()) diff --git a/docs/deployment.md b/docs/deployment.md new file mode 100644 index 00000000..7060719a --- /dev/null +++ b/docs/deployment.md @@ -0,0 +1,103 @@ +# Background sessions deployment + +This codebase runs as **two HF Space tiers** for the background-sessions feature: + +1. **Main Space** — current FastAPI/React app, hosts UI and interactive sessions. Started with `MODE=main` (the default). +2. **Worker Space(s)** — same Docker image, `MODE=worker` env var. Run agent loops for backgrounded sessions. No public HTTP routes. + +Both rely on a **MongoDB replica set** (Atlas, or self-hosted with `--replSet`). Change streams require this; the app falls back to 500 ms polling on a single-node deployment, but production should be a replica set. + +## Deploy ordering + +When shipping a new release that touches the agent loop: + +1. **Roll Workers first.** Each Worker reads `pending_submissions` and claims dormant sessions; deploying them ahead of Main means Main never processes a request against an old-protocol Worker. +2. **Then roll Main.** Main's `lifespan` startup runs `MongoSessionStore.init()`, which: + - Backfills `lease={holder_id: null, expires_at: 0}` on sessions with `last_active_at > now-1h` (recoverable) + - Flips older sessions' `runtime_state` to `"idle"` (still recoverable, never `"ended"`) +3. The lifespan shutdown sweep on the OLD Main releases active-turn leases via `release_session_to_background(reason="main_shutdown")`. Workers pick them up within ~30 s (TTL). + +## Pre-deploy blast-radius check + +Run before deploying: + +```js +db.sessions.aggregate([ + { $match: { runtime_state: "processing" } }, + { $count: "active_turns_at_deploy" }, +]) +``` + +Capture as a baseline metric. Each active turn at deploy time will see a `migrating` event then a brief (~30 s) handover window. Sessions with no in-flight work see no user-visible event. + +## Required env vars + +| Var | Default | Effect | +| --- | --- | --- | +| `MODE` | `main` | `worker` flips to the worker-loop entrypoint | +| `MONGODB_URI` | unset | Required for the control plane; without it falls back to `NoopSessionStore` (CLI compatibility) | +| `GRACE_PERIOD_SECONDS` | `180` | SSE-drop grace before background migration | +| `IDLE_EVICTION_SECONDS` | `1800` | Worker idle eviction TTL | + +## Local development + +To run a 2-process stack locally for a "close laptop, come back" drill: + +```bash +# 1) Start a Mongo replica set in Docker +docker run -d --name mongo-rs -p 27017:27017 mongo:7 --replSet rs0 +# Initiate the replica set (one-time) +docker exec mongo-rs mongosh --eval 'rs.initiate()' +docker exec mongo-rs mongosh --eval 'rs.status()' | grep PRIMARY + +# 2) Start Main +MODE=main MONGODB_URI=mongodb://localhost:27017/?replicaSet=rs0 \ + uvicorn backend.main:app --host 0.0.0.0 --port 7860 + +# 3) In another terminal, start Worker(s) +MODE=worker MONGODB_URI=mongodb://localhost:27017/?replicaSet=rs0 \ + python -m backend.worker + +# 4) Run Drill 1 (manually): create a session, close the tab, +# wait > GRACE_PERIOD_SECONDS, reopen, observe the migrating event +# in the SSE stream and verify lease.holder_id flipped to worker:* in db.sessions. +``` + +### Chaos test (verifies change-stream resume token) + +```bash +# Mid-drill, briefly pause the Mongo container: +docker pause mongo-rs && sleep 5 && docker unpause mongo-rs +# SSE should reconnect via resume token without losing events. +``` + +## Observability + +Grep production logs for: + +- `lease_claim`, `lease_release` — lease churn +- `requeue_claimed count=N` with N>0 — handover happened +- `migrating_emitted reason=...` — sessions moving to background +- `replay_event_count` with high counts — long-session replay scan +- `pending_submission_lag` (DEBUG) — Mongo or change-stream backpressure + +## Acceptance drills (run post-deploy) + +### Drill 1 — close laptop, come back + +1. `POST /api/session` → `session_id`. `db.sessions.findOne({_id})` shows `lease.holder_id` matching `main:*`. +2. `POST /api/chat/{session_id}` with a long-running message (HF Job). SSE streams initial events. +3. Close the browser tab. Wait `GRACE_PERIOD_SECONDS + ~10s`. `db.sessions.findOne({_id})` shows `lease.holder_id` matching `worker:*`. A `migrating` event was emitted at handover. +4. Reopen the browser, `GET /api/sessions`, then `GET /api/events/{session_id}?after=`. SSE replays all missed events including any `approval_required`. +5. `POST /api/approve` for any pending tool. Worker resumes the turn within ≤2 s (change-stream) or ≤500 ms × 1 (polling fallback). + +**Pass**: all 5 succeed end-to-end, including across a deliberate `docker restart` of Main between steps 2 and 3. + +### Drill 2 — Main restart with active turn + +1. Start a session on Main, send a message that triggers a long-running tool call. +2. Force-restart Main. Lifespan shutdown sweep emits `migrating` for each in-flight session and calls `release_lease`. +3. Within 30 s a Worker claims the lease (`lease.holder_id` flips to `worker:*`). +4. Fresh Main comes back; user opens new tab, hits `GET /api/events/{id}?after=`. No `interrupted` event from the restart itself. + +**Pass**: single uninterrupted user-visible turn across the restart. diff --git a/frontend/src/components/Chat/ActivityStatusBar.tsx b/frontend/src/components/Chat/ActivityStatusBar.tsx index 3dd0af53..fac435fe 100644 --- a/frontend/src/components/Chat/ActivityStatusBar.tsx +++ b/frontend/src/components/Chat/ActivityStatusBar.tsx @@ -112,6 +112,7 @@ function statusLabel(status: ActivityStatus): string { } case 'waiting-approval': return 'Waiting for approval'; case 'cancelled': return 'What should the agent do instead?'; + case 'migrating': return 'Running in background'; default: return ''; } } diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 99dafeab..09e32f63 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -3,11 +3,13 @@ import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuIte import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; +import BedtimeOutlinedIcon from '@mui/icons-material/BedtimeOutlined'; import { apiFetch } from '@/utils/api'; import { useUserQuota } from '@/hooks/useUserQuota'; import ClaudeCapDialog from '@/components/ClaudeCapDialog'; import JobsUpgradeDialog from '@/components/JobsUpgradeDialog'; import { useAgentStore } from '@/store/agentStore'; +import { useSessionStore } from '@/store/sessionStore'; import { CLAUDE_MODEL_PATH, FIRST_FREE_MODEL_PATH, @@ -86,7 +88,9 @@ interface ChatInputProps { sessionId?: string; onSend: (text: string) => void; onStop?: () => void; + onBackground?: () => Promise; isProcessing?: boolean; + hasPendingApproval?: boolean; disabled?: boolean; placeholder?: string; } @@ -95,10 +99,12 @@ const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath); const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath); const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0]; -export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { +export default function ChatInput({ sessionId, onSend, onStop, onBackground, isProcessing = false, hasPendingApproval = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); + const [isBackgroundLoading, setIsBackgroundLoading] = useState(false); const inputRef = useRef(null); const [modelOptions, setModelOptions] = useState(DEFAULT_MODEL_OPTIONS); + const isBackgrounded = useSessionStore((s) => s.sessions.find((sess) => sess.id === sessionId)?.isBackgrounded ?? false); const modelOptionsRef = useRef(DEFAULT_MODEL_OPTIONS); const sessionIdRef = useRef(sessionId); const [selectedModelId, setSelectedModelId] = useState(DEFAULT_MODEL_OPTIONS[0].id); @@ -231,6 +237,16 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa } catch { /* ignore */ } }; + const handleBackground = useCallback(async () => { + if (!onBackground) return; + setIsBackgroundLoading(true); + try { + await onBackground(); + } finally { + setIsBackgroundLoading(false); + } + }, [onBackground]); + // Dialog close: just clear the flag. The typed text is already restored. const handleCapDialogClose = useCallback(() => { setClaudeQuotaExhausted(false); @@ -385,6 +401,32 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa } }} /> + {(isProcessing || hasPendingApproval) && !isBackgrounded && onBackground ? ( + + + {isBackgroundLoading && ( + + )} + + + + ) : null} {isProcessing ? ( s.id === sessionId)?.expired === true; + const isBackgrounded = sessions.find((s) => s.id === sessionId)?.isBackgrounded === true; const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools } = useAgentChat({ sessionId, @@ -63,6 +65,20 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess updateSession(sessionId, { activityStatus: { type: 'cancelled' } }); }, [stop, updateSession, sessionId]); + const handleBackground = useCallback(async () => { + try { + const res = await apiFetch(`/api/session/${sessionId}/background`, { method: 'POST' }); + if (!res.ok) { + console.error(`Background request failed: ${res.status}`); + return; + } + updateSession(sessionId, { activityStatus: { type: 'migrating' } }); + setBackgrounded(sessionId, true); + } catch (e) { + console.error('Failed to send session to background:', e); + } + }, [sessionId, updateSession, setBackgrounded]); + // SDK status is the ground truth — if it's streaming/submitted, agent is busy const sdkBusy = status === 'streaming' || status === 'submitted'; const busy = isProcessing || sdkBusy; @@ -107,6 +123,27 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess onUndoLastTurn={undoLastTurn} onEditAndRegenerate={editAndRegenerate} /> + {(isBackgrounded || activityStatus.type === 'migrating') && ( + + + Running in background — you can close this tab and come back anytime. + + + )} {isExpired ? ( ) : ( @@ -114,7 +151,9 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess sessionId={sessionId} onSend={handleSendMessage} onStop={handleStop} + onBackground={handleBackground} isProcessing={busy} + hasPendingApproval={activityStatus.type === 'waiting-approval'} disabled={!isConnected || activityStatus.type === 'waiting-approval'} placeholder={ activityStatus.type === 'waiting-approval' diff --git a/frontend/src/components/SessionSidebar/SessionSidebar.tsx b/frontend/src/components/SessionSidebar/SessionSidebar.tsx index 243e48c7..ece92e45 100644 --- a/frontend/src/components/SessionSidebar/SessionSidebar.tsx +++ b/frontend/src/components/SessionSidebar/SessionSidebar.tsx @@ -16,6 +16,7 @@ import { import AddIcon from '@mui/icons-material/Add'; import DeleteOutlineIcon from '@mui/icons-material/DeleteOutline'; import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline'; +import CloudOutlinedIcon from '@mui/icons-material/CloudOutlined'; import { useSessionStore } from '@/store/sessionStore'; import { useAgentStore } from '@/store/agentStore'; import { apiFetch } from '@/utils/api'; @@ -255,14 +256,25 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) { }, }} > - + {session.isBackgrounded ? ( + + ) : ( + + )} { /* no-op — handled by stop() caller */ }, + onMigrating: () => { + updateSession(sessionId, { activityStatus: { type: 'migrating' } }); + setBackgrounded(sessionId, true); + // Reconnect to the slow path (change-stream SSE) is handled in US-FE-002; + // the transport's reconnectToStream() will route to the non-holder path + // because holder_id !== session_manager._holder_id after backgrounding. + transportRef.current?.reconnectToStream(); + }, }), // eslint-disable-next-line react-hooks/exhaustive-deps [sessionId], @@ -398,6 +406,14 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD return; } + // If session info 404s but messages exist, the session was backgrounded + // and its in-memory state was cleaned up. Treat as backgrounded rather + // than expired so the banner renders and the user can reload without + // losing their conversation. + if (infoRes.status === 404 && msgsRes.ok) { + setBackgrounded(sessionId, true); + } + let pendingIds: Set | undefined; let backendIsProcessing = false; if (infoRes.ok) { @@ -564,6 +580,10 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD const state = event.data?.state as string; const toolName = event.data?.tool as string; if (state === 'running' && toolName) sideChannel.onToolRunning(toolName); + } else if (et === 'migrating') { + sideChannel.onMigrating(); + stopReconnect(); + return true; } else if (et === 'turn_complete' || et === 'error' || et === 'interrupted') { sideChannel.onProcessingDone(); stopReconnect(); diff --git a/frontend/src/lib/sse-chat-transport.ts b/frontend/src/lib/sse-chat-transport.ts index cfe94789..5325ba5e 100644 --- a/frontend/src/lib/sse-chat-transport.ts +++ b/frontend/src/lib/sse-chat-transport.ts @@ -40,6 +40,7 @@ export interface SideChannelCallbacks { onStreaming: () => void; onToolRunning: (toolName: string, description?: string) => void; onInterrupted: () => void; + onMigrating: () => void; } // --------------------------------------------------------------------------- @@ -318,6 +319,13 @@ function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformS break; } + case 'migrating': + endTextPart(controller); + controller.enqueue({ type: 'finish-step' }); + controller.enqueue({ type: 'finish', finishReason: 'stop' }); + sideChannel.onMigrating(); + break; + default: logger.log('SSE transport: unknown event', event); } diff --git a/frontend/src/store/agentStore.ts b/frontend/src/store/agentStore.ts index 08a68a29..65a5e873 100644 --- a/frontend/src/store/agentStore.ts +++ b/frontend/src/store/agentStore.ts @@ -62,7 +62,8 @@ export type ActivityStatus = | { type: 'tool'; toolName: string; description?: string } | { type: 'waiting-approval' } | { type: 'streaming' } - | { type: 'cancelled' }; + | { type: 'cancelled' } + | { type: 'migrating' }; export interface ResearchAgentStats { toolCount: number; diff --git a/frontend/src/store/sessionStore.ts b/frontend/src/store/sessionStore.ts index 4115ef6b..c72dbfe7 100644 --- a/frontend/src/store/sessionStore.ts +++ b/frontend/src/store/sessionStore.ts @@ -44,6 +44,8 @@ interface SessionStore { * Used when we rehydrate an expired session into a freshly-created backend * session — preserves title, timestamps, and messages. */ renameSession: (oldId: string, newId: string) => void; + /** Mark a session as backgrounded (sticky — no path back to false). */ + setBackgrounded: (id: string, value: boolean) => void; } export const useSessionStore = create()( @@ -212,6 +214,14 @@ export const useSessionStore = create()( ), })); }, + + setBackgrounded: (id: string, value: boolean) => { + set((state) => ({ + sessions: state.sessions.map((s) => + s.id === id ? { ...s, isBackgrounded: value } : s + ), + })); + }, }), { name: 'hf-agent-sessions', diff --git a/frontend/src/types/agent.ts b/frontend/src/types/agent.ts index 7737399d..1c37b6b4 100644 --- a/frontend/src/types/agent.ts +++ b/frontend/src/types/agent.ts @@ -25,6 +25,7 @@ export interface SessionMeta { autoApprovalCostCapUsd?: number | null; autoApprovalEstimatedSpendUsd?: number; autoApprovalRemainingUsd?: number | null; + isBackgrounded?: boolean; } export interface ToolApproval { diff --git a/frontend/src/types/events.ts b/frontend/src/types/events.ts index 54795827..d998a4a0 100644 --- a/frontend/src/types/events.ts +++ b/frontend/src/types/events.ts @@ -18,6 +18,7 @@ export type EventType = | 'error' | 'shutdown' | 'interrupted' + | 'migrating' | 'undo_complete' | 'plan_update';