diff --git a/.claude/commands/wtnew.md b/.claude/commands/wtnew.md index 59d9746cd..9fef56647 100644 --- a/.claude/commands/wtnew.md +++ b/.claude/commands/wtnew.md @@ -1,6 +1,6 @@ # 创建 Worktree -基于最新 `origin/main` 创建隔离的 worktree 开发环境。 +基于最新 `origin/dev` 创建隔离的 worktree 开发环境,并自动拉起一个并行工作的 Kitty + Codex 开发位。 ## 参数 @@ -21,7 +21,7 @@ PROJECT_NAME=$(basename "$MAIN_REPO") git fetch origin ``` -确保基于最新的 `origin/main` 创建,避免从过时的 base 分叉。 +确保基于最新的 `origin/dev` 创建,避免从过时的 base 分叉。 ## Step 2:启用 worktreeConfig @@ -38,7 +38,7 @@ git config extensions.worktreeConfig true 路径规则:`~/worktrees/<项目名>--<目录名>`(如 `~/worktrees/leon--feat-eval`) ```bash -git worktree add "$HOME/worktrees/$PROJECT_NAME--<目录名>" -b $ARGUMENTS origin/main +git worktree add "$HOME/worktrees/$PROJECT_NAME--<目录名>" -b $ARGUMENTS origin/dev ``` - worktree 存放在 `~/worktrees/`,与主仓库完全隔离 @@ -163,16 +163,44 @@ ln -s "$MAIN_REPO/CLAUDE.local.md" CLAUDE.local.md 2>/dev/null 输出: - worktree 路径 - 分支名 +- base 分支(必须明确是 `origin/dev`) - 分配的端口(backend / frontend) - 自动生成的描述 - `CLAUDE.local.md` 符号链接状态 -询问用户:是否在新 worktree 中打开新的 Claude 会话? +## Step 9:自动拉起 Kitty + Codex 并行工作位 -如果是,用 osascript 打开新终端并启动 claude(**必须将路径替换为实际计算出的完整绝对路径,不得使用变量或占位符**): +不要再询问“是否打开新的 Claude 会话”。默认直接拉起一个新的 Kitty tab,并在里面启动 Codex。 + +要求: +- tab title 固定为 `dev-feature` +- Codex 必须在新建好的 worktree 路径里启动 +- 必须用实际计算出的完整绝对路径,不得保留变量或占位符 +- 如果当前 shell 没有 `KITTY_LISTEN_ON`,要明确报错并停下,不要静默跳过 + +执行命令(**必须将路径替换为实际计算出的完整绝对路径,不得使用变量或占位符**): ```bash -osascript -e 'tell app "Terminal" to do script "cd \"/Users/apple/worktrees/<项目名>--<目录名>\" && claude"' +if [ -z "$KITTY_LISTEN_ON" ]; then + echo "❌ 错误:未设置 KITTY_LISTEN_ON,无法自动创建 dev-feature kitty tab" + exit 1 +fi + +kitty @ --to "$KITTY_LISTEN_ON" launch \ + --type tab \ + --tab-title "dev-feature" \ + --title "dev-feature" \ + zsh -lc 'cd "/Users/apple/worktrees/<项目名>--<目录名>" && codex --cd "/Users/apple/worktrees/<项目名>--<目录名>"' ``` -关键:`cd` 和 `claude` 必须写在 osascript 的 `do script` 字符串内部,不是写在外层 Bash 命令里。 +关键: +- `cd` 和 `codex --cd ...` 必须写在新 tab 的命令字符串内部 +- `codex --cd` 和前面的 `cd` 都必须指向同一个实际 worktree 绝对路径 +- 不要退回 Terminal / osascript;这里的标准交互面就是 Kitty tab + +## Step 10:最终输出 + +除了原有输出,再追加: +- `Codex tab: dev-feature` +- `Codex cwd: ` +- 如果启动成功,明确说明“并行开发位已就绪” diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b384072f3..4a11dc769 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,8 +2,9 @@ name: CI on: push: - branches: [main] + branches: [main, dev] pull_request: + branches: [main, dev] jobs: lint: @@ -41,6 +42,10 @@ jobs: - name: Run tests # --maxfail=5: surface up to 5 failures per platform before stopping # e2e tests self-skip via skipif when provider secrets are absent + env: + SUPABASE_PUBLIC_URL: ${{ secrets.SUPABASE_PUBLIC_URL }} + LEON_SUPABASE_SERVICE_ROLE_KEY: ${{ secrets.SUPABASE_SERVICE_KEY }} + SUPABASE_ANON_KEY: ${{ secrets.SUPABASE_ANON_KEY }} run: uv run pytest tests/ --ignore=tests/test_e2e_providers.py --ignore=tests/test_sandbox_e2e.py --ignore=tests/test_daytona_e2e.py --ignore=tests/test_e2e_backend_api.py --ignore=tests/test_e2e_summary_persistence.py --ignore=tests/test_p3_e2e.py --maxfail=5 --timeout=60 -q frontend: diff --git a/.github/workflows/deploy-staging.yml b/.github/workflows/deploy-staging.yml index ee18d0d38..54eee564c 100644 --- a/.github/workflows/deploy-staging.yml +++ b/.github/workflows/deploy-staging.yml @@ -7,6 +7,9 @@ name: Deploy Staging # Both update the staging apps to the target branch, then deploy. on: + push: + branches: + - pr188-agent-optimize pull_request: types: [labeled] workflow_dispatch: @@ -23,9 +26,12 @@ jobs: deploy-staging: # For label trigger: only run when the label is exactly "deploy-staging" if: > + github.event_name == 'push' || github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && github.event.label.name == 'deploy-staging') runs-on: ubuntu-latest + env: + STAGING_STACK_UUID: fasbsube26s75ag6qus5bpi2 steps: - name: Resolve target ref @@ -33,33 +39,99 @@ jobs: run: | if [ "${{ github.event_name }}" = "pull_request" ]; then echo "ref=${{ github.head_ref }}" >> "$GITHUB_OUTPUT" + elif [ "${{ github.event_name }}" = "push" ]; then + echo "ref=${{ github.ref_name }}" >> "$GITHUB_OUTPUT" else echo "ref=${{ inputs.ref }}" >> "$GITHUB_OUTPUT" fi - - name: Update staging backend branch + - name: Check out target ref + uses: actions/checkout@v4 + with: + ref: ${{ steps.ref.outputs.ref }} + + - name: Resolve target commit + id: target run: | - curl -s -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${{ secrets.COOLIFY_BACKEND_STAGING_UUID }}" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" \ - -H "Content-Type: application/json" \ - -d '{"git_branch": "${{ steps.ref.outputs.ref }}"}' + set -euo pipefail + echo "sha=$(git rev-parse HEAD)" >> "$GITHUB_OUTPUT" + + - name: Assert repo staging compose contract + run: | + set -euo pipefail + grep -F "leon-home:/root/.leon" docker-compose.yml >/dev/null + grep -F "volumes:" docker-compose.yml >/dev/null - - name: Update staging frontend branch + - name: Update staging stack branch run: | - curl -s -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${{ secrets.COOLIFY_FRONTEND_STAGING_UUID }}" \ + set -euo pipefail + body="$(curl -sS --fail-with-body -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${STAGING_STACK_UUID}" \ -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" \ -H "Content-Type: application/json" \ - -d '{"git_branch": "${{ steps.ref.outputs.ref }}"}' + -d "{\"git_branch\": \"${{ steps.ref.outputs.ref }}\"}")" + echo "$body" + printf '%s' "$body" | jq -e --arg uuid "$STAGING_STACK_UUID" '.uuid == $uuid' >/dev/null + + - name: Deploy staging stack + id: deploy + run: | + set -euo pipefail + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${STAGING_STACK_UUID}&force=false" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + echo "$body" + printf '%s' "$body" | jq -e --arg uuid "$STAGING_STACK_UUID" '.deployments[0].resource_uuid == $uuid' >/dev/null + echo "deployment_uuid=$(printf '%s' "$body" | jq -r '.deployments[0].deployment_uuid')" >> "$GITHUB_OUTPUT" + + - name: Wait for staging deployment + run: | + set -euo pipefail + deployment_uuid="${{ steps.deploy.outputs.deployment_uuid }}" + for _ in $(seq 1 60); do + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/deployments/${deployment_uuid}" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + status="$(printf '%s' "$body" | jq -r '.status')" + echo "deployment status: $status" + if [ "$status" = "finished" ]; then + exit 0 + fi + if [ "$status" != "queued" ] && [ "$status" != "in_progress" ]; then + echo "$body" + exit 1 + fi + sleep 10 + done + echo "Timed out waiting for staging deployment ${deployment_uuid}" + exit 1 - - name: Deploy backend to staging + - name: Verify Coolify staging contract run: | - curl -sX GET "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${{ secrets.COOLIFY_BACKEND_STAGING_UUID }}&force=false" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" + set -euo pipefail + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/applications/${STAGING_STACK_UUID}" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + echo "$body" | jq '{uuid,git_branch,docker_compose_location}' + printf '%s' "$body" | jq -e --arg ref "${{ steps.ref.outputs.ref }}" '.git_branch == $ref' >/dev/null + printf '%s' "$body" | jq -e '.docker_compose_raw | contains("leon-home:/root/.leon")' >/dev/null + printf '%s' "$body" | jq -e --arg volume "${STAGING_STACK_UUID}_leon-home:/root/.leon" '.docker_compose | contains($volume)' >/dev/null + printf '%s' "$body" | jq -e --arg sha "${{ steps.target.outputs.sha }}" '.docker_compose | contains($sha)' >/dev/null - - name: Deploy frontend to staging + - name: Verify staging health contract run: | - curl -sX GET "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${{ secrets.COOLIFY_FRONTEND_STAGING_UUID }}&force=false" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" + set -euo pipefail + for attempt in $(seq 1 18); do + status="$(curl -sS -o /tmp/staging-health.json -w '%{http_code}' "https://app.staging.mycel.nextmind.space/api/monitor/health")" + echo "health attempt ${attempt}: status=${status}" + if [ "$status" = "200" ]; then + body="$(cat /tmp/staging-health.json)" + echo "$body" + printf '%s' "$body" | jq -e '.db.path == "/root/.leon/sandbox.db"' >/dev/null + printf '%s' "$body" | jq -e '.db.exists == true' >/dev/null + exit 0 + fi + cat /tmp/staging-health.json || true + sleep 10 + done + echo "Staging health contract did not become ready in time" + exit 1 - name: Comment on PR with staging URL if: github.event_name == 'pull_request' @@ -70,5 +142,5 @@ jobs: issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, - body: `🚀 **预发部署已触发**\n\n- 前端: https://app.staging.mycel.nextmind.space\n- 后端: https://api.staging.mycel.nextmind.space\n\n分支: \`${{ steps.ref.outputs.ref }}\`` + body: `🚀 **预发部署已触发**\n\n- 共享 Staging: https://app.staging.mycel.nextmind.space\n- API(同域反代): https://app.staging.mycel.nextmind.space/api\n\n分支: \`${{ steps.ref.outputs.ref }}\`` }) diff --git a/.gitignore b/.gitignore index be4d3c775..e24215ae8 100644 --- a/.gitignore +++ b/.gitignore @@ -102,6 +102,8 @@ worktrees/ # Development artifacts — never commit docs/lessons/ docs/plans/ +docs/superpowers/plans/ +docs/superpowers/specs/ frontend/.vite/ .playwright-cli/ ops diff --git a/Dockerfile b/Dockerfile index e875ed19f..36bb7bf5a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,11 +7,13 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv # Install dependencies (cached layer before source copy) COPY pyproject.toml uv.lock ./ -RUN uv sync --frozen --no-dev --no-install-project +# @@@sandbox-sdk-image-parity - shared staging/provider inventory should reflect runtime truth, +# not "SDK missing from image" accidents while config files are present. +RUN uv sync --frozen --no-dev --extra sandbox --extra e2b --extra daytona --no-install-project # Copy source and install project COPY . . -RUN uv sync --frozen --no-dev +RUN uv sync --frozen --no-dev --extra sandbox --extra e2b --extra daytona ENV PATH="/app/.venv/bin:$PATH" diff --git a/README.md b/README.md index a7fdc9af7..f75571e6f 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Full-featured web platform for managing and interacting with agents: ### Multi-Agent Communication -Agents are first-class social entities. They can discover each other, send messages, and collaborate autonomously: +Agents are first-class social entities. They can list chats, read messages, send messages, and collaborate autonomously: ``` Member (template) @@ -103,8 +103,10 @@ Member (template) └→ Thread (agent brain / conversation) ``` -- **`chat_send`**: Agent A messages Agent B; B responds autonomously -- **`directory`**: Agents browse and discover other entities +- **`list_chats`**: List active conversations with unread counts and participants +- **`read_messages`**: Read message history before responding +- **`send_message`**: Agent A messages Agent B; B responds autonomously +- **`search_messages`**: Search message history across chats - **Real-time delivery**: SSE-based chat with typing indicators and read receipts Humans also have entities — agents can initiate conversations with humans, not just the other way around. diff --git a/README.zh.md b/README.zh.md index 12bb8981a..1b3d31c87 100644 --- a/README.zh.md +++ b/README.zh.md @@ -95,7 +95,7 @@ cd frontend/app && npm run dev ### 多 Agent 通讯 -Agent 是一等公民的社交实体,可以互相发现、发送消息、自主协作: +Agent 是一等公民的社交实体,可以列出对话、读取消息、发送消息、自主协作: ``` Member(模板) @@ -103,8 +103,10 @@ Member(模板) └→ Thread(Agent 大脑 / 对话) ``` -- **`chat_send`**:Agent A 给 Agent B 发消息,B 自主回复 -- **`directory`**:Agent 浏览和发现其他实体 +- **`list_chats`**:列出活跃对话、未读数和参与者 +- **`read_messages`**:先读取消息历史,再决定如何回复 +- **`send_message`**:Agent A 给 Agent B 发消息,B 自主回复 +- **`search_messages`**:跨对话搜索消息历史 - **实时投递**:基于 SSE 的聊天,支持输入提示和已读回执 人类也有 Entity——Agent 可以主动找人类对话,而不只是被动响应。 diff --git a/backend/taskboard/_service_loader.py b/backend/taskboard/_service_loader.py new file mode 100644 index 000000000..c59e44605 --- /dev/null +++ b/backend/taskboard/_service_loader.py @@ -0,0 +1,25 @@ +"""Typed task_service loader for taskboard surfaces.""" + +from __future__ import annotations + +from typing import Any, Protocol, cast + + +class TaskServiceProtocol(Protocol): + def list_tasks(self) -> list[dict[str, Any]]: ... + def get_task(self, task_id: str) -> dict[str, Any] | None: ... + def get_highest_priority_pending_task(self) -> dict[str, Any] | None: ... + def create_task(self, **fields: Any) -> dict[str, Any]: ... + def update_task(self, task_id: str, **fields: Any) -> dict[str, Any] | None: ... + + +try: + from backend.web.services import task_service as _task_service +except ImportError: + _task_service = None + + +def require_task_service() -> TaskServiceProtocol: + if _task_service is None: + raise RuntimeError("backend.web.services.task_service is unavailable") + return cast(TaskServiceProtocol, _task_service) diff --git a/backend/taskboard/middleware.py b/backend/taskboard/middleware.py index 69a274624..6f9f3f83f 100644 --- a/backend/taskboard/middleware.py +++ b/backend/taskboard/middleware.py @@ -16,7 +16,7 @@ import json import logging import time -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from typing import Any from langchain.agents.middleware.types import ( @@ -26,12 +26,9 @@ ToolCallRequest, ) from langchain_core.messages import ToolMessage +from langchain_core.messages.tool import ToolCall -# Lazy import: backend is only available when running as web service -try: - from backend.web.services import task_service -except ImportError: - task_service = None # type: ignore[assignment] +from backend.taskboard._service_loader import require_task_service logger = logging.getLogger(__name__) @@ -76,7 +73,7 @@ def __init__( # Tool schemas # ------------------------------------------------------------------ - def _get_tool_schemas(self) -> list[dict]: + def _get_tool_schemas(self) -> list[dict[str, Any]]: """Return OpenAI-format function schemas, filtered by blocked_tools.""" schemas = [ { @@ -263,7 +260,7 @@ async def awrap_tool_call( # Dispatch # ------------------------------------------------------------------ - def _handle_tool_call(self, tool_call: dict) -> ToolMessage: + def _handle_tool_call(self, tool_call: Mapping[str, Any] | ToolCall) -> ToolMessage: tool_name = tool_call.get("name") tool_id = tool_call.get("id", "") args = tool_call.get("args", {}) @@ -292,6 +289,7 @@ def _handle_tool_call(self, tool_call: dict) -> ToolMessage: def _handle_list(self, args: dict) -> dict: """List board tasks with optional status/priority filter.""" + task_service = require_task_service() try: tasks = task_service.list_tasks() except Exception as e: @@ -310,6 +308,7 @@ def _handle_list(self, args: dict) -> dict: def _handle_claim(self, args: dict) -> dict: """Claim a task: set running + thread_id + started_at.""" + task_service = require_task_service() task_id = args.get("TaskId", "") now_ms = int(time.time() * 1000) updated = task_service.update_task( @@ -324,6 +323,7 @@ def _handle_claim(self, args: dict) -> dict: def _handle_progress(self, args: dict) -> dict: """Update task progress and optionally append a note.""" + task_service = require_task_service() task_id = args.get("TaskId", "") progress = args.get("Progress", 0) @@ -346,6 +346,7 @@ def _handle_progress(self, args: dict) -> dict: def _handle_complete(self, args: dict) -> dict: """Complete a task with result.""" + task_service = require_task_service() task_id = args.get("TaskId", "") result_text = args.get("Result", "") now_ms = int(time.time() * 1000) @@ -362,6 +363,7 @@ def _handle_complete(self, args: dict) -> dict: def _handle_fail(self, args: dict) -> dict: """Fail a task with reason.""" + task_service = require_task_service() task_id = args.get("TaskId", "") reason = args.get("Reason", "") now_ms = int(time.time() * 1000) @@ -381,6 +383,7 @@ def _handle_fail(self, args: dict) -> dict: async def on_idle(self) -> dict[str, Any] | None: """Called when agent enters IDLE state. Returns highest-priority pending task, or None.""" + task_service = require_task_service() return await asyncio.to_thread(task_service.get_highest_priority_pending_task) # ------------------------------------------------------------------ @@ -389,6 +392,7 @@ async def on_idle(self) -> dict[str, Any] | None: def _handle_create(self, args: dict) -> dict: """Create a board task with source='agent'.""" + task_service = require_task_service() try: task = task_service.create_task( title=args.get("Title", "New task"), diff --git a/backend/taskboard/service.py b/backend/taskboard/service.py index e1c99b568..e00a32b65 100644 --- a/backend/taskboard/service.py +++ b/backend/taskboard/service.py @@ -17,14 +17,9 @@ import time from typing import Any +from backend.taskboard._service_loader import require_task_service from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry -# Lazy import: backend is only available when running as web service -try: - from backend.web.services import task_service -except ImportError: - task_service = None # type: ignore[assignment] - logger = logging.getLogger(__name__) @@ -218,6 +213,7 @@ def _get_thread_id(self) -> str: # ------------------------------------------------------------------ async def _list_tasks(self, Status: str = "", Priority: str = "") -> str: + task_service = require_task_service() try: tasks = await asyncio.to_thread(task_service.list_tasks) except Exception as e: @@ -232,6 +228,7 @@ async def _list_tasks(self, Status: str = "", Priority: str = "") -> str: return json.dumps({"tasks": tasks, "total": len(tasks)}, ensure_ascii=False) async def _claim_task(self, TaskId: str) -> str: + task_service = require_task_service() thread_id = self._get_thread_id() now_ms = int(time.time() * 1000) try: @@ -250,6 +247,7 @@ async def _claim_task(self, TaskId: str) -> str: return json.dumps({"task": updated}, ensure_ascii=False) async def _update_progress(self, TaskId: str, Progress: int, Note: str = "") -> str: + task_service = require_task_service() update_kwargs: dict[str, Any] = {"progress": Progress} if Note: @@ -273,6 +271,7 @@ async def _update_progress(self, TaskId: str, Progress: int, Note: str = "") -> return json.dumps({"task": updated}, ensure_ascii=False) async def _complete_task(self, TaskId: str, Result: str) -> str: + task_service = require_task_service() now_ms = int(time.time() * 1000) try: updated = await asyncio.to_thread( @@ -291,6 +290,7 @@ async def _complete_task(self, TaskId: str, Result: str) -> str: return json.dumps({"task": updated}, ensure_ascii=False) async def _fail_task(self, TaskId: str, Reason: str) -> str: + task_service = require_task_service() now_ms = int(time.time() * 1000) try: updated = await asyncio.to_thread( @@ -308,6 +308,7 @@ async def _fail_task(self, TaskId: str, Reason: str) -> str: return json.dumps({"task": updated}, ensure_ascii=False) async def _create_task(self, Title: str, Description: str = "", Priority: str = "medium") -> str: + task_service = require_task_service() try: task = await asyncio.to_thread( task_service.create_task, @@ -327,4 +328,5 @@ async def _create_task(self, Title: str, Description: str = "", Priority: str = async def on_idle(self) -> dict[str, Any] | None: """Called when agent enters IDLE state. Returns highest-priority pending task, or None.""" + task_service = require_task_service() return await asyncio.to_thread(task_service.get_highest_priority_pending_task) diff --git a/backend/web/core/config.py b/backend/web/core/config.py index 23da41471..ab9d87372 100644 --- a/backend/web/core/config.py +++ b/backend/web/core/config.py @@ -4,10 +4,9 @@ from pathlib import Path from config.user_paths import user_home_path -from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path -# Database paths -DB_PATH = resolve_role_db_path(SQLiteDBRole.MAIN) +# Legacy DB_PATH — used only by SQLite sandbox repos as default path +DB_PATH = user_home_path("leon.db") SANDBOXES_DIR = user_home_path("sandboxes") SANDBOX_VOLUME_ROOT = Path(os.environ.get("LEON_SANDBOX_VOLUME_ROOT", str(user_home_path("volumes")))).expanduser().resolve() diff --git a/backend/web/core/dependencies.py b/backend/web/core/dependencies.py index 52bc277a0..85ece805b 100644 --- a/backend/web/core/dependencies.py +++ b/backend/web/core/dependencies.py @@ -1,7 +1,6 @@ """FastAPI dependency injection functions.""" import asyncio -import os from typing import Annotated, Any from fastapi import Depends, FastAPI, HTTPException, Request @@ -9,18 +8,6 @@ from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox from sandbox.thread_context import set_current_thread_id -# Dev bypass: set LEON_DEV_SKIP_AUTH=1 to skip JWT verification and inject a mock identity. -# WARNING: this bypasses ALL auth — never set in production. -_DEV_SKIP_AUTH = os.environ.get("LEON_DEV_SKIP_AUTH", "").lower() in ("1", "true", "yes") -_DEV_PAYLOAD = {"user_id": "dev-user"} - -if _DEV_SKIP_AUTH: - import logging as _logging - - _logging.getLogger(__name__).warning( - "LEON_DEV_SKIP_AUTH is active — JWT auth is BYPASSED for all requests. This must never be enabled in production." - ) - async def get_app(request: Request) -> FastAPI: """Get FastAPI app instance from request.""" @@ -37,8 +24,6 @@ def _get_auth_service(app: FastAPI): def _extract_jwt_payload(request: Request) -> dict: """Extract and verify JWT payload from Bearer token. Returns {user_id}.""" - if _DEV_SKIP_AUTH: - return _DEV_PAYLOAD auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") @@ -52,8 +37,6 @@ def _extract_jwt_payload(request: Request) -> dict: async def get_current_user_id(request: Request) -> str: """Extract user_id from JWT and verify user exists. Returns 401 if user was deleted (e.g. DB reset).""" user_id = _extract_jwt_payload(request)["user_id"] - if _DEV_SKIP_AUTH: - return user_id member_repo = getattr(request.app.state, "member_repo", None) if member_repo and member_repo.get_by_id(user_id) is None: raise HTTPException(401, "User no longer exists — please re-login") diff --git a/backend/web/core/lifespan.py b/backend/web/core/lifespan.py index 13a76a4b2..b985254ec 100644 --- a/backend/web/core/lifespan.py +++ b/backend/web/core/lifespan.py @@ -3,193 +3,73 @@ import asyncio import os from contextlib import asynccontextmanager -from typing import Any +from typing import Any, cast from fastapi import FastAPI +from psycopg import AsyncConnection from backend.web.services.event_buffer import RunEventBuffer, ThreadEventBuffer from backend.web.services.idle_reaper import idle_reaper_loop -from backend.web.services.resource_cache import resource_overview_refresh_loop -from config.env_manager import ConfigManager +from backend.web.services.resource_cache import monitor_resource_overview_refresh_loop from core.runtime.middleware.queue import MessageQueueManager -def _seed_dev_user(app: FastAPI) -> None: - """Create dev-user human member + initial agents if not yet seeded. +def _require_web_runtime_contract() -> None: + # @@@web-checkpointer-contract - web routes can create LeonAgent on first + # message, so missing Postgres checkpointer config is a startup contract + # violation, not a late per-request error. + if not os.getenv("LEON_POSTGRES_URL"): + raise RuntimeError("LEON_POSTGRES_URL is required for backend web runtime") - Mirrors AuthService.register() but uses the fixed 'dev-user' ID that - matches _DEV_PAYLOAD, so list_members('dev-user') returns results. - """ - import logging - import time - from pathlib import Path - from backend.web.services.member_service import MEMBERS_DIR, _write_agent_md, _write_json - from storage.contracts import MemberRow, MemberType - from storage.providers.sqlite.member_repo import generate_member_id +async def _validate_web_checkpointer_contract() -> None: + pg_url = os.getenv("LEON_POSTGRES_URL") + if not pg_url: + raise RuntimeError("LEON_POSTGRES_URL is required for backend web runtime") - log = logging.getLogger(__name__) - member_repo = app.state.member_repo - - dev_user_id = "dev-user" - - if member_repo.get_by_id(dev_user_id) is not None: - return # already seeded - - log.info("DEV: seeding dev-user member + initial agents") - now = time.time() - - # Human member row - member_repo.create( - MemberRow( - id=dev_user_id, - name="Dev", - type=MemberType.HUMAN, - created_at=now, - ) - ) - - # Initial agents (same as register()) - initial_agents = [ - {"name": "Toad", "description": "Curious and energetic assistant", "avatar": "toad.jpeg"}, - {"name": "Morel", "description": "Thoughtful senior analyst", "avatar": "morel.jpeg"}, - ] - assets_dir = Path(__file__).resolve().parents[3] / "assets" - - for agent_def in initial_agents: - agent_id = generate_member_id() - agent_dir = MEMBERS_DIR / agent_id - agent_dir.mkdir(parents=True, exist_ok=True) - _write_agent_md(agent_dir / "agent.md", name=agent_def["name"], description=agent_def["description"]) - _write_json( - agent_dir / "meta.json", - { - "status": "active", - "version": "1.0.0", - "created_at": int(now * 1000), - "updated_at": int(now * 1000), - }, - ) - member_repo.create( - MemberRow( - id=agent_id, - name=agent_def["name"], - type=MemberType.MYCEL_AGENT, - description=agent_def["description"], - config_dir=str(agent_dir), - owner_user_id=dev_user_id, - created_at=now, - ) - ) - src_avatar = assets_dir / agent_def["avatar"] - if src_avatar.exists(): - try: - from backend.web.routers.entities import process_and_save_avatar - - avatar_path = process_and_save_avatar(src_avatar, agent_id) - member_repo.update(agent_id, avatar=avatar_path, updated_at=now) - except Exception as e: - log.warning("DEV: avatar copy failed for %s: %s", agent_def["name"], e) + conn = await AsyncConnection.connect(pg_url) + try: + async with conn.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchone() + finally: + await conn.close() @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI lifespan context manager for startup and shutdown.""" - # Load configuration - config_manager = ConfigManager() - config_manager.load_to_env() - - # Ensure event store table exists (lazy init, not at module import) - from backend.web.services.event_store import init_event_store - - init_event_store() - - from backend.web.services.library_service import ensure_library_dir - from backend.web.services.member_service import ensure_members_dir - - ensure_members_dir() - ensure_library_dir() - - # ---- Entity-Chat repos + services ---- - _storage_strategy = os.getenv("LEON_STORAGE_STRATEGY", "sqlite") - - if _storage_strategy == "supabase": - from backend.web.core.supabase_factory import create_supabase_client - from storage.container import StorageContainer - from storage.providers.supabase import ( - SupabaseAccountRepo, - SupabaseChatEntityRepo, - SupabaseChatMessageRepo, - SupabaseChatRepo, - SupabaseContactRepo, - SupabaseEntityRepo, - SupabaseInviteCodeRepo, - SupabaseMemberRepo, - SupabaseRecipeRepo, - SupabaseThreadLaunchPrefRepo, - SupabaseThreadRepo, - SupabaseUserSettingsRepo, - ) - - _supabase_client = create_supabase_client() - app.state.member_repo = SupabaseMemberRepo(_supabase_client) - app.state.account_repo = SupabaseAccountRepo(_supabase_client) - app.state.entity_repo = SupabaseEntityRepo(_supabase_client) - app.state.thread_repo = SupabaseThreadRepo(_supabase_client) - app.state.thread_launch_pref_repo = SupabaseThreadLaunchPrefRepo(_supabase_client) - app.state.recipe_repo = SupabaseRecipeRepo(_supabase_client) - app.state.chat_repo = SupabaseChatRepo(_supabase_client) - app.state.chat_entity_repo = SupabaseChatEntityRepo(_supabase_client) - app.state.chat_message_repo = SupabaseChatMessageRepo(_supabase_client) - app.state.invite_code_repo = SupabaseInviteCodeRepo(_supabase_client) - app.state.user_settings_repo = SupabaseUserSettingsRepo(_supabase_client) - app.state._supabase_client = _supabase_client - app.state._storage_container = StorageContainer(strategy="supabase", supabase_client=_supabase_client) - else: - from storage.providers.sqlite.chat_repo import SQLiteChatEntityRepo, SQLiteChatMessageRepo, SQLiteChatRepo - from storage.providers.sqlite.entity_repo import SQLiteEntityRepo - from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path - from storage.providers.sqlite.member_repo import SQLiteAccountRepo, SQLiteMemberRepo - from storage.providers.sqlite.recipe_repo import SQLiteRecipeRepo - from storage.providers.sqlite.thread_launch_pref_repo import SQLiteThreadLaunchPrefRepo - from storage.providers.sqlite.thread_repo import SQLiteThreadRepo - - db = resolve_role_db_path(SQLiteDBRole.MAIN) - chat_db = resolve_role_db_path(SQLiteDBRole.CHAT) - - app.state.member_repo = SQLiteMemberRepo(db) - app.state.account_repo = SQLiteAccountRepo(db) - app.state.entity_repo = SQLiteEntityRepo(db) - app.state.thread_repo = SQLiteThreadRepo(db) - app.state.thread_launch_pref_repo = SQLiteThreadLaunchPrefRepo(db) - app.state.recipe_repo = SQLiteRecipeRepo(db) - app.state.chat_repo = SQLiteChatRepo(chat_db) - app.state.chat_entity_repo = SQLiteChatEntityRepo(chat_db) - app.state.chat_message_repo = SQLiteChatMessageRepo(chat_db) + _require_web_runtime_contract() + await _validate_web_checkpointer_contract() + + # ---- Member-Chat repos + services ---- + from backend.web.core.supabase_factory import create_supabase_auth_client, create_supabase_client + from storage.container import StorageContainer + + _supabase_client = create_supabase_client() + storage_container = StorageContainer(supabase_client=_supabase_client) + app.state.member_repo = storage_container.member_repo() + app.state.thread_repo = storage_container.thread_repo() + app.state.thread_launch_pref_repo = storage_container.thread_launch_pref_repo() + app.state.recipe_repo = storage_container.recipe_repo() + app.state.chat_repo = storage_container.chat_repo() + app.state.invite_code_repo = storage_container.invite_code_repo() + app.state.user_settings_repo = storage_container.user_settings_repo() + app.state.agent_config_repo = storage_container.agent_config_repo() + app.state.panel_task_repo = storage_container.panel_task_repo() + app.state.cron_job_repo = storage_container.cron_job_repo() + app.state._supabase_client = _supabase_client + app.state._supabase_auth_client_factory = create_supabase_auth_client + app.state._storage_container = storage_container from backend.web.services.auth_service import AuthService - if _storage_strategy == "supabase": - app.state.auth_service = AuthService( - members=app.state.member_repo, - accounts=app.state.account_repo, - entities=app.state.entity_repo, - supabase_client=_supabase_client, - invite_codes=app.state.invite_code_repo, - ) - else: - app.state.auth_service = AuthService( - members=app.state.member_repo, - accounts=app.state.account_repo, - entities=app.state.entity_repo, - supabase_client=None, - ) - - # Dev bypass: seed dev-user + initial agents on first startup - from backend.web.core.dependencies import _DEV_SKIP_AUTH - - if _DEV_SKIP_AUTH: - _seed_dev_user(app) + app.state.auth_service = AuthService( + members=app.state.member_repo, + supabase_client=_supabase_client, + supabase_auth_client_factory=create_supabase_auth_client, + invite_codes=app.state.invite_code_repo, + ) from backend.web.services.chat_events import ChatEventBus from backend.web.services.typing_tracker import TypingTracker @@ -197,92 +77,91 @@ async def lifespan(app: FastAPI): app.state.chat_event_bus = ChatEventBus() app.state.typing_tracker = TypingTracker(app.state.chat_event_bus) - from backend.web.services.delivery_resolver import DefaultDeliveryResolver + app.state.contact_repo = storage_container.contact_repo() - if _storage_strategy == "supabase": - app.state.contact_repo = SupabaseContactRepo(_supabase_client) - else: - from storage.providers.sqlite.contact_repo import SQLiteContactRepo + # Wire chat delivery after event loop is available + # ---- Messaging system (Supabase-backed, required) ---- + from backend.web.core.supabase_factory import create_messaging_supabase_client + from core.agents.communication.delivery import make_chat_delivery_fn + from messaging.delivery.resolver import HireVisitDeliveryResolver + from messaging.relationships.service import RelationshipService + from messaging.service import MessagingService + from storage.providers.supabase.messaging_repo import ( + SupabaseChatMemberRepo, + SupabaseMessageReadRepo, + SupabaseMessagesRepo, + SupabaseRelationshipRepo, + ) - app.state.contact_repo = SQLiteContactRepo(chat_db) + _msg_supabase = create_messaging_supabase_client() + _chat_member_repo = SupabaseChatMemberRepo(_msg_supabase) + _messages_repo = SupabaseMessagesRepo(_msg_supabase) + _message_read_repo = SupabaseMessageReadRepo(_msg_supabase) + app.state.relationship_repo = SupabaseRelationshipRepo(_msg_supabase) + app.state.chat_member_repo = _chat_member_repo + app.state.messages_repo = _messages_repo - delivery_resolver = DefaultDeliveryResolver(app.state.contact_repo, app.state.chat_entity_repo) + app.state.relationship_service = RelationshipService( + app.state.relationship_repo, + member_repo=app.state.member_repo, + thread_repo=app.state.thread_repo, + ) - from backend.web.services.chat_service import ChatService + _msg_delivery_resolver = HireVisitDeliveryResolver( + contact_repo=app.state.contact_repo, + chat_member_repo=_chat_member_repo, + relationship_repo=app.state.relationship_repo, + ) - app.state.chat_service = ChatService( + app.state.messaging_service = MessagingService( chat_repo=app.state.chat_repo, - chat_entity_repo=app.state.chat_entity_repo, - chat_message_repo=app.state.chat_message_repo, - entity_repo=app.state.entity_repo, + chat_member_repo=_chat_member_repo, + messages_repo=_messages_repo, + message_read_repo=_message_read_repo, member_repo=app.state.member_repo, + thread_repo=app.state.thread_repo, event_bus=app.state.chat_event_bus, - delivery_resolver=delivery_resolver, + delivery_resolver=_msg_delivery_resolver, ) - - # Wire chat delivery after event loop is available - from core.agents.communication.delivery import make_chat_delivery_fn - - app.state.chat_service.set_delivery_fn(make_chat_delivery_fn(app)) + app.state.messaging_service.set_delivery_fn(make_chat_delivery_fn(app)) # ---- Existing state ---- app.state.queue_manager = MessageQueueManager() - app.state.agent_pool: dict[str, Any] = {} - app.state.thread_sandbox: dict[str, str] = {} - app.state.thread_cwd: dict[str, str] = {} - app.state.thread_locks: dict[str, asyncio.Lock] = {} + app.state.agent_pool = cast(dict[str, Any], {}) + app.state.thread_sandbox = cast(dict[str, str], {}) + app.state.thread_cwd = cast(dict[str, str], {}) + app.state.thread_locks = cast(dict[str, asyncio.Lock], {}) app.state.thread_locks_guard = asyncio.Lock() - app.state.thread_tasks: dict[str, asyncio.Task] = {} - app.state.thread_event_buffers: dict[str, ThreadEventBuffer] = {} - app.state.subagent_buffers: dict[str, RunEventBuffer] = {} + app.state.thread_tasks = cast(dict[str, asyncio.Task[Any]], {}) + app.state.thread_event_buffers = cast(dict[str, ThreadEventBuffer], {}) + app.state.subagent_buffers = cast(dict[str, RunEventBuffer], {}) from backend.web.services.display_builder import DisplayBuilder app.state.display_builder = DisplayBuilder() - app.state.thread_last_active: dict[str, float] = {} # thread_id → epoch timestamp - app.state.idle_reaper_task: asyncio.Task | None = None + app.state.thread_last_active = cast(dict[str, float], {}) # thread_id → epoch timestamp + app.state.idle_reaper_task = cast(asyncio.Task[Any] | None, None) app.state.cron_service = None app.state._event_loop = asyncio.get_running_loop() - app.state.monitor_resources_task: asyncio.Task | None = None + app.state.monitor_resources_task = cast(asyncio.Task[Any] | None, None) try: # Start idle reaper background task app.state.idle_reaper_task = asyncio.create_task(idle_reaper_loop(app)) # Start resource overview refresh loop - app.state.monitor_resources_task = asyncio.create_task(resource_overview_refresh_loop()) + app.state.monitor_resources_task = asyncio.create_task(monitor_resource_overview_refresh_loop()) # Start cron scheduler from backend.web.services.cron_service import CronService - cron_svc = CronService() + cron_svc = CronService( + cron_job_repo=app.state.cron_job_repo, + task_repo=app.state.panel_task_repo, + ) await cron_svc.start() app.state.cron_service = cron_svc - # @@@wechat-registry — create registry with delivery callback, auto-start all - from backend.web.services.wechat_service import WeChatConnectionRegistry, migrate_entity_id_dirs - from core.runtime.middleware.queue.formatters import format_wechat_message - - migrate_entity_id_dirs() - - async def _wechat_deliver(conn, msg): - """Delivery callback — routes WeChat messages to configured thread/chat.""" - routing = conn.routing - if not routing.type or not routing.id: - return - sender_name = msg.from_user_id.split("@")[0] or msg.from_user_id - if routing.type == "thread": - from backend.web.services.message_routing import route_message_to_brain - - content = format_wechat_message(sender_name, msg.from_user_id, msg.text) - await route_message_to_brain(app, routing.id, content, source="owner", sender_name=sender_name) - elif routing.type == "chat": - content = format_wechat_message(sender_name, msg.from_user_id, msg.text) - app.state.chat_service.send_message(routing.id, conn.user_id, content) - - app.state.wechat_registry = WeChatConnectionRegistry(delivery_fn=_wechat_deliver) - app.state.wechat_registry.auto_start_all() - yield finally: # @@@background-task-shutdown-order - cancel monitor/reaper before provider cleanup. @@ -295,10 +174,6 @@ async def _wechat_deliver(conn, msg): except asyncio.CancelledError: pass - # Cleanup: stop WeChat connections - if hasattr(app.state, "wechat_registry"): - await app.state.wechat_registry.shutdown() - # Cleanup: stop cron scheduler if app.state.cron_service: await app.state.cron_service.stop() @@ -312,3 +187,8 @@ async def _wechat_deliver(conn, msg): agent.close() except Exception as e: print(f"[web] Agent cleanup error: {e}") + + # Cleanup: stop LSP language servers + from core.tools.lsp.service import lsp_pool + + await lsp_pool.close_all() diff --git a/backend/web/core/storage_factory.py b/backend/web/core/storage_factory.py index 8e189dd9d..8f63d3333 100644 --- a/backend/web/core/storage_factory.py +++ b/backend/web/core/storage_factory.py @@ -6,15 +6,10 @@ from __future__ import annotations -import os from functools import lru_cache from typing import Any -def _strategy() -> str: - return os.getenv("LEON_STORAGE_STRATEGY", "sqlite") - - @lru_cache(maxsize=1) def _supabase_client() -> Any: from backend.web.core.supabase_factory import create_supabase_client @@ -23,90 +18,24 @@ def _supabase_client() -> Any: def make_panel_task_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.panel_task_repo import SupabasePanelTaskRepo - - return SupabasePanelTaskRepo(_supabase_client()) - from backend.web.core.config import DB_PATH - from storage.providers.sqlite.panel_task_repo import SQLitePanelTaskRepo + from storage.providers.supabase.panel_task_repo import SupabasePanelTaskRepo - return SQLitePanelTaskRepo(db_path=DB_PATH) + return SupabasePanelTaskRepo(_supabase_client()) def make_cron_job_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.cron_job_repo import SupabaseCronJobRepo + from storage.providers.supabase.cron_job_repo import SupabaseCronJobRepo - return SupabaseCronJobRepo(_supabase_client()) - from backend.web.core.config import DB_PATH - from storage.providers.sqlite.cron_job_repo import SQLiteCronJobRepo - - return SQLiteCronJobRepo(db_path=DB_PATH) + return SupabaseCronJobRepo(_supabase_client()) def make_sandbox_monitor_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.sandbox_monitor_repo import SupabaseSandboxMonitorRepo - - return SupabaseSandboxMonitorRepo(_supabase_client()) from storage.providers.sqlite.sandbox_monitor_repo import SQLiteSandboxMonitorRepo return SQLiteSandboxMonitorRepo() -def make_agent_registry_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.agent_registry_repo import SupabaseAgentRegistryRepo - - return SupabaseAgentRegistryRepo(_supabase_client()) - from storage.providers.sqlite.agent_registry_repo import SQLiteAgentRegistryRepo - - return SQLiteAgentRegistryRepo() - - -def make_tool_task_repo(db_path: Any = None) -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.tool_task_repo import SupabaseToolTaskRepo - - return SupabaseToolTaskRepo(_supabase_client()) - from storage.providers.sqlite.tool_task_repo import SQLiteToolTaskRepo - - if db_path is None: - from core.tools.task.service import DEFAULT_DB_PATH - - db_path = DEFAULT_DB_PATH - return SQLiteToolTaskRepo(db_path=db_path) - - -def make_sync_file_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.sync_file_repo import SupabaseSyncFileRepo - - return SupabaseSyncFileRepo(_supabase_client()) - from storage.providers.sqlite.sync_file_repo import SQLiteSyncFileRepo - - return SQLiteSyncFileRepo() - - -def upsert_resource_snapshot(**kwargs: Any) -> None: - """Strategy-aware resource snapshot upsert.""" - if _strategy() == "supabase": - from storage.providers.supabase.resource_snapshot_repo import upsert_lease_resource_snapshot - - upsert_lease_resource_snapshot(**kwargs, client=_supabase_client()) - else: - from storage.providers.sqlite.resource_snapshot_repo import upsert_lease_resource_snapshot - - kwargs.pop("client", None) - upsert_lease_resource_snapshot(**kwargs) - - def list_resource_snapshots(lease_ids: list[str]) -> dict[str, Any]: - """Strategy-aware resource snapshot list.""" - if _strategy() == "supabase": - from storage.providers.supabase.resource_snapshot_repo import list_snapshots_by_lease_ids - - return list_snapshots_by_lease_ids(lease_ids, client=_supabase_client()) - from storage.providers.sqlite.resource_snapshot_repo import list_snapshots_by_lease_ids + from storage.providers.supabase.resource_snapshot_repo import list_snapshots_by_lease_ids - return list_snapshots_by_lease_ids(lease_ids) + return list_snapshots_by_lease_ids(lease_ids, client=_supabase_client()) diff --git a/backend/web/core/supabase_factory.py b/backend/web/core/supabase_factory.py index c8dc9abd1..2e3cfca26 100644 --- a/backend/web/core/supabase_factory.py +++ b/backend/web/core/supabase_factory.py @@ -1,4 +1,4 @@ -"""Runtime Supabase client factory for storage wiring.""" +"""Runtime Supabase client factories for storage and auth wiring.""" from __future__ import annotations @@ -6,6 +6,19 @@ import httpx from supabase import ClientOptions, create_client +from supabase_auth._sync.gotrue_client import SyncGoTrueClient + + +def _resolve_supabase_url() -> str: + url = os.getenv("SUPABASE_INTERNAL_URL") or os.getenv("SUPABASE_PUBLIC_URL") + if not url: + raise RuntimeError("SUPABASE_INTERNAL_URL or SUPABASE_PUBLIC_URL is required.") + return url + + +def _resolve_supabase_auth_url() -> str: + url = os.getenv("SUPABASE_AUTH_URL") or _resolve_supabase_url() + return url def create_supabase_client(): @@ -16,13 +29,46 @@ def create_supabase_client(): httpx client never routes through any system/VPN proxy. """ # Prefer internal URL (same-host direct connection) over public tunnel URL. - url = os.getenv("SUPABASE_INTERNAL_URL") or os.getenv("SUPABASE_PUBLIC_URL") + url = _resolve_supabase_url() key = os.getenv("LEON_SUPABASE_SERVICE_ROLE_KEY") - if not url: - raise RuntimeError("SUPABASE_INTERNAL_URL or SUPABASE_PUBLIC_URL is required.") if not key: raise RuntimeError("LEON_SUPABASE_SERVICE_ROLE_KEY is required for Supabase storage runtime.") schema = os.getenv("LEON_DB_SCHEMA", "public") timeout = httpx.Timeout(30.0, connect=10.0) http_client = httpx.Client(timeout=timeout, trust_env=False) return create_client(url, key, options=ClientOptions(httpx_client=http_client, schema=schema)) + + +def create_supabase_auth_client(): + """Build a supabase-py auth client for end-user auth flows. + + Uses the anon key rather than service-role credentials so auth endpoints + behave like real caller traffic instead of admin/server traffic. + """ + url = _resolve_supabase_auth_url() + key = os.getenv("SUPABASE_ANON_KEY") + if not key: + raise RuntimeError("SUPABASE_ANON_KEY is required for Supabase auth runtime.") + timeout = httpx.Timeout(30.0, connect=10.0) + http_client = httpx.Client(timeout=timeout, trust_env=False) + auth_url = os.getenv("SUPABASE_AUTH_URL") + if auth_url: + # @@@direct-gotrue - local auth may bypass Kong and hit GoTrue directly at /token. + return SyncGoTrueClient(url=auth_url, headers={"apikey": key}, http_client=http_client) + return create_client(url, key, options=ClientOptions(httpx_client=http_client)) + + +def create_messaging_supabase_client(): + """Build a server-side Supabase client for messaging repos. + + @@@messaging-public-schema - messaging tables still live in public while + main product storage moved to LEON_DB_SCHEMA, so this client must stay on + public and use server credentials. + """ + url = _resolve_supabase_url() + key = os.getenv("LEON_SUPABASE_SERVICE_ROLE_KEY") + if not key: + raise RuntimeError("LEON_SUPABASE_SERVICE_ROLE_KEY is required for messaging.") + timeout = httpx.Timeout(30.0, connect=10.0) + http_client = httpx.Client(timeout=timeout, trust_env=False) + return create_client(url, key, options=ClientOptions(httpx_client=http_client, schema="public")) diff --git a/backend/web/main.py b/backend/web/main.py index 64f60e0a5..8f6252bbe 100644 --- a/backend/web/main.py +++ b/backend/web/main.py @@ -1,10 +1,7 @@ """Leon Web Backend - FastAPI Application.""" import os -import sqlite3 import subprocess -import sys -from pathlib import Path # Load .env file if ENV_FILE is specified (e.g. ENV_FILE=.env for local dev) _env_file = os.getenv("ENV_FILE") @@ -17,85 +14,25 @@ from fastapi import FastAPI # noqa: E402 from fastapi.middleware.cors import CORSMiddleware # noqa: E402 - -def _ensure_windows_db_env_defaults() -> None: - """On Windows, default Leon DBs to a LOCALAPPDATA-backed path.""" - if sys.platform != "win32": - return - - root = _resolve_windows_db_root() - root.mkdir(parents=True, exist_ok=True) - defaults = { - "LEON_DB_PATH": root / "leon.db", - "LEON_RUN_EVENT_DB_PATH": root / "events.db", - "LEON_QUEUE_DB_PATH": root / "queue.db", - "LEON_CHAT_DB_PATH": root / "chat.db", - "LEON_SANDBOX_DB_PATH": root / "sandbox.db", - "LEON_SUBAGENT_DB_PATH": root / "subagent.db", - "LEON_EVAL_DB_PATH": root / "eval.db", - } - for key, value in defaults.items(): - os.environ.setdefault(key, str(value)) - - -def _resolve_windows_db_root() -> Path: - local_appdata = Path(os.getenv("LOCALAPPDATA") or (Path.home() / "AppData" / "Local")) - candidates = [ - local_appdata / "Leon", - Path.home() / ".codex" / "memories" / "mycel-run", - Path.home() / ".leon-win", - ] - seen: set[Path] = set() - for root in candidates: - if root in seen: - continue - seen.add(root) - if _sqlite_root_supports_wal(root): - return root - return candidates[0] - - -def _sqlite_root_supports_wal(root: Path) -> bool: - probe = root / ".leon-probe.db" - conn: sqlite3.Connection | None = None - try: - root.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(probe), timeout=1.0) - mode = conn.execute("PRAGMA journal_mode=WAL").fetchone() - conn.execute("CREATE TABLE IF NOT EXISTS _probe(x INTEGER)") - conn.commit() - return bool(mode and str(mode[0]).lower() == "wal") - except Exception: - return False - finally: - if conn is not None: - conn.close() - for suffix in ("", "-wal", "-shm"): - try: - (root / f".leon-probe.db{suffix}").unlink(missing_ok=True) - except OSError: - pass - - -_ensure_windows_db_env_defaults() - from backend.web.core.lifespan import lifespan # noqa: E402 from backend.web.routers import ( # noqa: E402 auth, - chats, - connections, - debug, + contacts, + conversations, # noqa: E402 entities, invite_codes, marketplace, monitor, panel, + resources, sandbox, settings, thread_files, threads, webhooks, ) +from backend.web.routers import messaging as messaging_router # noqa: E402 +from messaging.relationships.router import router as relationships_router # noqa: E402 # Create FastAPI app app = FastAPI(title="Leon Web Backend", lifespan=lifespan) @@ -113,19 +50,23 @@ def _sqlite_root_supports_wal(root: Path) -> bool: app.include_router(auth.router) app.include_router(invite_codes.router) app.include_router(threads.router) -app.include_router(chats.router) + +app.include_router(messaging_router.router) + +app.include_router(contacts.router) +app.include_router(relationships_router) app.include_router(entities.router) app.include_router(entities.members_router) app.include_router(sandbox.router) app.include_router(webhooks.router) -app.include_router(connections.router) app.include_router(thread_files.router) app.include_router(thread_files._public) app.include_router(settings.router) -app.include_router(debug.router) app.include_router(panel.router) app.include_router(monitor.router) +app.include_router(resources.router) app.include_router(marketplace.router) +app.include_router(conversations.router) def _resolve_port() -> int: @@ -158,5 +99,5 @@ def _resolve_port() -> int: host="0.0.0.0", port=port, reload=True, - reload_dirs=["backend", "core", "config", "storage", "sandbox"], + reload_dirs=["backend", "core", "config", "storage", "sandbox", "messaging"], ) diff --git a/backend/web/models/requests.py b/backend/web/models/requests.py index 05a108bf0..582ec7f4c 100644 --- a/backend/web/models/requests.py +++ b/backend/web/models/requests.py @@ -1,8 +1,8 @@ """Pydantic request models for Leon web API.""" -from typing import Literal +from typing import Any, Literal -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field from sandbox.config import MountSpec @@ -20,7 +20,7 @@ class RecipeSnapshotRequest(BaseModel): class CreateThreadRequest(BaseModel): member_id: str # which agent template to create thread from - sandbox: str = "local" + sandbox: str = Field(default="local", validation_alias=AliasChoices("sandbox", "sandbox_type")) recipe: RecipeSnapshotRequest | None = None lease_id: str | None = None cwd: str | None = None @@ -53,3 +53,22 @@ class RunRequest(BaseModel): class SendMessageRequest(BaseModel): message: str attachments: list[str] = Field(default_factory=list) + + +class AskUserAnswerRequest(BaseModel): + header: str | None = None + question: str | None = None + selected_options: list[str] = Field(default_factory=list) + free_text: str | None = None + + +class ResolvePermissionRequest(BaseModel): + decision: Literal["allow", "deny"] + message: str | None = None + answers: list[AskUserAnswerRequest] | None = None + annotations: dict[str, Any] | None = None + + +class ThreadPermissionRuleRequest(BaseModel): + behavior: Literal["allow", "deny", "ask"] + tool_name: str diff --git a/backend/web/routers/auth.py b/backend/web/routers/auth.py index 5c5f87b5b..582a642fa 100644 --- a/backend/web/routers/auth.py +++ b/backend/web/routers/auth.py @@ -11,6 +11,15 @@ router = APIRouter(prefix="/api/auth", tags=["auth"]) +async def _call_auth_service(app: Any, status_code: int, method_name: str, *args: Any) -> Any: + try: + service = _get_auth_service(app) + method = getattr(service, method_name) + return await asyncio.to_thread(method, *args) + except ValueError as e: + raise HTTPException(status_code, str(e)) + + # ── Registration step 1: send OTP ────────────────────────────────────────── @@ -22,11 +31,8 @@ class SendOtpRequest(BaseModel): @router.post("/send-otp") async def send_otp(payload: SendOtpRequest, app: Annotated[Any, Depends(get_app)]) -> dict: - try: - await asyncio.to_thread(_get_auth_service(app).send_otp, payload.email, payload.password, payload.invite_code) - return {"ok": True} - except ValueError as e: - raise HTTPException(400, str(e)) + await _call_auth_service(app, 400, "send_otp", payload.email, payload.password, payload.invite_code) + return {"ok": True} # ── Registration step 2: verify OTP ──────────────────────────────────────── @@ -39,10 +45,7 @@ class VerifyOtpRequest(BaseModel): @router.post("/verify-otp") async def verify_otp(payload: VerifyOtpRequest, app: Annotated[Any, Depends(get_app)]) -> dict: - try: - return await asyncio.to_thread(_get_auth_service(app).verify_register_otp, payload.email, payload.token) - except ValueError as e: - raise HTTPException(400, str(e)) + return await _call_auth_service(app, 400, "verify_register_otp", payload.email, payload.token) # ── Registration step 3: set password + invite code ──────────────────────── @@ -55,10 +58,7 @@ class CompleteRegisterRequest(BaseModel): @router.post("/complete-register") async def complete_register(payload: CompleteRegisterRequest, app: Annotated[Any, Depends(get_app)]) -> dict: - try: - return await asyncio.to_thread(_get_auth_service(app).complete_register, payload.temp_token, payload.invite_code) - except ValueError as e: - raise HTTPException(400, str(e)) + return await _call_auth_service(app, 400, "complete_register", payload.temp_token, payload.invite_code) # ── Login ─────────────────────────────────────────────────────────────────── @@ -71,7 +71,4 @@ class LoginRequest(BaseModel): @router.post("/login") async def login(payload: LoginRequest, app: Annotated[Any, Depends(get_app)]) -> dict: - try: - return await asyncio.to_thread(_get_auth_service(app).login, payload.identifier, payload.password) - except ValueError as e: - raise HTTPException(401, str(e)) + return await _call_auth_service(app, 401, "login", payload.identifier, payload.password) diff --git a/backend/web/routers/chats.py b/backend/web/routers/chats.py deleted file mode 100644 index 5e7e3ff9e..000000000 --- a/backend/web/routers/chats.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Chat API router — entity-to-entity communication.""" - -import asyncio -import json -import logging -from typing import Annotated, Any, Literal - -from fastapi import APIRouter, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -from backend.web.core.dependencies import get_app, get_current_user_id -from backend.web.utils.serializers import avatar_url - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/chats", tags=["chats"]) - - -class CreateChatBody(BaseModel): - user_ids: list[str] - title: str | None = None - - -class SendMessageBody(BaseModel): - content: str - sender_id: str - mentioned_ids: list[str] | None = None - - -@router.get("") -async def list_chats( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """List all chats for the current user (social identity from JWT).""" - return app.state.chat_service.list_chats_for_user(user_id) - - -@router.post("") -async def create_chat( - body: CreateChatBody, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Create a chat between users. 2 users = 1:1 chat, 3+ = group chat.""" - chat_service = app.state.chat_service - try: - if len(body.user_ids) >= 3: - chat = chat_service.create_group_chat(body.user_ids, body.title) - else: - chat = chat_service.find_or_create_chat(body.user_ids, body.title) - return {"id": chat.id, "title": chat.title, "status": chat.status, "created_at": chat.created_at} - except ValueError as e: - raise HTTPException(400, str(e)) - - -@router.get("/{chat_id}") -async def get_chat( - chat_id: str, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Get chat details with member list.""" - chat = app.state.chat_repo.get_by_id(chat_id) - if not chat: - raise HTTPException(404, "Chat not found") - participants = app.state.chat_entity_repo.list_participants(chat_id) - entity_repo = app.state.entity_repo - member_repo = app.state.member_repo - entities_info = [] - for p in participants: - e = entity_repo.get_by_id(p.user_id) - if e: - m = member_repo.get_by_id(e.member_id) - entities_info.append( - { - "id": p.user_id, - "name": e.name, - "type": e.type, - "avatar_url": avatar_url(e.member_id, bool(m.avatar if m else None)), - } - ) - else: - m = member_repo.get_by_id(p.user_id) - if m: - entities_info.append( - { - "id": p.user_id, - "name": m.name, - "type": "human", - "avatar_url": avatar_url(m.id, bool(m.avatar)), - } - ) - return { - "id": chat.id, - "title": chat.title, - "status": chat.status, - "created_at": chat.created_at, - "entities": entities_info, - } - - -@router.get("/{chat_id}/messages") -async def list_messages( - chat_id: str, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], - limit: int = Query(50, ge=1, le=200), - before: float | None = Query(None), -): - """List messages in a chat.""" - msgs = app.state.chat_message_repo.list_by_chat(chat_id, limit=limit, before=before) - entity_repo = app.state.entity_repo - member_repo = app.state.member_repo - sender_ids = {m.sender_id for m in msgs} - sender_names: dict[str, str] = {} - for sid in sender_ids: - e = entity_repo.get_by_id(sid) - if e: - sender_names[sid] = e.name - else: - m = member_repo.get_by_id(sid) - sender_names[sid] = m.name if m else "unknown" - return [ - { - "id": m.id, - "chat_id": m.chat_id, - "sender_id": m.sender_id, - "sender_name": sender_names.get(m.sender_id, "unknown"), - "content": m.content, - "mentioned_ids": m.mentioned_ids, - "created_at": m.created_at, - } - for m in msgs - ] - - -@router.post("/{chat_id}/read") -async def mark_read( - chat_id: str, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Mark all messages in this chat as read for the current user.""" - import time - - app.state.chat_entity_repo.update_last_read(chat_id, user_id, time.time()) - return {"status": "ok"} - - -@router.post("/{chat_id}/messages") -async def send_message( - chat_id: str, - body: SendMessageBody, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Send a message in a chat.""" - if not body.content.strip(): - raise HTTPException(400, "Content cannot be empty") - # Verify sender_id belongs to the authenticated user - _verify_participant_ownership(app, body.sender_id, user_id) - chat_service = app.state.chat_service - msg = chat_service.send_message(chat_id, body.sender_id, body.content, body.mentioned_ids) - return { - "id": msg.id, - "chat_id": msg.chat_id, - "sender_id": msg.sender_id, - "content": msg.content, - "mentioned_ids": msg.mentioned_ids, - "created_at": msg.created_at, - } - - -@router.get("/{chat_id}/events") -async def stream_chat_events( - chat_id: str, - token: str | None = None, - app: Annotated[Any, Depends(get_app)] = None, -): - """SSE stream for chat events. Uses ?token= for auth.""" - from backend.web.core.dependencies import _DEV_SKIP_AUTH - - if not _DEV_SKIP_AUTH: - if not token: - raise HTTPException(401, "Missing token") - try: - app.state.auth_service.verify_token(token) - except ValueError as e: - raise HTTPException(401, str(e)) - - event_bus = app.state.chat_event_bus - queue = event_bus.subscribe(chat_id) - - async def event_generator(): - try: - yield "retry: 5000\n\n" - while True: - try: - event = await asyncio.wait_for(queue.get(), timeout=30) - event_type = event.get("event", "message") - data = event.get("data", {}) - yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - except TimeoutError: - yield ": keepalive\n\n" - finally: - event_bus.unsubscribe(chat_id, queue) - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -# --------------------------------------------------------------------------- -# Contact management (block/mute) -# --------------------------------------------------------------------------- - - -class SetContactBody(BaseModel): - owner_id: str - target_id: str - relation: Literal["normal", "blocked", "muted"] - - -def _verify_participant_ownership(app: Any, participant_id: str, user_id: str) -> None: - """Raise 403 if participant_id does not belong to the authenticated user. - - For humans: participant_id == user_id (direct match). - For agents: participant_id == member_id, and agent_member.owner_user_id == user_id. - """ - if participant_id == user_id: - return - # Check if it's an agent member owned by this user - agent_member = app.state.member_repo.get_by_id(participant_id) - if agent_member and agent_member.owner_user_id == user_id: - return - raise HTTPException(403, "Participant does not belong to you") - - -@router.post("/contacts") -async def set_contact( - body: SetContactBody, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Set a directional contact relationship (block/mute/normal).""" - _verify_participant_ownership(app, body.owner_id, user_id) - import time - - from storage.contracts import ContactRow - - contact_repo = app.state.contact_repo - contact_repo.upsert( - ContactRow( - owner_id=body.owner_id, - target_id=body.target_id, - relation=body.relation, - created_at=time.time(), - updated_at=time.time(), - ) - ) - return {"status": "ok", "relation": body.relation} - - -@router.delete("/contacts/{owner_id}/{target_id}") -async def delete_contact( - owner_id: str, - target_id: str, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Delete a contact relationship.""" - _verify_participant_ownership(app, owner_id, user_id) - contact_repo = app.state.contact_repo - contact_repo.delete(owner_id, target_id) - return {"status": "deleted"} - - -# --------------------------------------------------------------------------- -# Chat mute -# --------------------------------------------------------------------------- - - -class MuteChatBody(BaseModel): - user_id: str - muted: bool - mute_until: float | None = None - - -@router.post("/{chat_id}/mute") -async def mute_chat( - chat_id: str, - body: MuteChatBody, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Mute/unmute a chat for the current user.""" - _verify_participant_ownership(app, body.user_id, user_id) - chat_entity_repo = app.state.chat_entity_repo - chat_entity_repo.update_mute(chat_id, body.user_id, body.muted, body.mute_until) - return {"status": "ok", "muted": body.muted} - - -@router.delete("/{chat_id}") -async def delete_chat( - chat_id: str, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -): - """Delete a chat. Caller must be a participant.""" - chat = app.state.chat_repo.get_by_id(chat_id) - if not chat: - raise HTTPException(404, "Chat not found") - if not app.state.chat_entity_repo.is_participant_in_chat(chat_id, user_id): - raise HTTPException(403, "Not a participant of this chat") - app.state.chat_repo.delete(chat_id) - return {"status": "deleted"} diff --git a/backend/web/routers/connections.py b/backend/web/routers/connections.py deleted file mode 100644 index c5fa0adc2..000000000 --- a/backend/web/routers/connections.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Connection endpoints — manage external platform connections (WeChat, etc.). - -@@@per-user — all endpoints scoped by user_id (the user's social identity). -""" - -from typing import Annotated, Any - -from fastapi import APIRouter, Depends, HTTPException - -from backend.web.core.dependencies import get_app, get_current_user_id -from backend.web.services.wechat_service import ( - QrPollRequest, - RoutingConfig, - RoutingSetRequest, - WeChatConnectionRegistry, -) - -router = APIRouter(prefix="/api/connections", tags=["connections"]) - - -def _get_registry(app: Any) -> WeChatConnectionRegistry: - return app.state.wechat_registry - - -# --- WeChat --- - - -@router.get("/wechat/state") -async def wechat_state( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - return _get_registry(app).get(user_id).get_state() - - -@router.post("/wechat/qrcode") -async def wechat_qrcode( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - conn = _get_registry(app).get(user_id) - if conn.connected: - raise HTTPException(400, "Already connected. Disconnect first.") - return await conn.get_qr_code() - - -@router.post("/wechat/qrcode/poll") -async def wechat_qrcode_poll( - body: QrPollRequest, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - registry = _get_registry(app) - conn = registry.get(user_id) - result = await conn.poll_qr_status(body.qrcode) - # Evict duplicates after successful connection - if result.get("status") == "confirmed" and conn._credentials: - registry.evict_duplicates(conn._credentials.account_id, user_id) - return result - - -@router.post("/wechat/disconnect") -async def wechat_disconnect( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).disconnect() - return {"ok": True} - - -@router.post("/wechat/polling/start") -async def wechat_start_polling( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - conn = _get_registry(app).get(user_id) - if not conn.connected: - raise HTTPException(400, "Not connected") - conn.start_polling() - return {"ok": True, "polling": True} - - -@router.post("/wechat/polling/stop") -async def wechat_stop_polling( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).stop_polling() - return {"ok": True, "polling": False} - - -# --- Routing config --- - - -@router.get("/wechat/routing") -async def wechat_get_routing( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - return _get_registry(app).get(user_id).routing.model_dump() - - -@router.post("/wechat/routing") -async def wechat_set_routing( - body: RoutingSetRequest, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).set_routing(RoutingConfig(type=body.type, id=body.id, label=body.label)) - return {"ok": True} - - -@router.delete("/wechat/routing") -async def wechat_clear_routing( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).set_routing(RoutingConfig()) - return {"ok": True} - - -# --- List targets for routing picker --- - - -@router.get("/wechat/routing/targets") -async def wechat_routing_targets( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - """List available threads and chats for the routing picker.""" - from backend.web.utils.serializers import avatar_url - - raw_threads = app.state.thread_repo.list_by_owner_user_id(user_id) - threads = [ - { - "id": t["id"], - "label": t.get("entity_name") or t.get("member_name") or t["id"][:12], - "avatar_url": avatar_url(t.get("member_id"), bool(t.get("member_avatar"))), - } - for t in raw_threads - ] - - raw_chats = app.state.chat_service.list_chats_for_user(user_id) - chats = [] - for c in raw_chats: - others = [e for e in c.get("entities", []) if e["id"] != user_id] - name = ", ".join(e["name"] for e in others) or "Unknown" - chats.append({"id": c["id"], "label": name}) - - return {"threads": threads, "chats": chats} diff --git a/backend/web/routers/contacts.py b/backend/web/routers/contacts.py new file mode 100644 index 000000000..689ff0f8b --- /dev/null +++ b/backend/web/routers/contacts.py @@ -0,0 +1,68 @@ +"""Contacts API router — /api/contacts endpoints.""" + +from __future__ import annotations + +import time +from typing import Annotated, Any, Literal + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from backend.web.core.dependencies import get_app, get_current_user_id +from storage.contracts import ContactRow + +router = APIRouter(prefix="/api/contacts", tags=["contacts"]) + + +class SetContactBody(BaseModel): + target_id: str + relation: Literal["normal", "blocked", "muted"] + + +@router.get("") +async def list_contacts( + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + """List contacts (blocked/muted) for the current user.""" + rows = app.state.contact_repo.list_for_user(user_id) + return [ + { + "owner_user_id": row.owner_id, + "target_user_id": row.target_id, + "relation": row.relation, + "created_at": row.created_at, + "updated_at": row.updated_at, + } + for row in rows + ] + + +@router.post("") +async def set_contact( + body: SetContactBody, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + """Upsert contact (block/mute/normal).""" + app.state.contact_repo.upsert( + ContactRow( + owner_id=user_id, + target_id=body.target_id, + relation=body.relation, + created_at=time.time(), + updated_at=time.time(), + ) + ) + return {"status": "ok", "relation": body.relation} + + +@router.delete("/{target_id}") +async def delete_contact( + target_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + """Remove contact entry.""" + app.state.contact_repo.delete(user_id, target_id) + return {"status": "deleted"} diff --git a/backend/web/routers/conversations.py b/backend/web/routers/conversations.py new file mode 100644 index 000000000..57cd48256 --- /dev/null +++ b/backend/web/routers/conversations.py @@ -0,0 +1,164 @@ +"""Unified conversation list API — merges threads (hire) and chats (visit). + +GET /api/conversations returns a single sorted list so the frontend +ConversationList can render a unified sidebar. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Annotated, Any + +from fastapi import APIRouter, Depends + +from backend.web.core.dependencies import get_app, get_current_user_id +from backend.web.utils.serializers import avatar_url +from core.runtime.middleware.monitor import AgentState + +router = APIRouter(prefix="/api/conversations", tags=["conversations"]) + + +def _is_internal_child_thread(thread_id: str) -> bool: + return thread_id.startswith("subagent-") + + +def _resolve_display_member(app: Any, social_user_id: str) -> Any | None: + member = app.state.member_repo.get_by_id(social_user_id) + if member is not None: + return member + thread = app.state.thread_repo.get_by_user_id(social_user_id) + if thread is None: + return None + member_id = thread.get("member_id") + if not member_id: + return None + return app.state.member_repo.get_by_id(member_id) + + +def _conversation_updated_at_key(item: dict[str, Any]) -> float: + raw = item.get("updated_at") + if raw is None: + return float("-inf") + if isinstance(raw, (int, float)): + return float(raw) + if isinstance(raw, str): + # @@@mixed-updated-at-sort - hire rows currently carry ISO strings while + # visit chats can still surface numeric timestamps from older chat storage. + # Normalize both before sorting so /api/conversations stays honest. + try: + return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() + except ValueError: + return float("-inf") + return float("-inf") + + +@router.get("") +async def list_conversations( + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)] = None, +) -> list[dict[str, Any]]: + """Return hire threads + visit chats merged by updated_at desc.""" + items: list[dict[str, Any]] = [] + + # ── Hire threads ── + raw_threads = app.state.thread_repo.list_by_owner_user_id(user_id) + pool = app.state.agent_pool + for t in raw_threads: + tid = t["id"] + if _is_internal_child_thread(tid): + continue + sandbox_type = t.get("sandbox_type", "local") + running = False + agent = pool.get(f"{tid}:{sandbox_type}") + if agent and hasattr(agent, "runtime"): + running = agent.runtime.current_state == AgentState.ACTIVE + last_active = app.state.thread_last_active.get(tid) + updated_at = datetime.fromtimestamp(last_active, tz=UTC).isoformat() if last_active else None + items.append( + { + "id": tid, + "type": "hire", + "title": t.get("member_name") or "Agent", + "member_id": t.get("member_id"), + "avatar_url": avatar_url(t.get("member_id"), bool(t.get("member_avatar"))), + "updated_at": updated_at, + "unread_count": 0, + "running": running, + } + ) + + # ── Visit chats ── + messaging = getattr(app.state, "messaging_service", None) + if messaging: + chats = messaging.list_chats_for_user(user_id) + messages_repo = getattr(app.state, "messages_repo", None) + + # Pre-fetch all member data to avoid N+1 per-member lookups + all_member_ids: set[str] = set() + chat_members_cache: dict[str, list[dict[str, Any]]] = {} + chat_obj_cache: dict[str, Any] = {} + + chat_ids = [c["id"] if isinstance(c, dict) else c for c in chats] + for chat_id in chat_ids: + chat_obj = app.state.chat_repo.get_by_id(chat_id) if hasattr(app.state, "chat_repo") else None + if not chat_obj: + continue + chat_obj_cache[chat_id] = chat_obj + members_list = messaging.list_chat_members(chat_id) + chat_members_cache[chat_id] = members_list + for m in members_list: + uid = m.get("user_id") + if uid and uid != user_id: + all_member_ids.add(uid) + + # Batch resolve members + member_cache: dict[str, Any] = {} + for uid in all_member_ids: + mem = _resolve_display_member(app, uid) + if mem: + member_cache[uid] = mem + + for chat_id in chat_ids: + chat_obj = chat_obj_cache.get(chat_id) + if not chat_obj: + continue + members_list = chat_members_cache[chat_id] + + # Determine display name + avatar in single pass + title = getattr(chat_obj, "title", None) or "" + chat_avatar = None + other_names: list[str] = [] + for m in members_list: + uid = m.get("user_id") + if not uid or uid == user_id: + continue + mem = member_cache.get(uid) + if not mem: + continue + other_names.append(mem.name) + if chat_avatar is None: + chat_avatar = avatar_url(mem.id, bool(mem.avatar)) + if not title: + title = ", ".join(other_names) or "Chat" + + # Unread count + unread = 0 + if messages_repo: + unread = messages_repo.count_unread(chat_id, user_id) + + items.append( + { + "id": chat_id, + "type": "visit", + "title": title, + "member_id": None, + "avatar_url": chat_avatar, + "updated_at": getattr(chat_obj, "updated_at", None) or getattr(chat_obj, "created_at", None), + "unread_count": unread, + "running": False, + } + ) + + # Sort by updated_at descending (None goes last) + items.sort(key=_conversation_updated_at_key, reverse=True) + return items diff --git a/backend/web/routers/debug.py b/backend/web/routers/debug.py deleted file mode 100644 index 57299f219..000000000 --- a/backend/web/routers/debug.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Debug logging endpoints.""" - -from fastapi import APIRouter -from pydantic import BaseModel - -router = APIRouter(prefix="/api/debug", tags=["debug"]) - - -class LogMessage(BaseModel): - message: str - timestamp: str - - -@router.post("/log") -async def log_frontend_message(payload: LogMessage) -> dict: - """Receive frontend console logs and write to file.""" - with open("/tmp/leon-frontend-console.log", "a") as f: - f.write(f"[{payload.timestamp}] {payload.message}\n") - return {"status": "ok"} diff --git a/backend/web/routers/entities.py b/backend/web/routers/entities.py index 96f636955..f1686eb51 100644 --- a/backend/web/routers/entities.py +++ b/backend/web/routers/entities.py @@ -1,4 +1,4 @@ -"""Entity & Member endpoints — new entity-chat system.""" +"""Member endpoints — social identity discovery and agent thread lookup.""" import io import logging @@ -12,6 +12,7 @@ from backend.web.core.dependencies import get_app, get_current_user_id from backend.web.core.paths import avatars_dir from backend.web.utils.serializers import avatar_url +from storage.contracts import MemberType logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ def process_and_save_avatar(source: Path | bytes, member_id: str) -> str: img = ImageOps.exif_transpose(img) if img.mode not in ("RGB", "RGBA"): img = img.convert("RGB") - img = ImageOps.fit(img, (AVATAR_SIZE, AVATAR_SIZE), method=Image.LANCZOS) + img = ImageOps.fit(img, (AVATAR_SIZE, AVATAR_SIZE), method=Image.Resampling.LANCZOS) AVATARS_DIR.mkdir(parents=True, exist_ok=True) img.save(AVATARS_DIR / f"{member_id}.png", format="PNG", optimize=True) return f"avatars/{member_id}.png" @@ -89,6 +90,15 @@ def _avatar_path(member_id: str) -> Path: return AVATARS_DIR / f"{safe_id}.png" +def _get_owned_avatar_member_or_404(member_id: str, current_user_id: str, member_repo: Any) -> Any: + member = member_repo.get_by_id(member_id) + if not member: + raise HTTPException(404, "Member not found") + if member_id == current_user_id or member.owner_user_id == current_user_id: + return member + raise HTTPException(403, "Not authorized") + + @members_router.put("/{member_id}/avatar") async def upload_avatar( member_id: str, @@ -98,11 +108,7 @@ async def upload_avatar( ) -> dict[str, str]: """Upload/replace avatar image. Resizes to 256x256 PNG.""" repo = app.state.member_repo - member = repo.get_by_id(member_id) - if not member: - raise HTTPException(404, "Member not found") - if member_id != current_user_id and member.owner_user_id != current_user_id: - raise HTTPException(403, "Not authorized") + _get_owned_avatar_member_or_404(member_id, current_user_id, repo) ct = file.content_type or "" if ct not in ALLOWED_CONTENT_TYPES: raise HTTPException(400, f"Unsupported image type: {ct}") @@ -137,11 +143,7 @@ async def delete_avatar( ) -> dict[str, str]: """Delete avatar.""" repo = app.state.member_repo - member = repo.get_by_id(member_id) - if not member: - raise HTTPException(404, "Member not found") - if member_id != current_user_id and member.owner_user_id != current_user_id: - raise HTTPException(403, "Not authorized") + _get_owned_avatar_member_or_404(member_id, current_user_id, repo) path = _avatar_path(member_id) if path.exists(): path.unlink() @@ -160,66 +162,85 @@ async def list_entities( app: Annotated[Any, Depends(get_app)], ): """List chattable entities for discovery (New Chat picker). - Humans are represented by their user_id; agents by their member_id. - Excludes the current user (you don't chat with yourself).""" - entity_repo = app.state.entity_repo + Humans are keyed by user_id; agent templates are keyed by member_id plus + their default representative thread. Excludes the current user.""" member_repo = app.state.member_repo - members = member_repo.list_all() member_map = {m.id: m for m in members} items = [] - # Human participants: all human members except self for m in members: - if m.type != "human" or m.id == user_id: + if m.id == user_id: continue - items.append( - { - "id": m.id, # user_id IS the social identity for humans - "name": m.name, - "type": "human", - "avatar_url": avatar_url(m.id, bool(m.avatar)), - "owner_name": None, - "member_name": m.name, - "thread_id": None, - "is_main": None, - "branch_index": None, - } - ) - - # Agent participants: from entity_repo (agent entities have id = member_id) - all_entities = entity_repo.list_by_type("agent") - for entity in all_entities: - member = member_map.get(entity.member_id) - owner = member_map.get(member.owner_user_id) if member and member.owner_user_id else None - thread = app.state.thread_repo.get_by_id(entity.thread_id) if entity.thread_id else None - items.append( - { - "id": entity.id, # entity.id = member_id = social identity for agents - "name": entity.name, - "type": entity.type, - "avatar_url": avatar_url(entity.member_id, bool(member.avatar if member else None)), - "owner_name": owner.name if owner else None, - "member_name": member.name if member else None, - "thread_id": entity.thread_id, - "is_main": thread["is_main"] if thread else None, - "branch_index": thread["branch_index"] if thread else None, - } - ) + if m.type == MemberType.HUMAN: + items.append( + { + "user_id": m.id, + "name": m.name, + "type": "human", + "avatar_url": avatar_url(m.id, bool(m.avatar)), + "owner_name": None, + "member_name": m.name, + "default_thread_id": None, + "is_default_thread": None, + "branch_index": None, + } + ) + else: + owner = member_map.get(m.owner_user_id) if m.owner_user_id else None + default_thread = app.state.thread_repo.get_default_thread(m.id) + items.append( + { + "member_id": m.id, + "name": m.name, + "type": m.type.value if hasattr(m.type, "value") else str(m.type), + "avatar_url": avatar_url(m.id, bool(m.avatar)), + "owner_name": owner.name if owner else None, + "member_name": m.name, + "default_thread_id": default_thread["id"] if default_thread else None, + "is_default_thread": default_thread["is_main"] if default_thread else None, + "branch_index": default_thread["branch_index"] if default_thread else None, + } + ) return items -@router.get("/{user_id}/agent-thread") +@router.get("/{member_id}/profile") +async def get_entity_profile( + member_id: str, + app: Annotated[Any, Depends(get_app)], +): + """Public agent profile. No auth required (frontend uses plain fetch).""" + member = _get_member_or_404(app, member_id) + member_type = member.type.value if hasattr(member.type, "value") else str(member.type) + if "agent" not in member_type: + raise HTTPException(404, "Profile not available for this member type") + return { + "id": member.id, + "name": member.name, + "type": member_type, + "avatar_url": avatar_url(member.id, bool(member.avatar)), + "description": member.description, + } + + +@router.get("/{member_id}/agent-thread") async def get_agent_thread( - user_id: str, + member_id: str, current_user_id: Annotated[str, Depends(get_current_user_id)], app: Annotated[Any, Depends(get_app)], ): - """Get the thread_id for an agent's main thread. user_id here is the agent's member_id.""" - entity = app.state.entity_repo.get_by_id(user_id) - if not entity: - raise HTTPException(404, "Entity not found") - if entity.type == "agent" and entity.thread_id: - return {"user_id": user_id, "thread_id": entity.thread_id} + """Get the default representative thread for an agent template.""" + member = _get_member_or_404(app, member_id) + default_thread = app.state.thread_repo.get_default_thread(member_id) + if member.type != MemberType.HUMAN and default_thread is not None: + return {"member_id": member_id, "default_thread_id": default_thread["id"]} raise HTTPException(404, "No agent thread found") + + +def _get_member_or_404(app: Any, member_id: str) -> Any: + member = app.state.member_repo.get_by_id(member_id) + if not member: + raise HTTPException(404, "Member not found") + return member diff --git a/backend/web/routers/invite_codes.py b/backend/web/routers/invite_codes.py index 53a17efeb..290b43631 100644 --- a/backend/web/routers/invite_codes.py +++ b/backend/web/routers/invite_codes.py @@ -11,15 +11,26 @@ router = APIRouter(prefix="/api/invite-codes", tags=["invite-codes"]) -def _get_invite_code_repo(app: Any): - """Get SupabaseInviteCodeRepo from app state, or raise 503 if unavailable.""" - sb_client = getattr(app.state, "_supabase_client", None) +async def _call_invite_code_repo( + request: Request, + error_prefix: str, + method_name: str, + *args: Any, + **kwargs: Any, +) -> Any: + sb_client = getattr(request.app.state, "_supabase_client", None) if sb_client is None: raise HTTPException(503, "邀请码服务不可用(当前为 SQLite 模式)") - repo = getattr(app.state, "invite_code_repo", None) + repo = getattr(request.app.state, "invite_code_repo", None) if repo is None: raise HTTPException(503, "邀请码仓库未初始化") - return repo + try: + method = getattr(repo, method_name) + return await asyncio.to_thread(method, *args, **kwargs) + except HTTPException: + raise + except Exception as e: + raise HTTPException(500, f"{error_prefix}{e}") from e # ── List all invite codes ──────────────────────────────────────────────────── @@ -30,14 +41,8 @@ async def list_invite_codes( request: Request, user_id: Annotated[str, Depends(get_current_user_id)], ) -> dict: - repo = _get_invite_code_repo(request.app) - try: - codes = await asyncio.to_thread(repo.list_all) - return {"codes": codes} - except HTTPException: - raise - except Exception as e: - raise HTTPException(500, f"获取邀请码列表失败:{e}") from e + codes = await _call_invite_code_repo(request, "获取邀请码列表失败:", "list_all") + return {"codes": codes} # ── Generate a new invite code ─────────────────────────────────────────────── @@ -53,18 +58,13 @@ async def generate_invite_code( request: Request, user_id: Annotated[str, Depends(get_current_user_id)], ) -> dict: - repo = _get_invite_code_repo(request.app) - try: - code = await asyncio.to_thread( - repo.generate, - created_by=user_id, - expires_days=payload.expires_days, - ) - return code - except HTTPException: - raise - except Exception as e: - raise HTTPException(500, f"生成邀请码失败:{e}") from e + return await _call_invite_code_repo( + request, + "生成邀请码失败:", + "generate", + created_by=user_id, + expires_days=payload.expires_days, + ) # ── Revoke (delete) an invite code ────────────────────────────────────────── @@ -76,16 +76,10 @@ async def revoke_invite_code( request: Request, user_id: Annotated[str, Depends(get_current_user_id)], ) -> dict: - repo = _get_invite_code_repo(request.app) - try: - ok = await asyncio.to_thread(repo.revoke, code) - if not ok: - raise HTTPException(404, "邀请码不存在") - return {"ok": True} - except HTTPException: - raise - except Exception as e: - raise HTTPException(500, f"吊销邀请码失败:{e}") from e + ok = await _call_invite_code_repo(request, "吊销邀请码失败:", "revoke", code) + if not ok: + raise HTTPException(404, "邀请码不存在") + return {"ok": True} # ── Validate an invite code (no auth required) ─────────────────────────────── @@ -93,11 +87,5 @@ async def revoke_invite_code( @router.get("/validate/{code}") async def validate_invite_code(code: str, request: Request) -> dict: - repo = _get_invite_code_repo(request.app) - try: - valid = await asyncio.to_thread(repo.is_valid, code) - return {"valid": valid} - except HTTPException: - raise - except Exception as e: - raise HTTPException(500, f"校验邀请码失败:{e}") from e + valid = await _call_invite_code_repo(request, "校验邀请码失败:", "is_valid", code) + return {"valid": valid} diff --git a/backend/web/routers/messaging.py b/backend/web/routers/messaging.py new file mode 100644 index 000000000..ce2b2579a --- /dev/null +++ b/backend/web/routers/messaging.py @@ -0,0 +1,329 @@ +"""Messaging API router — replaces chats.py. + +All operations go through MessagingService (Supabase-backed). +No legacy fallback. +""" + +from __future__ import annotations + +import asyncio +import json +from datetime import UTC, datetime +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel + +from backend.web.core.dependencies import get_app, get_current_user_id +from backend.web.utils.serializers import avatar_url + +router = APIRouter(prefix="/api/chats", tags=["chats"]) + + +# --------------------------------------------------------------------------- +# Request models +# --------------------------------------------------------------------------- + + +class CreateChatBody(BaseModel): + user_ids: list[str] + title: str | None = None + + +class SendMessageBody(BaseModel): + content: str + sender_id: str + mentioned_ids: list[str] | None = None + message_type: str = "human" + signal: str | None = None + + +class MuteChatBody(BaseModel): + user_id: str + muted: bool + mute_until: float | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _messaging(app: Any): + svc = getattr(app.state, "messaging_service", None) + if svc is None: + raise HTTPException(503, "MessagingService not initialized") + return svc + + +def _verify_member_ownership(app: Any, member_id: str, user_id: str) -> None: + # @@@thread-social-owner-check - sender_id can be a thread-owned social user_id, so + # ownership must resolve through the thread back to the template member before checking owner. + member = _resolve_display_member(app, member_id) + if not member: + raise HTTPException(403, "Member not found") + if member.id == user_id: + return # human member sending as themselves + if member.owner_user_id == user_id: + return # agent owned by current user + raise HTTPException(403, "Member does not belong to you") + + +def _get_accessible_chat_or_404(app: Any, chat_id: str, user_id: str) -> Any: + chat = app.state.chat_repo.get_by_id(chat_id) + if not chat: + raise HTTPException(404, "Chat not found") + if not _messaging(app).is_chat_member(chat_id, user_id): + raise HTTPException(403, "Not a participant of this chat") + return chat + + +def _resolve_display_member(app: Any, social_user_id: str) -> Any | None: + member = app.state.member_repo.get_by_id(social_user_id) + if member is not None: + return member + thread_repo = getattr(app.state, "thread_repo", None) + if thread_repo is None: + return None + thread = thread_repo.get_by_user_id(social_user_id) + if thread is None: + return None + member_id = thread.get("member_id") + if not member_id: + return None + return app.state.member_repo.get_by_id(member_id) + + +def _msg_response(m: dict[str, Any], app: Any) -> dict[str, Any]: + sender = _resolve_display_member(app, m.get("sender_id", "")) + return { + "id": m["id"], + "chat_id": m["chat_id"], + "sender_id": m.get("sender_id"), + "sender_name": sender.name if sender else "unknown", + "content": m["content"], + "message_type": m.get("message_type", "human"), + "mentioned_ids": m.get("mentioned_ids") or m.get("mentions") or [], + "signal": m.get("signal"), + "retracted_at": m.get("retracted_at"), + "created_at": m.get("created_at"), + } + + +# --------------------------------------------------------------------------- +# Chat list / create +# --------------------------------------------------------------------------- + + +@router.get("") +async def list_chats( + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + return _messaging(app).list_chats_for_user(user_id) + + +@router.post("") +async def create_chat( + body: CreateChatBody, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + try: + if len(body.user_ids) >= 3: + chat = _messaging(app).create_group_chat(body.user_ids, body.title) + else: + chat = _messaging(app).find_or_create_chat(body.user_ids, body.title) + return { + "id": chat["id"], + "title": chat.get("title"), + "status": chat.get("status"), + "created_at": chat.get("created_at"), + } + except ValueError as e: + raise HTTPException(400, str(e)) + + +# --------------------------------------------------------------------------- +# Chat detail +# --------------------------------------------------------------------------- + + +@router.get("/{chat_id}") +async def get_chat( + chat_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + chat = _get_accessible_chat_or_404(app, chat_id, user_id) + members_list = _messaging(app).list_chat_members(chat_id) + members_info = [] + for m in members_list: + uid = m.get("user_id") + if not uid: + continue + mem = _resolve_display_member(app, uid) + if mem: + members_info.append( + { + "id": uid, + "name": mem.name, + "type": mem.type.value if hasattr(mem.type, "value") else str(mem.type), + "avatar_url": avatar_url(mem.id, bool(mem.avatar)), + } + ) + return { + "id": chat.id, + "title": chat.title, + "status": chat.status, + "created_at": chat.created_at, + "entities": members_info, + } + + +# --------------------------------------------------------------------------- +# Messages +# --------------------------------------------------------------------------- + + +@router.get("/{chat_id}/messages") +async def list_messages( + chat_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], + limit: int = Query(50, ge=1, le=200), + before: str | None = Query(None), +): + if not _messaging(app).is_chat_member(chat_id, user_id): + raise HTTPException(403, "Not a participant of this chat") + msgs = _messaging(app).list_messages(chat_id, limit=limit, before=before, viewer_id=user_id) + return [_msg_response(m, app) for m in msgs] + + +@router.post("/{chat_id}/messages") +async def send_message( + chat_id: str, + body: SendMessageBody, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + if not body.content.strip(): + raise HTTPException(400, "Content cannot be empty") + _verify_member_ownership(app, body.sender_id, user_id) + msg = _messaging(app).send( + chat_id, + body.sender_id, + body.content, + mentions=body.mentioned_ids, + signal=body.signal, + message_type=body.message_type, + ) + return _msg_response(msg, app) + + +@router.post("/{chat_id}/messages/{message_id}/retract") +async def retract_message( + chat_id: str, + message_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + ok = _messaging(app).retract(message_id, user_id) + if not ok: + raise HTTPException(400, "Cannot retract: not sender, already retracted, or 2-min window expired") + return {"status": "retracted"} + + +@router.delete("/{chat_id}/messages/{message_id}") +async def delete_message_for_self( + chat_id: str, + message_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + _messaging(app).delete_for(message_id, user_id) + return {"status": "deleted"} + + +@router.post("/{chat_id}/read") +async def mark_read( + chat_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + _messaging(app).mark_read(chat_id, user_id) + return {"status": "ok"} + + +# --------------------------------------------------------------------------- +# Delete chat +# --------------------------------------------------------------------------- + + +@router.delete("/{chat_id}") +async def delete_chat( + chat_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + _get_accessible_chat_or_404(app, chat_id, user_id) + app.state.chat_repo.delete(chat_id) + return {"status": "deleted"} + + +# --------------------------------------------------------------------------- +# SSE stream (typing indicators fallback, messages come via Supabase Realtime) +# --------------------------------------------------------------------------- + + +@router.get("/{chat_id}/events") +async def stream_chat_events( + chat_id: str, + token: str | None = None, + app: Annotated[Any, Depends(get_app)] = None, +): + if not token: + raise HTTPException(401, "Missing token") + try: + app.state.auth_service.verify_token(token) + except ValueError as e: + raise HTTPException(401, str(e)) + + from fastapi.responses import StreamingResponse + + event_bus = app.state.chat_event_bus + queue = event_bus.subscribe(chat_id) + + async def event_generator(): + try: + yield "retry: 5000\n\n" + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=30) + event_type = event.get("event", "message") + data = event.get("data", {}) + yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + except TimeoutError: + yield ": keepalive\n\n" + finally: + event_bus.unsubscribe(chat_id, queue) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +# --------------------------------------------------------------------------- +# Chat mute +# --------------------------------------------------------------------------- + + +@router.post("/{chat_id}/mute") +async def mute_chat( + chat_id: str, + body: MuteChatBody, + user_id: Annotated[str, Depends(get_current_user_id)], + app: Annotated[Any, Depends(get_app)], +): + _verify_member_ownership(app, body.user_id, user_id) + mute_until_iso = datetime.fromtimestamp(body.mute_until, tz=UTC).isoformat() if body.mute_until else None + _messaging(app).update_mute(chat_id, body.user_id, body.muted, mute_until_iso) + return {"status": "ok", "muted": body.muted} diff --git a/backend/web/routers/monitor.py b/backend/web/routers/monitor.py index 8b389c308..eb1781db6 100644 --- a/backend/web/routers/monitor.py +++ b/backend/web/routers/monitor.py @@ -1,57 +1,26 @@ -"""Sandbox Monitor API - thin router over monitor core.""" +"""Monitor router compatibility layer. + +Expose the richer monitor implementation from ``backend.web.monitor`` while +preserving the newer resource/health helper endpoints added on main. +""" import asyncio -from fastapi import APIRouter, HTTPException, Query +from fastapi import HTTPException, Query +from pydantic import BaseModel, Field +from backend.web.monitor import list_leases, router from backend.web.services import monitor_service from backend.web.services.resource_cache import ( - get_resource_overview_snapshot, - refresh_resource_overview_sync, + get_monitor_resource_overview_snapshot, + refresh_monitor_resource_overview_sync, ) -router = APIRouter(prefix="/api/monitor") - - -@router.get("/threads") -def list_threads(): - return monitor_service.list_threads() - - -@router.get("/thread/{thread_id}") -def get_thread(thread_id: str): - return monitor_service.get_thread(thread_id) - - -@router.get("/leases") -def list_leases(): - return monitor_service.list_leases() - - -@router.get("/lease/{lease_id}") -def get_lease(lease_id: str): - try: - return monitor_service.get_lease(lease_id) - except KeyError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - -@router.get("/diverged") -def list_diverged(): - return monitor_service.list_diverged() - - -@router.get("/events") -def list_events(limit: int = 100): - return monitor_service.list_events(limit=limit) - - -@router.get("/event/{event_id}") -def get_event(event_id: str): - try: - return monitor_service.get_event(event_id) - except KeyError as e: - raise HTTPException(status_code=404, detail=str(e)) from e +class ResourceCleanupRequest(BaseModel): + action: str = Field(default="cleanup_residue") + lease_ids: list[str] + expected_category: str @router.get("/health") @@ -59,15 +28,60 @@ def health_snapshot(): return monitor_service.runtime_health_snapshot() +@router.get("/dashboard") +def dashboard_snapshot(): + health = monitor_service.runtime_health_snapshot() + resources = get_monitor_resource_overview_snapshot() + leases = list_leases() + + resource_summary = resources.get("summary") or {} + lease_summary = leases.get("summary") or {} + + return { + "snapshot_at": health.get("snapshot_at"), + "resources_summary": resource_summary, + "infra": { + "providers_active": int(resource_summary.get("active_providers") or 0), + "providers_unavailable": int(resource_summary.get("unavailable_providers") or 0), + "leases_total": int(lease_summary.get("total") or leases.get("count") or 0), + "leases_diverged": int(lease_summary.get("diverged") or 0) + int(lease_summary.get("orphan_diverged") or 0), + "leases_orphan": int(lease_summary.get("orphan") or 0) + int(lease_summary.get("orphan_diverged") or 0), + "leases_healthy": int(lease_summary.get("healthy") or 0), + }, + "workload": { + "db_sessions_total": int(((health.get("db") or {}).get("counts") or {}).get("chat_sessions") or 0), + "provider_sessions_total": int(((health.get("sessions") or {}).get("total")) or 0), + "running_sessions": int(resource_summary.get("running_sessions") or 0), + "evaluations_running": 0, + }, + "latest_evaluation": None, + } + + @router.get("/resources") def resources_overview(): - return get_resource_overview_snapshot() + return get_monitor_resource_overview_snapshot() @router.post("/resources/refresh") async def resources_refresh(): # @@@refresh-off-main-loop - provider I/O stays off event loop to avoid request head-of-line blocking. - return await asyncio.to_thread(refresh_resource_overview_sync) + return await asyncio.to_thread(refresh_monitor_resource_overview_sync) + + +@router.post("/resources/cleanup") +async def resources_cleanup(payload: ResourceCleanupRequest): + from backend.web.services import monitor_service + + try: + return await asyncio.to_thread( + monitor_service.cleanup_resource_leases, + action=payload.action, + lease_ids=payload.lease_ids, + expected_category=payload.expected_category, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc @router.get("/sandbox/{lease_id}/browse") diff --git a/backend/web/routers/panel.py b/backend/web/routers/panel.py index 3fe2f481b..0b5a8bd45 100644 --- a/backend/web/routers/panel.py +++ b/backend/web/routers/panel.py @@ -27,6 +27,15 @@ router = APIRouter(prefix="/api/panel", tags=["panel"]) +def _get_owned_member_or_404(member_id: str, user_id: str) -> dict[str, Any]: + item = member_service.get_member(member_id) + if not item: + raise HTTPException(404, "Member not found") + if item.get("owner_user_id") != user_id: + raise HTTPException(403, "Forbidden") + return item + + # ── Members ── @@ -41,11 +50,11 @@ async def list_members( @router.get("/members/{member_id}") -async def get_member(member_id: str) -> dict[str, Any]: - item = await asyncio.to_thread(member_service.get_member, member_id) - if not item: - raise HTTPException(404, "Member not found") - return item +async def get_member( + member_id: str, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + return await asyncio.to_thread(_get_owned_member_or_404, member_id, user_id) @router.post("/members") @@ -55,20 +64,30 @@ async def create_member( request: Request, ) -> dict[str, Any]: member_repo = getattr(request.app.state, "member_repo", None) - return await asyncio.to_thread(member_service.create_member, req.name, req.description, owner_user_id=user_id, member_repo=member_repo) + agent_config_repo = getattr(request.app.state, "agent_config_repo", None) + return await asyncio.to_thread( + member_service.create_member, + req.name, + req.description, + owner_user_id=user_id, + member_repo=member_repo, + agent_config_repo=agent_config_repo, + ) @router.put("/members/{member_id}") -async def update_member(member_id: str, req: UpdateMemberRequest, request: Request) -> dict[str, Any]: +async def update_member( + member_id: str, + req: UpdateMemberRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: member_repo = getattr(request.app.state, "member_repo", None) - entity_repo = getattr(request.app.state, "entity_repo", None) - thread_repo = getattr(request.app.state, "thread_repo", None) + await asyncio.to_thread(_get_owned_member_or_404, member_id, user_id) item = await asyncio.to_thread( member_service.update_member, member_id, member_repo=member_repo, - entity_repo=entity_repo, - thread_repo=thread_repo, **req.model_dump(), ) if not item: @@ -77,29 +96,64 @@ async def update_member(member_id: str, req: UpdateMemberRequest, request: Reque @router.put("/members/{member_id}/config") -async def update_member_config(member_id: str, req: MemberConfigPayload) -> dict[str, Any]: - item = await asyncio.to_thread(member_service.update_member_config, member_id, req.model_dump()) +async def update_member_config( + member_id: str, + req: MemberConfigPayload, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + await asyncio.to_thread(_get_owned_member_or_404, member_id, user_id) + agent_config_repo = getattr(request.app.state, "agent_config_repo", None) + item = await asyncio.to_thread( + member_service.update_member_config, + member_id, + req.model_dump(), + agent_config_repo=agent_config_repo, + ) if not item: raise HTTPException(404, "Member not found") return item @router.put("/members/{member_id}/publish") -async def publish_member(member_id: str, req: PublishMemberRequest) -> dict[str, Any]: +async def publish_member( + member_id: str, + req: PublishMemberRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: if member_id == "__leon__": raise HTTPException(403, "Cannot publish builtin member") - item = await asyncio.to_thread(member_service.publish_member, member_id, req.bump_type) + await asyncio.to_thread(_get_owned_member_or_404, member_id, user_id) + agent_config_repo = getattr(request.app.state, "agent_config_repo", None) + item = await asyncio.to_thread( + member_service.publish_member, + member_id, + req.bump_type, + agent_config_repo=agent_config_repo, + ) if not item: raise HTTPException(404, "Member not found") return item @router.delete("/members/{member_id}") -async def delete_member(member_id: str, request: Request) -> dict[str, Any]: +async def delete_member( + member_id: str, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: if member_id == "__leon__": raise HTTPException(403, "Cannot delete builtin member") + await asyncio.to_thread(_get_owned_member_or_404, member_id, user_id) member_repo = getattr(request.app.state, "member_repo", None) - ok = await asyncio.to_thread(member_service.delete_member, member_id, member_repo=member_repo) + agent_config_repo = getattr(request.app.state, "agent_config_repo", None) + ok = await asyncio.to_thread( + member_service.delete_member, + member_id, + member_repo=member_repo, + agent_config_repo=agent_config_repo, + ) if not ok: raise HTTPException(404, "Member not found") return {"success": True} @@ -109,39 +163,95 @@ async def delete_member(member_id: str, request: Request) -> dict[str, Any]: @router.get("/tasks") -async def list_tasks() -> dict[str, Any]: - items = await asyncio.to_thread(task_service.list_tasks) +async def list_tasks( + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + items = await asyncio.to_thread( + task_service.list_tasks, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + thread_repo=request.app.state.thread_repo, + ) return {"items": items} @router.post("/tasks") -async def create_task(req: CreateTaskRequest) -> dict[str, Any]: - return await asyncio.to_thread(task_service.create_task, **req.model_dump()) +async def create_task( + req: CreateTaskRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + return await asyncio.to_thread( + task_service.create_task, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + **req.model_dump(), + ) @router.put("/tasks/bulk-status") -async def bulk_update_status(req: BulkTaskStatusRequest) -> dict[str, Any]: - count = await asyncio.to_thread(task_service.bulk_update_task_status, req.ids, req.status) +async def bulk_update_status( + req: BulkTaskStatusRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + count = await asyncio.to_thread( + task_service.bulk_update_task_status, + req.ids, + req.status, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + ) return {"updated": count} @router.post("/tasks/bulk-delete") -async def bulk_delete_tasks(req: BulkDeleteTasksRequest) -> dict[str, Any]: - count = await asyncio.to_thread(task_service.bulk_delete_tasks, req.ids) +async def bulk_delete_tasks( + req: BulkDeleteTasksRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + count = await asyncio.to_thread( + task_service.bulk_delete_tasks, + req.ids, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + ) return {"deleted": count} @router.put("/tasks/{task_id}") -async def update_task(task_id: str, req: UpdateTaskRequest) -> dict[str, Any]: - item = await asyncio.to_thread(task_service.update_task, task_id, **req.model_dump()) +async def update_task( + task_id: str, + req: UpdateTaskRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + item = await asyncio.to_thread( + task_service.update_task, + task_id, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + **req.model_dump(), + ) if not item: raise HTTPException(404, "Task not found") return item @router.delete("/tasks/{task_id}") -async def delete_task(task_id: str) -> dict[str, Any]: - ok = await asyncio.to_thread(task_service.delete_task, task_id) +async def delete_task( + task_id: str, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + ok = await asyncio.to_thread( + task_service.delete_task, + task_id, + owner_user_id=user_id, + repo=request.app.state.panel_task_repo, + ) if not ok: raise HTTPException(404, "Task not found") return {"success": True} @@ -151,49 +261,86 @@ async def delete_task(task_id: str) -> dict[str, Any]: @router.get("/cron-jobs") -async def list_cron_jobs() -> dict[str, Any]: - items = await asyncio.to_thread(cron_job_service.list_cron_jobs) +async def list_cron_jobs( + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + items = await asyncio.to_thread( + cron_job_service.list_cron_jobs, + owner_user_id=user_id, + repo=request.app.state.cron_job_repo, + ) return {"items": items} @router.post("/cron-jobs") -async def create_cron_job(req: CreateCronJobRequest) -> dict[str, Any]: +async def create_cron_job( + req: CreateCronJobRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: job = await asyncio.to_thread( cron_job_service.create_cron_job, name=req.name, cron_expression=req.cron_expression, + repo=request.app.state.cron_job_repo, description=req.description, task_template=req.task_template, enabled=int(req.enabled), + owner_user_id=user_id, ) return {"item": job} @router.put("/cron-jobs/{job_id}") -async def update_cron_job(job_id: str, req: UpdateCronJobRequest) -> dict[str, Any]: +async def update_cron_job( + job_id: str, + req: UpdateCronJobRequest, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: fields = req.model_dump(exclude_none=True) if "enabled" in fields: fields["enabled"] = int(fields["enabled"]) - job = await asyncio.to_thread(cron_job_service.update_cron_job, job_id, **fields) + job = await asyncio.to_thread( + cron_job_service.update_cron_job, + job_id, + owner_user_id=user_id, + repo=request.app.state.cron_job_repo, + **fields, + ) if not job: raise HTTPException(404, "Cron job not found") return {"item": job} @router.delete("/cron-jobs/{job_id}") -async def delete_cron_job(job_id: str) -> dict[str, Any]: - ok = await asyncio.to_thread(cron_job_service.delete_cron_job, job_id) +async def delete_cron_job( + job_id: str, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + ok = await asyncio.to_thread( + cron_job_service.delete_cron_job, + job_id, + owner_user_id=user_id, + repo=request.app.state.cron_job_repo, + ) if not ok: raise HTTPException(404, "Cron job not found") return {"ok": True} @router.post("/cron-jobs/{job_id}/run") -async def trigger_cron_job(job_id: str, request: Request) -> dict[str, Any]: +async def trigger_cron_job( + job_id: str, + request: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: cron_service = getattr(request.app.state, "cron_service", None) if not cron_service: raise HTTPException(503, "Cron service not available") - task = await cron_service.trigger_job(job_id) + task = await cron_service.trigger_job(job_id, owner_user_id=user_id) if not task: raise HTTPException(404, "Cron job not found or disabled") return {"item": task} @@ -315,10 +462,17 @@ async def update_resource_content(resource_type: str, resource_id: str, req: Upd @router.get("/profile") -async def get_profile() -> dict[str, Any]: - return await asyncio.to_thread(profile_service.get_profile) +async def get_profile( + user_id: Annotated[str, Depends(get_current_user_id)], + request: Request, +) -> dict[str, Any]: + member = request.app.state.member_repo.get_by_id(user_id) + return await asyncio.to_thread(profile_service.get_profile, member) @router.put("/profile") -async def update_profile(req: UpdateProfileRequest) -> dict[str, Any]: +async def update_profile( + req: UpdateProfileRequest, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: return await asyncio.to_thread(profile_service.update_profile, **req.model_dump()) diff --git a/backend/web/routers/resources.py b/backend/web/routers/resources.py new file mode 100644 index 000000000..4fc56e7a5 --- /dev/null +++ b/backend/web/routers/resources.py @@ -0,0 +1,28 @@ +"""User-scoped resource endpoints.""" + +from __future__ import annotations + +import asyncio +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Request + +from backend.web.core.dependencies import get_current_user_id +from backend.web.services import resource_projection_service + +router = APIRouter(prefix="/api/resources", tags=["resources"]) + + +@router.get("/overview") +async def resources_overview( + user_id: Annotated[str, Depends(get_current_user_id)], + request: Request, +) -> dict[str, Any]: + try: + return await asyncio.to_thread( + resource_projection_service.list_user_resource_providers, + request.app, + user_id, + ) + except RuntimeError as exc: + raise HTTPException(500, str(exc)) from exc diff --git a/backend/web/routers/settings.py b/backend/web/routers/settings.py index f765c0962..daf049255 100644 --- a/backend/web/routers/settings.py +++ b/backend/web/routers/settings.py @@ -6,11 +6,12 @@ import json from pathlib import Path -from typing import Any +from typing import Annotated, Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel +from backend.web.core.dependencies import get_current_user_id from config.models_loader import ModelsLoader from config.models_schema import ModelsConfig from config.user_paths import user_home_path, user_home_read_candidates @@ -42,6 +43,27 @@ class DirectoryItem(BaseModel): is_dir: bool +def _resolve_workspace_path_or_400( + workspace: str, + *, + missing_detail: str, + not_dir_detail: str, +) -> str: + workspace_path = Path(workspace).expanduser().resolve() + if not workspace_path.exists(): + raise HTTPException(status_code=400, detail=missing_detail) + if not workspace_path.is_dir(): + raise HTTPException(status_code=400, detail=not_dir_detail) + return str(workspace_path) + + +def _remember_recent_workspace(settings: "WorkspaceSettings", workspace_str: str) -> None: + if workspace_str in settings.recent_workspaces: + settings.recent_workspaces.remove(workspace_str) + settings.recent_workspaces.insert(0, workspace_str) + settings.recent_workspaces = settings.recent_workspaces[:5] + + def load_settings() -> WorkspaceSettings: try: data = _load_user_json("preferences.json") @@ -71,6 +93,25 @@ def _try_get_user_id(request: Request) -> str | None: return None +def _load_models_for_user(repo, user_id: str | None) -> dict[str, Any]: + """Load models config: Supabase first, filesystem fallback.""" + if repo and user_id: + data = repo.get_models_config(user_id) + if data is not None: + return data + return _load_user_json("models.json") + + +def _save_models_for_user(repo, user_id: str | None, data: dict[str, Any]) -> None: + """Save models config: Supabase if available, else filesystem.""" + if repo and user_id: + repo.set_models_config(user_id, data) + else: + MODELS_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(MODELS_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + # ============================================================================ # Models config (models.json) # ============================================================================ @@ -81,13 +122,6 @@ def load_models() -> dict[str, Any]: return _load_user_json("models.json") -def save_models(data: dict[str, Any]) -> None: - """Save models.json to disk (user-level).""" - MODELS_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(MODELS_FILE, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) - - def load_merged_models() -> ModelsConfig: """Load fully merged ModelsConfig (system + user).""" return ModelsLoader().load() @@ -149,7 +183,7 @@ async def get_settings(request: Request) -> UserSettings: # Build compat view mapping = {k: v.model for k, v in models.mapping.items()} providers = {k: ProviderConfig(api_key=v.api_key, base_url=v.base_url) for k, v in models.providers.items()} - raw = load_models() + raw = _load_models_for_user(repo, user_id) custom_config = raw.get("pool", {}).get("custom_config", {}) return UserSettings( @@ -214,51 +248,49 @@ async def read_local_file(path: str = Query(...)) -> dict[str, Any]: @router.post("/workspace") -async def set_default_workspace(request: WorkspaceRequest, req: Request) -> dict[str, Any]: +async def set_default_workspace( + request: WorkspaceRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: """Set default workspace path.""" - workspace_path = Path(request.workspace).expanduser().resolve() - if not workspace_path.exists(): - raise HTTPException(status_code=400, detail="Workspace path does not exist") - if not workspace_path.is_dir(): - raise HTTPException(status_code=400, detail="Workspace path is not a directory") - - workspace_str = str(workspace_path) + workspace_str = _resolve_workspace_path_or_400( + request.workspace, + missing_detail="Workspace path does not exist", + not_dir_detail="Workspace path is not a directory", + ) repo = _get_settings_repo(req) - user_id = _try_get_user_id(req) if repo else None if repo and user_id: repo.set_default_workspace(user_id, workspace_str) else: settings = load_settings() settings.default_workspace = workspace_str - if workspace_str in settings.recent_workspaces: - settings.recent_workspaces.remove(workspace_str) - settings.recent_workspaces.insert(0, workspace_str) - settings.recent_workspaces = settings.recent_workspaces[:5] + _remember_recent_workspace(settings, workspace_str) save_settings(settings) return {"success": True, "workspace": workspace_str} @router.post("/workspace/recent") -async def add_recent_workspace(request: WorkspaceRequest, req: Request) -> dict[str, Any]: +async def add_recent_workspace( + request: WorkspaceRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: """Add a workspace to recent list.""" - workspace_path = Path(request.workspace).expanduser().resolve() - if not workspace_path.exists() or not workspace_path.is_dir(): - raise HTTPException(status_code=400, detail="Invalid workspace path") - - workspace_str = str(workspace_path) + workspace_str = _resolve_workspace_path_or_400( + request.workspace, + missing_detail="Invalid workspace path", + not_dir_detail="Invalid workspace path", + ) repo = _get_settings_repo(req) - user_id = _try_get_user_id(req) if repo else None if repo and user_id: repo.add_recent_workspace(user_id, workspace_str) else: settings = load_settings() - if workspace_str in settings.recent_workspaces: - settings.recent_workspaces.remove(workspace_str) - settings.recent_workspaces.insert(0, workspace_str) - settings.recent_workspaces = settings.recent_workspaces[:5] + _remember_recent_workspace(settings, workspace_str) save_settings(settings) return {"success": True} @@ -269,10 +301,13 @@ class DefaultModelRequest(BaseModel): @router.post("/default-model") -async def set_default_model(request: DefaultModelRequest, req: Request) -> dict[str, Any]: +async def set_default_model( + request: DefaultModelRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: """Set default virtual model preference.""" repo = _get_settings_repo(req) - user_id = _try_get_user_id(req) if repo else None if repo and user_id: repo.set_default_model(user_id, request.model) else: @@ -387,9 +422,14 @@ class ModelMappingRequest(BaseModel): @router.post("/model-mapping") -async def update_model_mapping(request: ModelMappingRequest) -> dict[str, Any]: - """Update virtual model mapping → models.json.""" - data = load_models() +async def update_model_mapping( + request: ModelMappingRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + """Update virtual model mapping → models config.""" + repo = _get_settings_repo(req) + data = _load_models_for_user(repo, user_id) mapping = data.get("mapping", {}) for name, spec in request.mapping.items(): if isinstance(spec, dict): @@ -398,7 +438,7 @@ async def update_model_mapping(request: ModelMappingRequest) -> dict[str, Any]: else: mapping[name] = spec data["mapping"] = mapping - save_models(data) + _save_models_for_user(repo, user_id, data) return {"success": True, "model_mapping": request.mapping} @@ -413,9 +453,14 @@ class ModelToggleRequest(BaseModel): @router.post("/models/toggle") -async def toggle_model(request: ModelToggleRequest) -> dict[str, Any]: - """Enable or disable a model → models.json pool.enabled.""" - data = load_models() +async def toggle_model( + request: ModelToggleRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + """Enable or disable a model.""" + repo = _get_settings_repo(req) + data = _load_models_for_user(repo, user_id) pool = data.setdefault("pool", {"enabled": [], "custom": []}) enabled = pool.setdefault("enabled", []) @@ -426,7 +471,7 @@ async def toggle_model(request: ModelToggleRequest) -> dict[str, Any]: if request.model_id in enabled: enabled.remove(request.model_id) - save_models(data) + _save_models_for_user(repo, user_id, data) return {"success": True, "enabled_models": enabled} @@ -438,9 +483,14 @@ class CustomModelRequest(BaseModel): @router.post("/models/custom") -async def add_custom_model(request: CustomModelRequest) -> dict[str, Any]: - """Add a custom model → models.json pool.custom + auto-enable.""" - data = load_models() +async def add_custom_model( + request: CustomModelRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + """Add a custom model + auto-enable.""" + repo = _get_settings_repo(req) + data = _load_models_for_user(repo, user_id) pool = data.setdefault("pool", {"enabled": [], "custom": []}) custom = pool.setdefault("custom", []) enabled = pool.setdefault("enabled", []) @@ -463,7 +513,7 @@ async def add_custom_model(request: CustomModelRequest) -> dict[str, Any]: cfg["context_limit"] = request.context_limit custom_config[request.model_id] = cfg - save_models(data) + _save_models_for_user(repo, user_id, data) return {"success": True, "custom_models": custom, "enabled_models": enabled} @@ -528,9 +578,11 @@ async def test_model(request: ModelTestRequest) -> dict[str, Any]: @router.delete("/models/custom") -async def remove_custom_model(model_id: str = Query(...)) -> dict[str, Any]: - """Remove a custom model from models.json pool.custom + pool.enabled.""" - data = load_models() +async def remove_custom_model(req: Request, model_id: str = Query(...)) -> dict[str, Any]: + """Remove a custom model.""" + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + data = _load_models_for_user(repo, user_id) pool = data.setdefault("pool", {"enabled": [], "custom": []}) custom = pool.setdefault("custom", []) enabled = pool.setdefault("enabled", []) @@ -546,7 +598,7 @@ async def remove_custom_model(model_id: str = Query(...)) -> dict[str, Any]: custom_config = pool.get("custom_config", {}) custom_config.pop(model_id, None) - save_models(data) + _save_models_for_user(repo, user_id, data) return {"success": True, "custom_models": custom} @@ -558,9 +610,11 @@ class CustomModelConfigRequest(BaseModel): @router.post("/models/custom/config") -async def update_custom_model_config(request: CustomModelConfigRequest) -> dict[str, Any]: +async def update_custom_model_config(request: CustomModelConfigRequest, req: Request) -> dict[str, Any]: """Update based_on/context_limit/provider for a custom model.""" - data = load_models() + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + data = _load_models_for_user(repo, user_id) pool = data.setdefault("pool", {}) custom_config = pool.setdefault("custom_config", {}) cfg: dict[str, Any] = custom_config.get(request.model_id, {}) @@ -572,7 +626,7 @@ async def update_custom_model_config(request: CustomModelConfigRequest) -> dict[ if request.provider: custom_providers = pool.setdefault("custom_providers", {}) custom_providers[request.model_id] = request.provider - save_models(data) + _save_models_for_user(repo, user_id, data) return {"success": True, "custom_config": custom_config} @@ -588,9 +642,14 @@ class ProviderRequest(BaseModel): @router.post("/providers") -async def update_provider(request: ProviderRequest, req: Request) -> dict[str, Any]: - """Update provider config → models.json providers, then reload all agents.""" - data = load_models() +async def update_provider( + request: ProviderRequest, + req: Request, + user_id: Annotated[str, Depends(get_current_user_id)], +) -> dict[str, Any]: + """Update provider config, then reload all agents.""" + repo = _get_settings_repo(req) + data = _load_models_for_user(repo, user_id) providers = data.setdefault("providers", {}) provider_data: dict[str, Any] = {} if request.api_key is not None: @@ -598,7 +657,7 @@ async def update_provider(request: ProviderRequest, req: Request) -> dict[str, A if request.base_url is not None: provider_data["base_url"] = request.base_url providers[request.provider] = provider_data - save_models(data) + _save_models_for_user(repo, user_id, data) # @@@reload-agents-on-key-change — hot-reload all cached agents so they pick up new API keys pool = getattr(req.app.state, "agent_pool", {}) @@ -633,8 +692,14 @@ class ObservationRequest(BaseModel): @router.get("/observation") -async def get_observation_settings() -> dict[str, Any]: +async def get_observation_settings(req: Request) -> dict[str, Any]: """Get observation provider configuration.""" + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + if repo and user_id: + data = repo.get_observation_config(user_id) + if data is not None: + return data from config.observation_loader import ObservationLoader config = ObservationLoader().load() @@ -642,13 +707,19 @@ async def get_observation_settings() -> dict[str, Any]: @router.post("/observation") -async def update_observation_settings(request: ObservationRequest) -> dict[str, Any]: - """Update observation provider config (persists to observation.json). +async def update_observation_settings(request: ObservationRequest, req: Request) -> dict[str, Any]: + """Update observation provider config. New threads will pick up the active provider at creation time. Existing threads keep their locked provider — only credentials are read live. """ - data = _load_user_json("observation.json") + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + + if repo and user_id: + data = repo.get_observation_config(user_id) or {} + else: + data = _load_user_json("observation.json") data["active"] = request.active if request.langfuse is not None: @@ -660,9 +731,12 @@ async def update_observation_settings(request: ObservationRequest) -> dict[str, existing.update(request.langsmith) data["langsmith"] = existing - OBSERVATION_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(OBSERVATION_FILE, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) + if repo and user_id: + repo.set_observation_config(user_id, data) + else: + OBSERVATION_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(OBSERVATION_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) return {"success": True, "active": data.get("active")} @@ -740,8 +814,15 @@ class SandboxConfigRequest(BaseModel): @router.get("/sandboxes") -async def list_sandbox_configs() -> dict[str, Any]: - """List all sandbox configurations from ~/.leon/sandboxes/.""" +async def list_sandbox_configs(req: Request) -> dict[str, Any]: + """List all sandbox configurations.""" + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + if repo and user_id: + data = repo.get_sandbox_configs(user_id) + if data is not None: + return {"sandboxes": data} + # Filesystem fallback sandboxes: dict[str, Any] = {} seen: set[Path] = set() for root in user_home_read_candidates("sandboxes"): @@ -760,13 +841,23 @@ async def list_sandbox_configs() -> dict[str, Any]: @router.post("/sandboxes") -async def save_sandbox_config(request: SandboxConfigRequest) -> dict[str, Any]: - """Save a sandbox configuration to ~/.leon/sandboxes/.json.""" +async def save_sandbox_config(request: SandboxConfigRequest, req: Request) -> dict[str, Any]: + """Save a sandbox configuration.""" + repo = _get_settings_repo(req) + user_id = _try_get_user_id(req) if repo else None + from sandbox.config import SandboxConfig try: cfg = SandboxConfig(**request.config) - path = cfg.save(request.name) - return {"success": True, "path": str(path)} + if repo and user_id: + # Save to Supabase + existing = repo.get_sandbox_configs(user_id) or {} + existing[request.name] = cfg.model_dump() + repo.set_sandbox_configs(user_id, existing) + return {"success": True, "path": f"supabase://user_settings/{user_id}/sandbox_configs/{request.name}"} + else: + path = cfg.save(request.name) + return {"success": True, "path": str(path)} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/web/routers/thread_files.py b/backend/web/routers/thread_files.py index ef92a670d..30b0fcd09 100644 --- a/backend/web/routers/thread_files.py +++ b/backend/web/routers/thread_files.py @@ -21,6 +21,17 @@ _public = APIRouter(prefix="/api/threads/{thread_id}/files", tags=["thread-files"]) +async def _call_channel_file_service(method, *args, missing_status: int | None = None, **kwargs): + try: + return await asyncio.to_thread(method, *args, **kwargs) + except ValueError as e: + raise HTTPException(400, str(e)) from e + except FileNotFoundError as e: + if missing_status is None: + raise + raise HTTPException(missing_status, str(e)) from e + + @router.get("/list") async def list_workspace_path( thread_id: str, @@ -185,16 +196,12 @@ async def download_file( path: str = Query(...), ) -> FileResponse: """Download a file from thread-scoped files directory.""" - try: - target = await asyncio.to_thread( - file_channel_service.resolve_channel_file, - thread_id=thread_id, - relative_path=path, - ) - except ValueError as e: - raise HTTPException(400, str(e)) from e - except FileNotFoundError as e: - raise HTTPException(404, str(e)) from e + target = await _call_channel_file_service( + file_channel_service.resolve_channel_file, + thread_id=thread_id, + relative_path=path, + missing_status=404, + ) return FileResponse(path=str(target), filename=target.name, media_type="application/octet-stream") @@ -204,16 +211,12 @@ async def delete_workspace_file( path: str = Query(...), ) -> dict[str, Any]: """Delete a file from workspace.""" - try: - await asyncio.to_thread( - file_channel_service.delete_channel_file, - thread_id=thread_id, - relative_path=path, - ) - except ValueError as e: - raise HTTPException(400, str(e)) from e - except FileNotFoundError as e: - raise HTTPException(404, str(e)) from e + await _call_channel_file_service( + file_channel_service.delete_channel_file, + thread_id=thread_id, + relative_path=path, + missing_status=404, + ) return {"ok": True, "path": path} @@ -222,11 +225,8 @@ async def list_channel_files( thread_id: str, ) -> dict[str, Any]: """List files under thread-scoped files directory.""" - try: - entries = await asyncio.to_thread( - file_channel_service.list_channel_files, - thread_id=thread_id, - ) - except ValueError as e: - raise HTTPException(400, str(e)) from e + entries = await _call_channel_file_service( + file_channel_service.list_channel_files, + thread_id=thread_id, + ) return {"thread_id": thread_id, "entries": entries} diff --git a/backend/web/routers/threads.py b/backend/web/routers/threads.py index 33a75b8aa..8b380e050 100644 --- a/backend/web/routers/threads.py +++ b/backend/web/routers/threads.py @@ -21,25 +21,29 @@ from backend.web.models.requests import ( CreateThreadRequest, ResolveMainThreadRequest, + ResolvePermissionRequest, SaveThreadLaunchConfigRequest, SendMessageRequest, + ThreadPermissionRuleRequest, ) from backend.web.services import sandbox_service from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox from backend.web.services.event_buffer import ThreadEventBuffer from backend.web.services.file_channel_service import get_file_channel_source -from backend.web.services.resource_cache import clear_resource_overview_cache +from backend.web.services.resource_cache import clear_monitor_resource_overview_cache from backend.web.services.sandbox_service import destroy_thread_resources_sync, init_providers_and_managers from backend.web.services.streaming_service import ( get_or_create_thread_buffer, observe_thread_events, ) from backend.web.services.thread_launch_config_service import ( + build_existing_launch_config, + build_new_launch_config, resolve_default_config, save_last_confirmed_config, save_last_successful_config, ) -from backend.web.services.thread_naming import canonical_entity_name, sidebar_label +from backend.web.services.thread_naming import sidebar_label from backend.web.services.thread_state_service import ( get_lease_status, get_sandbox_info, @@ -50,19 +54,45 @@ from backend.web.utils.serializers import avatar_url, serialize_message from core.runtime.middleware.monitor import AgentState from sandbox.config import MountSpec +from sandbox.manager import bind_thread_to_existing_lease from sandbox.recipes import normalize_recipe_snapshot, provider_type_from_name from sandbox.thread_context import set_current_thread_id -from storage.contracts import EntityRow logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["threads"]) +class _NoopAsyncLock: + async def __aenter__(self) -> None: + return None + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + +def _is_internal_child_thread(thread_id: str) -> bool: + return thread_id.startswith("subagent-") + + def _invalidate_resource_overview_cache() -> None: - # @@@resource-overview-invalidation - thread/lease mutations change the monitor topology immediately. + # @@@monitor-resource-overview-invalidation - thread/lease mutations change the monitor topology immediately. # Clear the overview snapshot so the next /api/monitor/resources read reflects the fresh binding/state. - clear_resource_overview_cache() + clear_monitor_resource_overview_cache() + + +def _find_owned_member(app: Any, member_id: str, owner_user_id: str) -> Any | None: + member = app.state.member_repo.get_by_id(member_id) + if not member or member.owner_user_id != owner_user_id: + return None + return member + + +def _require_owned_member(app: Any, member_id: str, owner_user_id: str) -> Any: + member = _find_owned_member(app, member_id, owner_user_id) + if member is None: + raise HTTPException(403, "Not authorized") + return member async def _prepare_attachment_message( @@ -179,6 +209,86 @@ async def _validate_mount_capability_gate( ) +def _provider_unavailable_response(sandbox_type: str) -> JSONResponse: + return JSONResponse( + status_code=400, + content={ + "error": "sandbox_provider_unavailable", + "provider": sandbox_type, + }, + ) + + +def _format_ask_user_question_followup( + pending_request: dict[str, Any], + *, + answers: list[dict[str, Any]], + annotations: dict[str, Any] | None, +) -> str: + payload: dict[str, Any] = { + "questions": (pending_request.get("args") or {}).get("questions", []), + "answers": answers, + } + if annotations is not None: + payload["annotations"] = annotations + # @@@ask-user-followup-payload - keep this as one narrow, structured owner reply + # so the resumed run can continue from the user's choices without inventing + # a bespoke second continuation channel. + return ( + "The user answered your AskUserQuestion prompt. Continue the task using these answers.\n" + "\n" + f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n" + "" + ) + + +def _build_ask_user_question_answered_payload( + pending_request: dict[str, Any], + *, + answers: list[dict[str, Any]], + annotations: dict[str, Any] | None, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "questions": (pending_request.get("args") or {}).get("questions", []), + "answers": answers, + } + if annotations is not None: + payload["annotations"] = annotations + return payload + + +def _serialize_permission_answers(payload: Any) -> list[dict[str, Any]] | None: + raw_answers = getattr(payload, "answers", None) + if raw_answers is None: + return None + serialized: list[dict[str, Any]] = [] + for item in raw_answers: + if hasattr(item, "model_dump"): + serialized.append(item.model_dump(exclude_none=True)) + elif isinstance(item, dict): + serialized.append({key: value for key, value in item.items() if value is not None}) + else: + serialized.append({key: value for key, value in vars(item).items() if value is not None}) + return serialized + + +def _validate_sandbox_provider_gate(app: Any, owner_user_id: str, payload: CreateThreadRequest) -> JSONResponse | None: + sandbox_type = payload.sandbox or "local" + if payload.lease_id: + owned_lease = next( + (lease for lease in sandbox_service.list_user_leases(owner_user_id) if lease["lease_id"] == payload.lease_id), + None, + ) + if owned_lease is not None: + sandbox_type = str(owned_lease["provider_name"] or sandbox_type) + if sandbox_type == "local": + return None + provider = sandbox_service.build_provider_from_config_name(sandbox_type) + if provider is not None: + return None + return _provider_unavailable_response(sandbox_type) + + def _get_agent_for_thread(app: Any, thread_id: str) -> Any | None: """Get agent instance for a thread from the agent pool.""" pool = getattr(app.state, "agent_pool", None) @@ -194,15 +304,13 @@ def _thread_payload(app: Any, thread_id: str, sandbox_type: str) -> dict[str, An if thread is None: raise HTTPException(404, "Thread not found") member = app.state.member_repo.get_by_id(thread["member_id"]) - entity = app.state.entity_repo.get_by_id(thread["member_id"]) - if member is None or entity is None: - raise HTTPException(500, f"Thread {thread_id} missing member/entity") + if member is None: + raise HTTPException(500, f"Thread {thread_id} missing member") return { "thread_id": thread_id, "sandbox": sandbox_type, "member_id": member.id, "member_name": member.name, - "entity_name": entity.name, "branch_index": thread["branch_index"], "sidebar_label": sidebar_label(is_main=thread["is_main"], branch_index=thread["branch_index"]), "avatar_url": avatar_url(member.id, bool(member.avatar)), @@ -210,7 +318,165 @@ def _thread_payload(app: Any, thread_id: str, sandbox_type: str) -> dict[str, An } -def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: dict[str, Any] | None) -> None: +_IDLE_REPLAYABLE_RUN_EVENTS = frozenset({"error", "cancelled", "retry"}) + + +def _checkpoint_tail_is_pending_owner_turn(messages: list[dict[str, Any]]) -> bool: + if not messages: + return False + tail = messages[-1] + if tail.get("type") != "HumanMessage": + return False + meta = tail.get("metadata") or {} + return meta.get("source") not in {"system", "external"} + + +async def _get_thread_display_entries(app: Any, thread_id: str) -> list[dict[str, Any]]: + display_builder = app.state.display_builder + entries = display_builder.get_entries(thread_id) + if entries is not None: + _normalize_blocking_subagent_terminal_status(entries) + sandbox_type = resolve_thread_sandbox(app, thread_id) + agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) + if entries is not None and getattr(agent.runtime, "current_state", None) != AgentState.IDLE: + return entries + + set_current_thread_id(thread_id) + config = {"configurable": {"thread_id": thread_id}} + state = await agent.agent.aget_state(config) + values = getattr(state, "values", {}) if state else {} + messages = values.get("messages", []) if isinstance(values, dict) else [] + serialized = [serialize_message(msg) for msg in messages] + + from core.runtime.visibility import annotate_owner_visibility + + annotated, _ = annotate_owner_visibility(serialized) + if entries is not None and not _display_entries_need_idle_rebuild(entries, annotated): + return entries + entries = display_builder.build_from_checkpoint(thread_id, annotated) + if _checkpoint_tail_is_pending_owner_turn(annotated): + await _replay_latest_run_failure_events( + thread_id=thread_id, + display_builder=display_builder, + ) + entries = display_builder.get_entries(thread_id) or entries + _normalize_blocking_subagent_terminal_status(entries) + return entries + + +def _display_entries_need_idle_rebuild(entries: list[dict[str, Any]], messages: list[dict[str, Any]]) -> bool: + if not messages: + return bool(entries) + if not entries: + return True + # @@@idle-cache-honesty - idle detail must not trust cached assistant shells after + # clear/restart. Rebuild only when cache is visibly impossible for the persisted checkpoint. + return any(entry.get("role") == "assistant" and not entry.get("segments") for entry in entries) + + +def _normalize_blocking_subagent_terminal_status(entries: list[dict[str, Any]]) -> None: + for entry in entries: + if entry.get("role") != "assistant": + continue + for seg in entry.get("segments", []): + if seg.get("type") != "tool": + continue + step = seg.get("step") or {} + if step.get("name") != "Agent" or step.get("status") != "done": + continue + stream = step.get("subagent_stream") + if not isinstance(stream, dict): + continue + result_text = step.get("result") + existing_status = str(stream.get("status") or "").lower() + terminal_status = ( + existing_status + if existing_status in {"completed", "error", "cancelled"} + else ("error" if isinstance(result_text, str) and result_text.startswith("") else "completed") + ) + if stream.get("status") != terminal_status: + # @@@blocking-subagent-terminal-honesty - a finished blocking Agent tool + # must not keep exposing a stale running child status on refresh/detail/tasks. + stream["status"] = terminal_status + if terminal_status == "error" and not stream.get("error") and isinstance(result_text, str): + stream["error"] = result_text + + +def _collect_display_subagent_tasks(entries: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + tasks: dict[str, dict[str, Any]] = {} + for entry in entries: + if entry.get("role") != "assistant": + continue + for seg in entry.get("segments", []): + if seg.get("type") != "tool": + continue + step = seg.get("step") or {} + if step.get("name") != "Agent": + continue + stream = step.get("subagent_stream") + if not isinstance(stream, dict) or not stream.get("task_id"): + continue + task_id = str(stream["task_id"]) + raw_args = step.get("args") + args: dict[str, Any] = raw_args if isinstance(raw_args, dict) else {} + description = stream.get("description") or args.get("description") or args.get("prompt") + status = str(stream.get("status") or ("completed" if step.get("status") == "done" else "running")) + result_text = step.get("result") or stream.get("text") + # @@@dual-source-task-surface - blocking Agent subagents never enter parent _background_runs, + # so /tasks must also project persisted subagent_stream state from display history. + tasks[task_id] = { + "task_id": task_id, + "task_type": "agent", + "status": status, + "command_line": None, + "description": description, + "exit_code": None, + "error": stream.get("error"), + "result": result_text, + "text": result_text, + "thread_id": stream.get("thread_id"), + } + return tasks + + +async def _replay_latest_run_failure_events( + *, + thread_id: str, + display_builder: Any, +) -> None: + from backend.web.services.event_store import get_latest_run_id, read_events_after + + run_id = await get_latest_run_id(thread_id) + if not run_id or run_id.startswith("activity_"): + return + + events = await read_events_after(thread_id, run_id, 0) + if not any(event.get("event") in _IDLE_REPLAYABLE_RUN_EVENTS for event in events): + return + + # @@@idle-run-error-replay - checkpoint can stop at the owner's input when + # the run dies before first persisted AI/Tool message. Rebuild must replay + # the latest run-level failure events so refresh/detail stays honest. + for event in events: + event_type = event.get("event", "") + if event_type not in {"run_start", "run_done", *_IDLE_REPLAYABLE_RUN_EVENTS}: + continue + raw_data = event.get("data", "{}") + try: + payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data + except (json.JSONDecodeError, TypeError): + payload = {} + if not isinstance(payload, dict): + payload = {} + display_builder.apply_event(thread_id, event_type, payload) + + +def _create_thread_sandbox_resources( + thread_id: str, + sandbox_type: str, + recipe: dict[str, Any] | None, + cwd: str | None = None, +) -> None: """Create volume, lease, and terminal eagerly so volume exists before file uploads.""" from datetime import datetime @@ -250,11 +516,11 @@ def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: terminal_repo = SQLiteTerminalRepo(db_path=sandbox_db) try: terminal_id = f"term-{uuid.uuid4().hex[:12]}" - # @@@initial-cwd - use project root for local, provider default for remote + # @@@initial-cwd - local threads own their requested cwd; remote threads start from provider defaults. from backend.web.core.config import LOCAL_WORKSPACE_ROOT if sandbox_type == "local": - initial_cwd = str(LOCAL_WORKSPACE_ROOT) + initial_cwd = cwd or str(LOCAL_WORKSPACE_ROOT) else: from backend.web.services.sandbox_service import build_provider_from_config_name from sandbox.manager import resolve_provider_cwd @@ -271,43 +537,6 @@ def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: terminal_repo.close() -def _resolve_existing_lease_cwd(lease_id: str, fallback_cwd: str | None) -> str: - if fallback_cwd: - return fallback_cwd - - from backend.web.core.config import LOCAL_WORKSPACE_ROOT - from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path - from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo - - terminal_repo = SQLiteTerminalRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) - try: - row = terminal_repo.get_latest_by_lease(lease_id) - finally: - terminal_repo.close() - if row and row.get("cwd"): - return str(row["cwd"]) - - return str(LOCAL_WORKSPACE_ROOT) - - -def _bind_thread_to_existing_lease(thread_id: str, lease_id: str, *, cwd: str | None) -> str: - from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path - from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo - - initial_cwd = _resolve_existing_lease_cwd(lease_id, cwd) - terminal_repo = SQLiteTerminalRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) - try: - terminal_repo.create( - terminal_id=f"term-{uuid.uuid4().hex[:12]}", - thread_id=thread_id, - lease_id=lease_id, - initial_cwd=initial_cwd, - ) - finally: - terminal_repo.close() - return initial_cwd - - def _create_owned_thread( app: Any, owner_user_id: str, @@ -342,16 +571,17 @@ def _create_owned_thread( raise HTTPException(403, "Lease not authorized") sandbox_type = str(owned_lease["provider_name"] or sandbox_type) - # @@@non-atomic-create - these 3 steps (seq++, thread, entity) are not atomic. - seq = app.state.member_repo.increment_entity_seq(agent_member_id) + # @@@non-atomic-create - these 3 steps (seq++, thread) are not atomic. + seq = app.state.member_repo.increment_thread_seq(agent_member_id) new_thread_id = f"{agent_member_id}-{seq}" - has_main = app.state.thread_repo.get_main_thread(agent_member_id) is not None + has_main = app.state.thread_repo.get_default_thread(agent_member_id) is not None resolved_is_main = is_main or not has_main branch_index = 0 if resolved_is_main else app.state.thread_repo.get_next_branch_index(agent_member_id) app.state.thread_repo.create( thread_id=new_thread_id, member_id=agent_member_id, + user_id=new_thread_id, sandbox_type=sandbox_type, cwd=payload.cwd, created_at=time.time(), @@ -360,29 +590,6 @@ def _create_owned_thread( branch_index=branch_index, ) - # @@@entity-name-convention - entity display names derive from member + thread role, never sandbox strings. - entity_name = canonical_entity_name(agent_member.name, is_main=resolved_is_main, branch_index=branch_index) - - # @@@entity-id-is-member-id - agent entity id = member_id (per-agent, not per-thread). - # thread_id field on the entity points to the current main thread. - # If entity already exists, update thread_id (main thread changed); otherwise create. - existing_entity = app.state.entity_repo.get_by_id(agent_member_id) - if existing_entity is not None: - if resolved_is_main: - app.state.entity_repo.update(agent_member_id, thread_id=new_thread_id, name=entity_name) - # Branch threads don't update the entity — it represents the main identity - else: - app.state.entity_repo.create( - EntityRow( - id=agent_member_id, - type="agent", - member_id=agent_member_id, - name=entity_name, - thread_id=new_thread_id if resolved_is_main else None, - created_at=time.time(), - ) - ) - # Set thread state app.state.thread_sandbox[new_thread_id] = sandbox_type if payload.cwd: @@ -390,7 +597,7 @@ def _create_owned_thread( if selected_lease_id: # @@@reuse-lease-binding - Reuse an existing lease by attaching a fresh terminal for the new thread. - bound_cwd = _bind_thread_to_existing_lease( + bound_cwd = bind_thread_to_existing_lease( new_thread_id, selected_lease_id, cwd=payload.cwd, @@ -403,29 +610,22 @@ def _create_owned_thread( new_thread_id, sandbox_type, payload.recipe.model_dump() if payload.recipe else None, + payload.cwd, ) if selected_lease_id and owned_lease is not None: - successful_config = { - "create_mode": "existing", - "provider_config": sandbox_type, - "recipe": owned_lease.get("recipe"), - "lease_id": owned_lease["lease_id"], - "model": payload.model, - "workspace": app.state.thread_cwd.get(new_thread_id), - } + successful_config = build_existing_launch_config( + lease=owned_lease, + model=payload.model, + workspace=app.state.thread_cwd.get(new_thread_id), + ) else: - successful_config = { - "create_mode": "new", - "provider_config": sandbox_type, - "recipe": normalize_recipe_snapshot( - provider_type_from_name(sandbox_type), - payload.recipe.model_dump() if payload.recipe else None, - ), - "lease_id": None, - "model": payload.model, - "workspace": app.state.thread_cwd.get(new_thread_id) or payload.cwd, - } + successful_config = build_new_launch_config( + provider_config=sandbox_type, + recipe=payload.recipe.model_dump() if payload.recipe else None, + model=payload.model, + workspace=app.state.thread_cwd.get(new_thread_id) or payload.cwd, + ) save_last_successful_config(app, owner_user_id, agent_member_id, successful_config) return { @@ -433,7 +633,6 @@ def _create_owned_thread( "sandbox": sandbox_type, "member_id": agent_member_id, "member_name": agent_member.name, - "entity_name": entity_name, "branch_index": branch_index, "sidebar_label": sidebar_label(is_main=resolved_is_main, branch_index=branch_index), "avatar_url": avatar_url(agent_member_id, bool(agent_member.avatar)), @@ -448,6 +647,9 @@ async def create_thread( app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any] | JSONResponse: """Create a new child thread for an agent member.""" + provider_error = _validate_sandbox_provider_gate(app, user_id, payload) + if provider_error is not None: + return provider_error # Validate bind_mounts capability before creating thread sandbox_type = payload.sandbox or "local" requested_mounts = payload.bind_mounts if payload.bind_mounts else [] @@ -467,17 +669,45 @@ async def resolve_main_thread( user_id: Annotated[str, Depends(get_current_user_id)], app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: - """Return the main thread for a member, or null when none exists.""" - agent_member = app.state.member_repo.get_by_id(payload.member_id) - if not agent_member or agent_member.owner_user_id != user_id: + """Return the default representative thread for a member template.""" + agent_member = _find_owned_member(app, payload.member_id, user_id) + if agent_member is None: # Return null instead of 403 — member may not exist yet (stale client state) # or belong to another user (harmless to reveal "no thread") - return {"thread": None} + return { + "member_id": payload.member_id, + "default_thread_id": None, + "thread": None, + } - existing = app.state.thread_repo.get_main_thread(payload.member_id) - if existing is None: - return {"thread": None} - return {"thread": _thread_payload(app, existing["id"], existing.get("sandbox_type", "local"))} + default_thread = app.state.thread_repo.get_default_thread(payload.member_id) + if default_thread is None: + return { + "member_id": payload.member_id, + "default_thread_id": None, + "thread": None, + } + try: + return { + "member_id": payload.member_id, + "default_thread_id": default_thread["id"], + "thread": _thread_payload(app, default_thread["id"], default_thread.get("sandbox_type", "local")), + } + except HTTPException as exc: + # @@@orphan-default-thread - stale bootstrap data can leave the member pointing at a thread whose + # member rows are gone. Treat that as "no resolvable default thread" instead of surfacing a 500. + if exc.status_code == 500 and "missing member" in str(exc.detail): + logger.warning( + "resolve_main_thread ignored orphaned default thread %s for member %s", + default_thread["id"], + payload.member_id, + ) + return { + "member_id": payload.member_id, + "default_thread_id": None, + "thread": None, + } + raise @router.get("/default-config") @@ -486,9 +716,7 @@ async def get_default_thread_config( user_id: Annotated[str, Depends(get_current_user_id)], app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: - agent_member = app.state.member_repo.get_by_id(member_id) - if not agent_member or agent_member.owner_user_id != user_id: - raise HTTPException(403, "Not authorized") + _require_owned_member(app, member_id, user_id) return resolve_default_config(app, user_id, member_id) @@ -498,9 +726,7 @@ async def save_default_thread_config( user_id: Annotated[str, Depends(get_current_user_id)], app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: - agent_member = app.state.member_repo.get_by_id(payload.member_id) - if not agent_member or agent_member.owner_user_id != user_id: - raise HTTPException(403, "Not authorized") + _require_owned_member(app, payload.member_id, user_id) save_last_confirmed_config(app, user_id, payload.member_id, payload.model_dump()) return {"ok": True} @@ -518,6 +744,8 @@ async def list_threads( threads = [] for t in raw: tid = t["id"] + if _is_internal_child_thread(tid): + continue sandbox_type = t.get("sandbox_type", "local") # Check if agent is currently running — pool key is "{thread_id}:{sandbox_type}" running = False @@ -536,7 +764,6 @@ async def list_threads( "sandbox": t.get("sandbox_type", "local"), "member_name": t.get("member_name"), "member_id": t.get("member_id"), - "entity_name": t.get("entity_name"), "branch_index": t.get("branch_index"), "sidebar_label": sidebar_label( is_main=bool(t.get("is_main", False)), @@ -562,26 +789,10 @@ async def get_thread_messages( @@@display-builder — returns pre-computed ChatEntry[] from DisplayBuilder. Hot path: return in-memory state. Cold path: rebuild from checkpoint. """ - display_builder = app.state.display_builder sandbox_type = resolve_thread_sandbox(app, thread_id) agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) - - # Hot path: return cached display entries - entries = display_builder.get_entries(thread_id) - if entries is None: - # Cold path: rebuild from checkpoint - set_current_thread_id(thread_id) - config = {"configurable": {"thread_id": thread_id}} - state = await agent.agent.aget_state(config) - values = getattr(state, "values", {}) if state else {} - messages = values.get("messages", []) if isinstance(values, dict) else [] - serialized = [serialize_message(msg) for msg in messages] - - from core.runtime.visibility import annotate_owner_visibility - - annotated, _ = annotate_owner_visibility(serialized) - entries = display_builder.build_from_checkpoint(thread_id, annotated) - + display_builder = app.state.display_builder + entries = await _get_thread_display_entries(app, thread_id) sandbox_info = get_sandbox_info(agent, thread_id, sandbox_type) return { "thread_id": thread_id, @@ -622,17 +833,8 @@ async def delete_thread( except Exception as exc: logger.warning("Failed to destroy sandbox resources for thread %s: %s", thread_id, exc) await asyncio.to_thread(delete_thread_in_db, thread_id) - # Also delete from threads table (entity-chat addition) - thread_data = app.state.thread_repo.get_by_id(thread_id) - member_id = thread_data["member_id"] if thread_data else None + # Also delete from threads table (member-chat addition) app.state.thread_repo.delete(thread_id) - # Entity is keyed by member_id (shared across threads) — update its thread_id - # to the next main thread, or clear it if no threads remain - if member_id: - entity = app.state.entity_repo.get_by_id(member_id) - if entity and entity.thread_id == thread_id: - next_main = app.state.thread_repo.get_main_thread(member_id) - app.state.entity_repo.update(member_id, thread_id=next_main["id"] if next_main else None) # Clean up thread-specific state app.state.thread_sandbox.pop(thread_id, None) @@ -647,6 +849,28 @@ async def delete_thread( return {"ok": True, "thread_id": thread_id} +@router.post("/{thread_id}/clear") +async def clear_thread_history( + thread_id: str, + user_id: Annotated[str, Depends(verify_thread_owner)], + app: Annotated[Any, Depends(get_app)] = None, +) -> dict[str, Any]: + """Clear replayable thread history while preserving the thread itself.""" + sandbox_type = resolve_thread_sandbox(app, thread_id) + + lock = await get_thread_lock(app, thread_id) + async with lock: + agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) + if hasattr(agent, "runtime") and agent.runtime.current_state == AgentState.ACTIVE: + raise HTTPException(status_code=409, detail="Cannot clear thread while run is in progress") + await agent.aclear_thread(thread_id) + + app.state.display_builder.clear(thread_id) + app.state.thread_event_buffers.pop(thread_id, None) + app.state.queue_manager.clear_all(thread_id) + return {"ok": True, "thread_id": thread_id} + + @router.post("/{thread_id}/messages") async def send_message( thread_id: str, @@ -705,7 +929,7 @@ async def get_thread_history( thread_id: str, limit: int = 20, truncate: int = 300, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: """Compact conversation history for debugging — no raw LangChain noise. @@ -743,6 +967,8 @@ def _expand(msg: Any) -> list[dict[str, Any]]: cls = msg.__class__.__name__ if cls == "HumanMessage": metadata = getattr(msg, "metadata", {}) or {} + if metadata.get("source") == "internal": + return [] if metadata.get("source") == "system": return [{"role": "notification", "text": _trunc(extract_text_content(msg.content))}] return [{"role": "human", "text": _trunc(extract_text_content(msg.content))}] @@ -759,7 +985,7 @@ def _expand(msg: Any) -> list[dict[str, Any]]: text = extract_text_content(msg.content) if text: entries.append({"role": "assistant", "text": _trunc(text)}) - return entries or [{"role": "assistant", "text": ""}] + return entries if cls == "ToolMessage": return [ { @@ -782,11 +1008,155 @@ def _expand(msg: Any) -> list[dict[str, Any]]: } +@router.get("/{thread_id}/permissions") +async def get_thread_permissions( + thread_id: str, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + thread_lock: Annotated[asyncio.Lock | None, Depends(get_thread_lock)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + # @@@permission-state-lock - owner polling and resolve can race on idle + # threads. Serialize the lightweight /permissions read with resolve/persist + # so stale checkpoint hydration cannot resurrect an already-resolved request. + async with thread_lock or _NoopAsyncLock(): + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + rule_state = agent.get_thread_permission_rules(thread_id) + return { + "thread_id": thread_id, + "requests": agent.get_pending_permission_requests(thread_id), + "session_rules": rule_state["rules"], + "managed_only": rule_state["managed_only"], + } + + +@router.post("/{thread_id}/permissions/{request_id}/resolve") +async def resolve_thread_permission_request( + thread_id: str, + request_id: str, + payload: ResolvePermissionRequest, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, + app: Annotated[Any, Depends(get_app)] = None, + thread_lock: Annotated[asyncio.Lock | None, Depends(get_thread_lock)] = None, +) -> dict[str, Any]: + async with thread_lock or _NoopAsyncLock(): + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + pending_requests = { + item.get("request_id"): item + for item in agent.get_pending_permission_requests(thread_id) + if isinstance(item, dict) and item.get("request_id") + } + pending_request = pending_requests.get(request_id) + is_ask_user_question = bool(pending_request and pending_request.get("tool_name") == "AskUserQuestion") + answers = _serialize_permission_answers(payload) + if is_ask_user_question and payload.decision == "allow" and not answers: + raise HTTPException(status_code=400, detail="AskUserQuestion answers are required when approving the request") + ok = agent.resolve_permission_request( + request_id, + decision=payload.decision, + message=payload.message, + answers=answers, + annotations=getattr(payload, "annotations", None), + ) + if not ok: + raise HTTPException(status_code=404, detail="Permission request not found") + await agent.agent.apersist_state(thread_id) + if is_ask_user_question and payload.decision == "allow" and answers is not None: + # @@@ask-user-lifecycle - the owner's answer is about to become a + # real follow-up user message. Clear the old request before that + # run starts so checkpoint replay cannot resurrect the popup. + agent.drop_permission_request(request_id) + await agent.agent.apersist_state(thread_id) + + followup: dict[str, Any] | None = None + if is_ask_user_question and payload.decision == "allow" and pending_request is not None and answers is not None: + from backend.web.services.message_routing import route_message_to_brain + + answered_payload = _build_ask_user_question_answered_payload( + pending_request, + answers=answers, + annotations=getattr(payload, "annotations", None), + ) + + followup = await route_message_to_brain( + app, + thread_id, + _format_ask_user_question_followup( + pending_request, + answers=answers, + annotations=getattr(payload, "annotations", None), + ), + source="internal", + message_metadata={"ask_user_question_answered": answered_payload}, + ) + + response = {"ok": True, "thread_id": thread_id, "request_id": request_id} + if followup is not None: + response["followup"] = followup + return response + + +@router.post("/{thread_id}/permissions/rules") +async def add_thread_permission_rule( + thread_id: str, + payload: ThreadPermissionRuleRequest, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + rule_state = agent.get_thread_permission_rules(thread_id) + if rule_state["managed_only"]: + raise HTTPException(status_code=409, detail="Managed permission rules only; session overrides are disabled") + ok = agent.add_thread_permission_rule( + thread_id, + behavior=payload.behavior, + tool_name=payload.tool_name, + ) + if not ok: + raise HTTPException(status_code=400, detail="Could not add thread permission rule") + await agent.agent.apersist_state(thread_id) + updated = agent.get_thread_permission_rules(thread_id) + return { + "ok": True, + "thread_id": thread_id, + "scope": "session", + "rules": updated["rules"], + "managed_only": updated["managed_only"], + } + + +@router.delete("/{thread_id}/permissions/rules/{behavior}/{tool_name}") +async def delete_thread_permission_rule( + thread_id: str, + behavior: str, + tool_name: str, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + ok = agent.remove_thread_permission_rule( + thread_id, + behavior=behavior, + tool_name=tool_name, + ) + if not ok: + raise HTTPException(status_code=404, detail="Thread permission rule not found") + await agent.agent.apersist_state(thread_id) + updated = agent.get_thread_permission_rules(thread_id) + return { + "ok": True, + "thread_id": thread_id, + "scope": "session", + "rules": updated["rules"], + "managed_only": updated["managed_only"], + } + + @router.get("/{thread_id}/runtime") async def get_thread_runtime( thread_id: str, stream: bool = False, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: """Get runtime status for a thread.""" @@ -902,12 +1272,9 @@ async def get_thread_terminal_status( async def get_thread_lease_status( thread_id: str, agent: Annotated[Any, Depends(get_thread_agent)] = None, -) -> dict[str, Any]: +) -> dict[str, Any] | None: """Get SandboxLease status for a thread.""" - try: - return await get_lease_status(agent, thread_id) - except ValueError as e: - raise HTTPException(404, str(e)) from e + return await get_lease_status(agent, thread_id) # SSE response headers: disable proxy buffering for real-time streaming @@ -931,17 +1298,12 @@ async def stream_thread_events( app: Annotated[Any, Depends(get_app)] = None, ) -> EventSourceResponse: """Persistent SSE event stream — uses ?token= for auth (EventSource can't set headers).""" - from backend.web.core.dependencies import _DEV_PAYLOAD, _DEV_SKIP_AUTH - - if _DEV_SKIP_AUTH: - sse_user_id = _DEV_PAYLOAD["user_id"] - else: - if not token: - raise HTTPException(401, "Missing token") - try: - sse_user_id = app.state.auth_service.verify_token(token)["user_id"] - except ValueError as e: - raise HTTPException(401, str(e)) + if not token: + raise HTTPException(401, "Missing token") + try: + sse_user_id = app.state.auth_service.verify_token(token)["user_id"] + except ValueError as e: + raise HTTPException(401, str(e)) thread = app.state.thread_repo.get_by_id(thread_id) if not thread: raise HTTPException(404, "Thread not found") @@ -995,7 +1357,7 @@ async def stream_thread_events( @router.post("/{thread_id}/runs/cancel") async def cancel_run( thread_id: str, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ): """Cancel an active run for the given thread.""" @@ -1016,6 +1378,33 @@ def _get_background_runs(app: Any, thread_id: str) -> dict: return getattr(agent, "_background_runs", {}) if agent else {} +def _background_run_type(run: Any) -> str: + return "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" + + +def _serialize_background_run(task_id: str, run: Any, *, include_result: bool) -> dict[str, Any]: + run_type = _background_run_type(run) + result_text = run.get_result() if include_result and run.is_done else None + payload = { + "task_id": task_id, + "task_type": run_type, + "status": "completed" if run.is_done else "running", + "command_line": getattr(run, "command", None) if run_type == "bash" else None, + } + if include_result: + payload["result"] = result_text + payload["text"] = result_text + return payload + payload["description"] = getattr(run, "description", None) + payload["exit_code"] = getattr(getattr(run, "_cmd", None), "exit_code", None) if run_type == "bash" else None + payload["error"] = None + return payload + + +async def _get_display_task_map(app: Any, thread_id: str) -> dict[str, dict[str, Any]]: + return _collect_display_subagent_tasks(await _get_thread_display_entries(app, thread_id)) + + @router.get("/{thread_id}/tasks") async def list_tasks( thread_id: str, @@ -1023,18 +1412,20 @@ async def list_tasks( ) -> list[dict]: """列出线程的所有后台 run(bash + agent)""" runs = _get_background_runs(request.app, thread_id) - result = [] - for task_id, run in runs.items(): - run_type = "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" + result = [_serialize_background_run(task_id, run, include_result=False) for task_id, run in runs.items()] + seen_task_ids = set(runs) + for task_id, task in (await _get_display_task_map(request.app, thread_id)).items(): + if task_id in seen_task_ids: + continue result.append( { - "task_id": task_id, - "task_type": run_type, - "status": "completed" if run.is_done else "running", - "command_line": getattr(run, "command", None) if run_type == "bash" else None, - "description": getattr(run, "description", None), - "exit_code": getattr(getattr(run, "_cmd", None), "exit_code", None) if run_type == "bash" else None, - "error": None, + "task_id": task["task_id"], + "task_type": task["task_type"], + "status": task["status"], + "command_line": task["command_line"], + "description": task["description"], + "exit_code": task["exit_code"], + "error": task["error"], } ) return result @@ -1050,18 +1441,19 @@ async def get_task( runs = _get_background_runs(request.app, thread_id) run = runs.get(task_id) if not run: - raise HTTPException(status_code=404, detail="Task not found") + task = (await _get_display_task_map(request.app, thread_id)).get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + return { + "task_id": task["task_id"], + "task_type": task["task_type"], + "status": task["status"], + "command_line": task["command_line"], + "result": task["result"], + "text": task["text"], + } - run_type = "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" - result_text = run.get_result() if run.is_done else None - return { - "task_id": task_id, - "task_type": run_type, - "status": "completed" if run.is_done else "running", - "command_line": getattr(run, "command", None) if run_type == "bash" else None, - "result": result_text, - "text": result_text, - } + return _serialize_background_run(task_id, run, include_result=True) @router.post("/{thread_id}/tasks/{task_id}/cancel") @@ -1074,7 +1466,16 @@ async def cancel_task( runs = _get_background_runs(request.app, thread_id) run = runs.get(task_id) if not run: - raise HTTPException(status_code=404, detail="Task not found") + task = (await _get_display_task_map(request.app, thread_id)).get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + if task["status"] != "running": + raise HTTPException(status_code=400, detail="Task is not running") + thread_task = request.app.state.thread_tasks.get(thread_id) + if thread_task is None or thread_task.done(): + raise HTTPException(status_code=400, detail="Task is not independently cancellable") + thread_task.cancel() + return {"ok": True, "message": "Run cancellation requested", "task_id": task_id} if run.is_done: raise HTTPException(status_code=400, detail="Task is not running") @@ -1112,7 +1513,7 @@ async def _notify_task_cancelled(app: Any, thread_id: str, task_id: str, run: An agent_id=task_id, agent_name=f"cancel-{task_id[:8]}", ) - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -1125,6 +1526,8 @@ async def _notify_task_cancelled(app: Any, thread_id: str, task_id: str, run: An ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: logger.warning("Failed to emit task_done for cancelled task %s", task_id, exc_info=True) diff --git a/backend/web/services/agent_pool.py b/backend/web/services/agent_pool.py index 50ecb5dbf..b3041c6a9 100644 --- a/backend/web/services/agent_pool.py +++ b/backend/web/services/agent_pool.py @@ -1,18 +1,21 @@ """Agent pool management service.""" import asyncio -import os +import logging from pathlib import Path from typing import Any from fastapi import FastAPI +from config.user_paths import preferred_existing_user_home_path from core.identity.agent_registry import get_or_create_agent_id from core.runtime.agent import create_leon_agent from sandbox.manager import lookup_sandbox_for_thread from sandbox.thread_context import set_current_thread_id from storage.runtime import build_storage_container +logger = logging.getLogger(__name__) + # Thread lock for config updates _config_update_locks: dict[str, asyncio.Lock] = {} _agent_create_locks: dict[str, asyncio.Lock] = {} @@ -23,15 +26,16 @@ def create_agent_sync( workspace_root: Path | None = None, model_name: str | None = None, agent: str | None = None, + bundle_dir: Path | None = None, + thread_repo: Any = None, + member_repo: Any = None, queue_manager: Any = None, chat_repos: dict | None = None, extra_allowed_paths: list[str] | None = None, + web_app: Any = None, ) -> Any: """Create a LeonAgent with the given sandbox. Runs in a thread.""" - storage_container = build_storage_container( - main_db_path=os.getenv("LEON_DB_PATH"), - eval_db_path=os.getenv("LEON_EVAL_DB_PATH"), - ) + storage_container = build_storage_container() # @@@web-file-ops-repo - inject storage-backed repo so file_operations route to correct provider. from core.operations import FileOperationRecorder, set_recorder @@ -41,10 +45,15 @@ def create_agent_sync( workspace_root=workspace_root or Path.cwd(), sandbox=sandbox_name if sandbox_name != "local" else None, storage_container=storage_container, + permission_resolver_scope="thread", + thread_repo=thread_repo, + member_repo=member_repo, queue_manager=queue_manager, chat_repos=chat_repos, + web_app=web_app, verbose=True, agent=agent, + bundle_dir=bundle_dir, extra_allowed_paths=extra_allowed_paths, ) @@ -76,11 +85,27 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st thread_data = app_obj.state.thread_repo.get_by_id(thread_id) if hasattr(app_obj.state, "thread_repo") else None if sandbox_type == "local": cwd = app_obj.state.thread_cwd.get(thread_id) + cwd_from_live_map = cwd is not None if not cwd and thread_data and thread_data.get("cwd"): cwd = thread_data["cwd"] - app_obj.state.thread_cwd[thread_id] = cwd if cwd: - workspace_root = Path(cwd).resolve() + path = Path(cwd).expanduser() + # @@@fresh-local-cwd-owns-workspace - a cwd chosen in this live backend session is + # the caller contract for local threads; create it instead of silently falling + # back to the repo root. Persisted paths from another host stay advisory. + if cwd_from_live_map: + path.mkdir(parents=True, exist_ok=True) + workspace_root = path.resolve() + app_obj.state.thread_cwd[thread_id] = str(workspace_root) + # @@@host-local-cwd-is-advisory - persisted local thread cwd can come from another + # host (for example a macOS path stored in shared Supabase but replayed inside a + # Linux staging container). Only pin workspace_root when that path exists here. + elif path.exists() and path.is_dir(): + workspace_root = path.resolve() + app_obj.state.thread_cwd[thread_id] = str(workspace_root) + else: + app_obj.state.thread_cwd.pop(thread_id, None) + logger.warning("Ignoring unavailable local cwd for thread %s: %s", thread_id, cwd) # Look up model for this thread (threads table → preferences default) model_name = thread_data.get("model") if thread_data else None @@ -93,29 +118,35 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st # @@@agent-vs-member - thread_config.agent stores a member ID (e.g. "__leon__") for display, # NOT an agent type name ("bash", "general", etc.). Never pass it to create_leon_agent. agent_name = agent # explicit caller-provided type only; None → default Leon agent + bundle_dir = None + if thread_data and thread_data.get("member_id"): + member_dir = preferred_existing_user_home_path("members", str(thread_data["member_id"])) + if member_dir.is_dir(): + bundle_dir = member_dir.resolve() - # @@@chat-repos - construct chat_repos for ChatToolService if entity system is available + # @@@chat-repos - construct chat_repos for ChatToolService (v2 messaging) chat_repos = None - if hasattr(app_obj.state, "entity_repo") and thread_data: - entity_repo = app_obj.state.entity_repo - member_repo = getattr(app_obj.state, "member_repo", None) - # Entity id = member_id in the new model; look up by member_id, not thread_id + if hasattr(app_obj.state, "member_repo") and thread_data: + member_repo = app_obj.state.member_repo agent_member_id = thread_data.get("member_id") - agent_entity = entity_repo.get_by_id(agent_member_id) if agent_member_id else None - if agent_entity: - # agent social identity = member_id - agent_member = member_repo.get_by_id(agent_entity.member_id) if member_repo else None - # owner social identity = owner's user_id (same as their member_id for humans) - owner_user_id = agent_member.owner_user_id if agent_member else "" + agent_member = member_repo.get_by_id(agent_member_id) if agent_member_id else None + if agent_member: + chat_identity_id = thread_data.get("user_id") + # @@@thread-chat-identity-source - agent chat identity must come from the + # thread-owned dedicated user_id, never from the member template id. + if not chat_identity_id: + raise RuntimeError(f"thread.user_id is required for agent chat identity: {thread_id}") + owner_id = agent_member.owner_user_id or "" chat_repos = { - "user_id": agent_entity.member_id, # agent's social identity = member_id - "owner_user_id": owner_user_id, - "entity_repo": entity_repo, - "chat_service": getattr(app_obj.state, "chat_service", None), - "chat_entity_repo": getattr(app_obj.state, "chat_entity_repo", None), - "chat_message_repo": getattr(app_obj.state, "chat_message_repo", None), + "chat_identity_id": chat_identity_id, + "user_id": chat_identity_id, + "owner_id": owner_id, "member_repo": member_repo, - "chat_event_bus": getattr(app_obj.state, "chat_event_bus", None), + "messaging_service": getattr(app_obj.state, "messaging_service", None), + "chat_member_repo": getattr(app_obj.state, "chat_member_repo", None), + "messages_repo": getattr(app_obj.state, "messages_repo", None), + "relationship_repo": getattr(app_obj.state, "relationship_repo", None), + "agent_config_repo": getattr(app_obj.state, "agent_config_repo", None), } # @@@per-thread-file-access - ensure thread files are accessible from agent @@ -136,12 +167,23 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st except FileNotFoundError: pass - extra_allowed_paths = extra_allowed_paths or None + extra_allowed_paths_or_none: list[str] | None = extra_allowed_paths or None # @@@ agent-init-thread - LeonAgent.__init__ uses run_until_complete, must run in thread qm = getattr(app_obj.state, "queue_manager", None) agent_obj = await asyncio.to_thread( - create_agent_sync, sandbox_type, workspace_root, model_name, agent_name, qm, chat_repos, extra_allowed_paths + create_agent_sync, + sandbox_name=sandbox_type, + workspace_root=workspace_root, + model_name=model_name, + agent=agent_name, + bundle_dir=bundle_dir, + thread_repo=getattr(app_obj.state, "thread_repo", None), + member_repo=getattr(app_obj.state, "member_repo", None), + queue_manager=qm, + chat_repos=chat_repos, + extra_allowed_paths=extra_allowed_paths_or_none, + web_app=app_obj, ) member = agent_name or "leon" agent_id = get_or_create_agent_id( diff --git a/backend/web/services/auth_service.py b/backend/web/services/auth_service.py index 85c9c21c6..dd7b46c21 100644 --- a/backend/web/services/auth_service.py +++ b/backend/web/services/auth_service.py @@ -5,10 +5,11 @@ import logging import os import time +from collections.abc import Callable import jwt -from storage.contracts import AccountRepo, EntityRepo, InviteCodeRepo, MemberRepo, MemberRow, MemberType +from storage.contracts import InviteCodeRepo, MemberRepo, MemberRow, MemberType logger = logging.getLogger(__name__) @@ -19,15 +20,15 @@ class AuthService: def __init__( self, members: MemberRepo, - accounts: AccountRepo, - entities: EntityRepo, supabase_client=None, + supabase_auth_client=None, + supabase_auth_client_factory: Callable[[], object] | None = None, invite_codes: InviteCodeRepo | None = None, ) -> None: self._members = members - self._accounts = accounts - self._entities = entities - self._sb = supabase_client # None in sqlite-only mode + self._sb = supabase_client # storage/service-role client + self._sb_auth = supabase_auth_client # end-user auth client + self._sb_auth_factory = supabase_auth_client_factory self._invite_codes = invite_codes # ------------------------------------------------------------------ @@ -39,6 +40,7 @@ def __init__( def send_otp(self, email: str, password: str, invite_code: str) -> None: """Validate invite code, create user via signUp (sends confirmation OTP to email).""" + auth_client = self._auth_api(self._require_auth_client()) if self._sb is None: raise RuntimeError("Supabase client required.") if self._invite_codes is None or not self._invite_codes.is_valid(invite_code): @@ -46,7 +48,7 @@ def send_otp(self, email: str, password: str, invite_code: str) -> None: from supabase_auth.errors import AuthApiError try: - self._sb.auth.sign_up({"email": email, "password": password}) + auth_client.sign_up({"email": email, "password": password}) except AuthApiError as e: msg = e.message or "" if "already registered" in msg or "already exists" in msg: @@ -55,12 +57,13 @@ def send_otp(self, email: str, password: str, invite_code: str) -> None: def verify_register_otp(self, email: str, token: str) -> dict: """Verify signup OTP. Returns temp_token to be used in complete_register.""" + auth_client = self._auth_api(self._require_auth_client()) if self._sb is None: raise RuntimeError("Supabase client required.") from supabase_auth.errors import AuthApiError try: - resp = self._sb.auth.verify_otp({"email": email, "token": token, "type": "signup"}) + resp = auth_client.verify_otp({"email": email, "token": token, "type": "signup"}) except AuthApiError as e: raise ValueError(f"验证码错误: {e.message}") from e if resp.user is None or resp.session is None: @@ -129,8 +132,7 @@ def complete_register(self, temp_token: str, invite_code: str) -> dict: def login(self, identifier: str, password: str) -> dict: """Login with email or mycel_id + password.""" - if self._sb is None: - raise RuntimeError("Supabase client required for login. Set LEON_STORAGE_STRATEGY=supabase.") + auth_client = self._auth_api(self._require_auth_client()) # Resolve email email = self._resolve_email(identifier) @@ -139,7 +141,7 @@ def login(self, identifier: str, password: str) -> dict: # Sign in via Supabase try: - resp = self._sb.auth.sign_in_with_password({"email": email, "password": password}) + resp = auth_client.sign_in_with_password({"email": email, "password": password}) except AuthApiError: raise ValueError("邮箱或密码错误") if resp.user is None or resp.session is None: @@ -175,6 +177,16 @@ def login(self, identifier: str, password: str) -> dict: def verify_token(self, token: str) -> dict: """Verify Supabase JWT. Returns {user_id}.""" + auth_client = self._sb_auth_factory() if self._sb_auth_factory is not None else self._sb_auth + if auth_client is not None: + auth_api = self._auth_api(auth_client) + try: + user_resp = auth_api.get_user(token) + except Exception as e: + raise ValueError(f"Token 无效: {e}") from e + if user_resp is None or getattr(user_resp, "user", None) is None: + raise ValueError("Token 无效: user not found") + return {"user_id": str(user_resp.user.id)} jwt_secret = os.getenv("SUPABASE_JWT_SECRET") if not jwt_secret: raise RuntimeError("SUPABASE_JWT_SECRET env var required for token verification.") @@ -204,12 +216,22 @@ def _resolve_email(self, identifier: str) -> str: return member.email return identifier.strip() + def _require_auth_client(self): + if self._sb_auth_factory is not None: + return self._sb_auth_factory() + if self._sb_auth is None: + raise RuntimeError("Supabase auth client required. Configure SUPABASE_ANON_KEY for auth runtime.") + return self._sb_auth + + def _auth_api(self, auth_client): + return getattr(auth_client, "auth", auth_client) + def _create_initial_agents(self, owner_user_id: str, now: float) -> dict | None: """Create Toad and Morel agents for a new user. Returns first agent info.""" from pathlib import Path from backend.web.services.member_service import MEMBERS_DIR, _write_agent_md, _write_json - from storage.providers.sqlite.member_repo import generate_member_id + from storage.utils import generate_member_id initial_agents = [ {"name": "Toad", "description": "Curious and energetic assistant", "avatar": "toad.jpeg"}, diff --git a/backend/web/services/chat_service.py b/backend/web/services/chat_service.py deleted file mode 100644 index 51a5ebbeb..000000000 --- a/backend/web/services/chat_service.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Chat service — entity-to-entity communication.""" - -from __future__ import annotations - -import logging -import time -import uuid -from collections.abc import Callable -from typing import Any - -from backend.web.utils.serializers import avatar_url -from storage.contracts import ( - ChatEntityRepo, - ChatMessageRepo, - ChatMessageRow, - ChatRepo, - ChatRow, - DeliveryResolver, - EntityRepo, - MemberRepo, -) - -logger = logging.getLogger(__name__) - - -class ChatService: - def __init__( - self, - chat_repo: ChatRepo, - chat_entity_repo: ChatEntityRepo, - chat_message_repo: ChatMessageRepo, - entity_repo: EntityRepo, - member_repo: MemberRepo, - event_bus: Any = None, - delivery_fn: Callable | None = None, - delivery_resolver: DeliveryResolver | None = None, - ) -> None: - self._chats = chat_repo - self._chat_entities = chat_entity_repo - self._messages = chat_message_repo - self._entities = entity_repo - self._members = member_repo - self._event_bus = event_bus - self._delivery_fn = delivery_fn - self._delivery_resolver = delivery_resolver - - def _resolve_name(self, user_id: str) -> str: - """Resolve display name: entity_repo (agents) → member_repo (humans).""" - e = self._entities.get_by_id(user_id) - if e: - return e.name - m = self._members.get_by_id(user_id) if self._members else None - return m.name if m else "unknown" - - def find_or_create_chat(self, user_ids: list[str], title: str | None = None) -> ChatRow: - """Find existing 1:1 chat between two social identities, or create one.""" - if len(user_ids) != 2: - raise ValueError("Use create_group_chat() for 3+ participants") - - existing_id = self._chat_entities.find_chat_between(user_ids[0], user_ids[1]) - if existing_id: - return self._chats.get_by_id(existing_id) - - now = time.time() - chat_id = str(uuid.uuid4()) - self._chats.create(ChatRow(id=chat_id, title=title, created_at=now)) - for uid in user_ids: - self._chat_entities.add_participant(chat_id, uid, now) - return self._chats.get_by_id(chat_id) - - def create_group_chat(self, user_ids: list[str], title: str | None = None) -> ChatRow: - """Create a group chat with 3+ participants.""" - if len(user_ids) < 3: - raise ValueError("Group chat requires 3+ participants") - now = time.time() - chat_id = str(uuid.uuid4()) - self._chats.create(ChatRow(id=chat_id, title=title, created_at=now)) - for uid in user_ids: - self._chat_entities.add_participant(chat_id, uid, now) - return self._chats.get_by_id(chat_id) - - def send_message( - self, - chat_id: str, - sender_id: str, - content: str, - mentioned_ids: list[str] | None = None, - signal: str | None = None, - ) -> ChatMessageRow: - """Send a message in a chat.""" - logger.debug( - "[send_message] chat=%s sender=%s content=%.50s signal=%s", - chat_id[:8], - sender_id[:15], - content[:50], - signal, - ) - mentions = mentioned_ids or [] - now = time.time() - msg_id = str(uuid.uuid4()) - msg = ChatMessageRow( - id=msg_id, - chat_id=chat_id, - sender_id=sender_id, - content=content, - mentioned_ids=mentions, - created_at=now, - ) - self._messages.create(msg) - - sender_name = self._resolve_name(sender_id) - - if self._event_bus: - self._event_bus.publish( - chat_id, - { - "event": "message", - "data": { - "id": msg_id, - "chat_id": chat_id, - "sender_id": sender_id, - "sender_name": sender_name, - "content": content, - "mentioned_ids": mentions, - "created_at": now, - }, - }, - ) - - self._deliver_to_agents(chat_id, sender_id, sender_name, content, mentions, signal=signal) - return msg - - def _deliver_to_agents( - self, - chat_id: str, - sender_id: str, - sender_name: str, - content: str, - mentioned_ids: list[str] | None = None, - signal: str | None = None, - ) -> None: - """For each non-sender agent participant in the chat, deliver to their brain thread.""" - mentions = set(mentioned_ids or []) - participants = self._chat_entities.list_participants(chat_id) - sender_avatar_url = None - sender_mid = sender_id - sender_entity = self._entities.get_by_id(sender_id) - if sender_entity: - sender_mid = sender_entity.member_id - m = self._members.get_by_id(sender_mid) if self._members else None - sender_avatar_url = avatar_url(sender_mid, bool(m.avatar if m else None)) - - for ce in participants: - if ce.user_id == sender_id: - continue - entity = self._entities.get_by_id(ce.user_id) - if not entity or entity.type != "agent" or not entity.thread_id: - logger.debug( - "[deliver] SKIP %s type=%s thread=%s", - ce.user_id, - getattr(entity, "type", None), - getattr(entity, "thread_id", None), - ) - continue - # @@@delivery-strategy-gate — check contact block/mute + chat mute - # @@@mention-override — mentioned entities skip mute (but not block) - if self._delivery_resolver: - from storage.contracts import DeliveryAction - - is_mentioned = ce.user_id in mentions - action = self._delivery_resolver.resolve( - ce.user_id, - chat_id, - sender_id, - is_mentioned=is_mentioned, - ) - if action != DeliveryAction.DELIVER: - logger.info( - "[deliver] POLICY %s for %s (sender=%s chat=%s mentioned=%s)", - action.value, - ce.user_id, - sender_id, - chat_id[:8], - is_mentioned, - ) - continue - if self._delivery_fn: - logger.debug("[deliver] → %s (thread=%s) from=%s", entity.id, entity.thread_id, sender_name) - try: - self._delivery_fn(entity, content, sender_name, chat_id, sender_id, sender_avatar_url, signal=signal) - except Exception: - logger.exception("Failed to deliver chat message to entity %s", entity.id) - else: - logger.warning("[deliver] NO delivery_fn for %s", entity.id) - - def set_delivery_fn(self, fn) -> None: - self._delivery_fn = fn - - def list_chats_for_user(self, user_id: str) -> list[dict]: - """List all chats for a user (social identity) with summary info.""" - chat_ids = self._chat_entities.list_chats_for_user(user_id) - result = [] - for cid in chat_ids: - chat = self._chats.get_by_id(cid) - if not chat or chat.status != "active": - continue - participants = self._chat_entities.list_participants(cid) - entities_info = [] - for p in participants: - e = self._entities.get_by_id(p.user_id) - if e: - m = self._members.get_by_id(e.member_id) if self._members else None - entities_info.append( - { - "id": p.user_id, - "name": e.name, - "type": e.type, - "avatar_url": avatar_url(e.member_id, bool(m.avatar if m else None)), - } - ) - else: - m = self._members.get_by_id(p.user_id) if self._members else None - if m: - entities_info.append( - { - "id": p.user_id, - "name": m.name, - "type": "human", - "avatar_url": avatar_url(m.id, bool(m.avatar)), - } - ) - msgs = self._messages.list_by_chat(cid, limit=1) - last_msg = None - if msgs: - m = msgs[0] - last_msg = { - "content": m.content, - "sender_name": self._resolve_name(m.sender_id), - "created_at": m.created_at, - } - unread = self._messages.count_unread(cid, user_id) - has_mention = self._messages.has_unread_mention(cid, user_id) - result.append( - { - "id": cid, - "title": chat.title, - "status": chat.status, - "created_at": chat.created_at, - "entities": entities_info, - "last_message": last_msg, - "unread_count": unread, - "has_mention": has_mention, - } - ) - return result diff --git a/backend/web/services/cron_job_service.py b/backend/web/services/cron_job_service.py index e7b3a7330..c59b54e5e 100644 --- a/backend/web/services/cron_job_service.py +++ b/backend/web/services/cron_job_service.py @@ -9,45 +9,55 @@ def _repo() -> Any: return make_cron_job_repo() -def list_cron_jobs() -> list[dict[str, Any]]: - repo = _repo() +def list_cron_jobs(owner_user_id: str | None = None, repo: Any = None) -> list[dict[str, Any]]: + own_repo = repo is None + repo = repo or _repo() try: - return repo.list_all() + return repo.list_all(owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def get_cron_job(job_id: str) -> dict[str, Any] | None: - repo = _repo() +def get_cron_job(job_id: str, owner_user_id: str | None = None, repo: Any = None) -> dict[str, Any] | None: + own_repo = repo is None + repo = repo or _repo() try: - return repo.get(job_id) + return repo.get(job_id, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def create_cron_job(*, name: str, cron_expression: str, **fields: Any) -> dict[str, Any]: +def create_cron_job(*, name: str, cron_expression: str, repo: Any = None, **fields: Any) -> dict[str, Any]: if not name or not name.strip(): raise ValueError("name must not be empty") if not cron_expression or not cron_expression.strip(): raise ValueError("cron_expression must not be empty") - repo = _repo() + own_repo = repo is None + repo = repo or _repo() try: return repo.create(name=name, cron_expression=cron_expression, **fields) finally: - repo.close() + if own_repo: + repo.close() -def update_cron_job(job_id: str, **fields: Any) -> dict[str, Any] | None: - repo = _repo() +def update_cron_job(job_id: str, owner_user_id: str | None = None, repo: Any = None, **fields: Any) -> dict[str, Any] | None: + own_repo = repo is None + repo = repo or _repo() try: - return repo.update(job_id, **fields) + return repo.update(job_id, owner_user_id=owner_user_id, **fields) finally: - repo.close() + if own_repo: + repo.close() -def delete_cron_job(job_id: str) -> bool: - repo = _repo() +def delete_cron_job(job_id: str, owner_user_id: str | None = None, repo: Any = None) -> bool: + own_repo = repo is None + repo = repo or _repo() try: - return repo.delete(job_id) + return repo.delete(job_id, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() diff --git a/backend/web/services/cron_service.py b/backend/web/services/cron_service.py index bfb0ca244..2c9c8993f 100644 --- a/backend/web/services/cron_service.py +++ b/backend/web/services/cron_service.py @@ -26,9 +26,11 @@ class CronService: """Background cron scheduler that creates panel_tasks from cron job templates.""" - def __init__(self) -> None: + def __init__(self, *, cron_job_repo: Any = None, task_repo: Any = None) -> None: self._running = False self._task: asyncio.Task | None = None + self._cron_job_repo = cron_job_repo + self._task_repo = task_repo # -- public API ---------------------------------------------------------- @@ -52,13 +54,18 @@ async def stop(self) -> None: self._task = None logger.info("[cron-service] stopped") - async def trigger_job(self, job_id: str) -> dict[str, Any] | None: + async def trigger_job(self, job_id: str, owner_user_id: str | None = None) -> dict[str, Any] | None: """Manually trigger a cron job. Creates a task from template. Returns the created task dict, or None if the job doesn't exist, is disabled, or has an invalid template. """ - job = await asyncio.to_thread(cron_job_service.get_cron_job, job_id) + job = await asyncio.to_thread( + cron_job_service.get_cron_job, + job_id, + owner_user_id=owner_user_id, + repo=self._cron_job_repo, + ) if job is None: return None if not job.get("enabled"): @@ -76,12 +83,19 @@ async def trigger_job(self, job_id: str) -> dict[str, Any] | None: task_fields: dict[str, Any] = {k: v for k, v in template.items() if k in _ALLOWED_TEMPLATE_KEYS} task_fields["source"] = "cron" task_fields["cron_job_id"] = job_id + task_fields["owner_user_id"] = job.get("owner_user_id") - task = await asyncio.to_thread(task_service.create_task, **task_fields) + task = await asyncio.to_thread(task_service.create_task, repo=self._task_repo, **task_fields) # Update last_run_at on the cron job now_ms = int(time.time() * 1000) - await asyncio.to_thread(cron_job_service.update_cron_job, job_id, last_run_at=now_ms) + await asyncio.to_thread( + cron_job_service.update_cron_job, + job_id, + owner_user_id=job.get("owner_user_id"), + repo=self._cron_job_repo, + last_run_at=now_ms, + ) logger.info("[cron-service] triggered job %s → task %s", job_id, task.get("id")) return task @@ -129,7 +143,7 @@ async def _scheduler_loop(self) -> None: async def _check_and_trigger(self) -> None: """Check all enabled cron jobs and trigger those that are due.""" - jobs = await asyncio.to_thread(cron_job_service.list_cron_jobs) + jobs = await asyncio.to_thread(cron_job_service.list_cron_jobs, repo=self._cron_job_repo) for job in jobs: if self.is_due(job): try: diff --git a/backend/web/services/delivery_resolver.py b/backend/web/services/delivery_resolver.py deleted file mode 100644 index 43e6e6bd7..000000000 --- a/backend/web/services/delivery_resolver.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Delivery strategy resolver — evaluates per-recipient delivery action. - -@@@delivery-strategy-gate — single evaluation point between message storage -and agent delivery. Checks contact-level block/mute → chat-level mute → default. -""" - -from __future__ import annotations - -import logging -import time - -from storage.contracts import ChatEntityRepo, ContactRepo, DeliveryAction - -logger = logging.getLogger(__name__) - - -class DefaultDeliveryResolver: - """Evaluates delivery action for a chat message recipient. - - Priority (highest wins): - 1. Contact block (sender blocked by recipient) → DROP - 2. Contact mute (sender muted by recipient) → NOTIFY - 3. Chat mute (recipient muted this chat) → NOTIFY - 4. Default → DELIVER - """ - - def __init__(self, contact_repo: ContactRepo, chat_entity_repo: ChatEntityRepo) -> None: - self._contacts = contact_repo - self._chat_entities = chat_entity_repo - - def resolve( - self, - recipient_id: str, - chat_id: str, - sender_id: str, - *, - is_mentioned: bool = False, - ) -> DeliveryAction: - # 1. Contact-level block — always DROP, even if mentioned - contact = self._contacts.get(recipient_id, sender_id) - if contact and contact.relation == "blocked": - logger.debug("[resolver] DROP: %s blocked %s", recipient_id[:15], sender_id[:15]) - return DeliveryAction.DROP - - # @@@mention-override — mentioned entities skip mute checks - if is_mentioned: - return DeliveryAction.DELIVER - - # 2. Contact-level mute - if contact and contact.relation == "muted": - logger.debug("[resolver] NOTIFY: %s muted %s", recipient_id[:15], sender_id[:15]) - return DeliveryAction.NOTIFY - - # 3. Chat-level mute - if self._is_chat_muted(recipient_id, chat_id): - logger.debug("[resolver] NOTIFY: %s muted chat %s", recipient_id[:15], chat_id[:8]) - return DeliveryAction.NOTIFY - - # 4. Default - return DeliveryAction.DELIVER - - def _is_chat_muted(self, user_id: str, chat_id: str) -> bool: - """Check if user has muted this specific chat.""" - participants = self._chat_entities.list_participants(chat_id) - for ce in participants: - if ce.user_id == user_id: - muted = getattr(ce, "muted", False) - if not muted: - return False - mute_until = getattr(ce, "mute_until", None) - if mute_until is not None and mute_until < time.time(): - return False # mute expired - return True - return False diff --git a/backend/web/services/display_builder.py b/backend/web/services/display_builder.py index 25f034ed5..6af91d91d 100644 --- a/backend/web/services/display_builder.py +++ b/backend/web/services/display_builder.py @@ -38,18 +38,46 @@ # Helpers — ported from message-mapper.ts # --------------------------------------------------------------------------- -_CHAT_MESSAGE_RE = re.compile(r"]*>([\s\S]*?)") - - -def _extract_chat_message(text: str) -> str | None: - m = _CHAT_MESSAGE_RE.search(text) - return m.group(1).strip() if m else None +_TASK_NOTIFICATION_RUN_ID_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) +_TASK_NOTIFICATION_STATUS_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) def _make_id(prefix: str = "db") -> str: return f"{prefix}-{uuid.uuid4().hex[:12]}" +def _extract_terminal_task_status(notification_type: str | None, text: str) -> tuple[str | None, str | None]: + if notification_type != "agent" or "" not in text: + return None, None + task_match = _TASK_NOTIFICATION_RUN_ID_RE.search(text) + status_match = _TASK_NOTIFICATION_STATUS_RE.search(text) + task_id = task_match.group(1).strip() if task_match else None + status = status_match.group(1).strip().lower() if status_match else None + return task_id, status + + +def _reconcile_subagent_stream_status( + entries: list[dict], + current_turn: dict | None, + task_id: str, + status: str, +) -> None: + # @@@checkpoint-status-reconcile - idle detail rebuild only sees persisted + # checkpoint messages, not live task_done events. If a later persisted + # terminal notification names the child task, reconcile the earlier Agent + # subagent_stream status so cold rebuild does not regress it back to running. + turns: list[dict] = [] + if current_turn is not None: + turns.append(current_turn) + turns.extend(entry for entry in reversed(entries) if entry.get("role") == "assistant" and entry is not current_turn) + for turn in turns: + for seg in turn.get("segments", []): + stream = seg.get("step", {}).get("subagent_stream") + if seg.get("type") == "tool" and stream and stream.get("task_id") == task_id: + stream["status"] = status + return + + # --------------------------------------------------------------------------- # Entry builders # --------------------------------------------------------------------------- @@ -89,6 +117,39 @@ def _append_to_turn(turn: dict, msg_id: str, segments: list[dict]) -> None: turn.setdefault("messageIds", []).append(msg_id) +def _build_subagent_stream( + *, + task_id: str, + thread_id: str, + description: str | None, + status: str, +) -> dict[str, Any]: + return { + "task_id": task_id, + "thread_id": thread_id, + "description": description, + "text": "", + "tool_calls": [], + "status": status, + } + + +def _build_hidden_ask_user_answer_entry( + *, + msg_id: str | None, + payload: dict[str, Any], + now: int, +) -> dict[str, Any]: + return { + "id": msg_id or _make_id("hist-user"), + "role": "user", + "content": "", + "timestamp": now, + "showing": False, + "ask_user_question_answered": payload, + } + + # --------------------------------------------------------------------------- # ThreadDisplay — per-thread in-memory state # --------------------------------------------------------------------------- @@ -234,6 +295,15 @@ def _handle_human( # Hidden if display.get("showing") is False: + ask_answered = meta.get("ask_user_question_answered") + if isinstance(ask_answered, dict): + entries.append( + _build_hidden_ask_user_answer_entry( + msg_id=msg.get("id"), + payload=ask_answered, + now=now, + ) + ) return None, None # System / external chat notification → notice @@ -242,6 +312,9 @@ def _handle_human( if source == "system" or (source == "external" and ntype == "chat"): content = _extract_text_content(msg.get("content")) msg_run_id = meta.get("run_id") or None + task_id, task_status = _extract_terminal_task_status(ntype, content) + if task_id and task_status: + _reconcile_subagent_stream_status(entries, current_turn, task_id, task_status) # Fold into current turn if same run if current_turn and (not msg_run_id or msg_run_id == current_run_id): @@ -332,19 +405,12 @@ def _handle_tool(self, msg: dict, _i: int, current_turn: dict | None, _now: int) seg["step"]["result"] = content_str seg["step"]["status"] = "done" - # Restore subagent_stream from metadata meta = msg.get("metadata") or {} - task_id = meta.get("task_id") - sub_thread = meta.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) - - if not task_id and seg["step"].get("name") == "Agent": - try: - parsed = json.loads(content_str) - if isinstance(parsed, dict) and parsed.get("task_id"): - task_id = parsed["task_id"] - sub_thread = parsed.get("thread_id") or f"subagent-{task_id}" - except (json.JSONDecodeError, TypeError): - pass + task_id, sub_thread, task_status = _extract_subagent_stream_identity( + seg["step"].get("name"), + meta, + content_str, + ) if sub_thread and not seg["step"].get("subagent_stream"): seg["step"]["subagent_stream"] = { @@ -353,7 +419,7 @@ def _handle_tool(self, msg: dict, _i: int, current_turn: dict | None, _now: int) "description": meta.get("description"), "text": "", "tool_calls": [], - "status": "completed", + "status": task_status, } break @@ -381,6 +447,18 @@ def _handle_user_message(td: ThreadDisplay, data: dict) -> dict | None: run_start/run_done events. This allows steers to appear at the bottom while the agent keeps streaming above. """ + if data.get("showing") is False: + ask_answered = data.get("ask_user_question_answered") + if not isinstance(ask_answered, dict): + return None + entry = _build_hidden_ask_user_answer_entry( + msg_id=None, + payload=ask_answered, + now=int(time.time() * 1000), + ) + td.entries.append(entry) + return {"type": "append_entry", "entry": entry} + content = data.get("content", "") entry: dict = { "id": _make_id("user"), @@ -502,18 +580,18 @@ def _handle_tool_result(td: ThreadDisplay, data: dict) -> dict | None: seg["step"]["result"] = result seg["step"]["status"] = "done" - # Subagent stream tracking - task_id = metadata.get("task_id") - sub_thread = metadata.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) + task_id, sub_thread, task_status = _extract_subagent_stream_identity( + seg["step"].get("name"), + metadata, + result, + ) if sub_thread and not seg["step"].get("subagent_stream"): - seg["step"]["subagent_stream"] = { - "task_id": task_id or "", - "thread_id": sub_thread, - "description": metadata.get("description"), - "text": "", - "tool_calls": [], - "status": "running", - } + seg["step"]["subagent_stream"] = _build_subagent_stream( + task_id=task_id or "", + thread_id=sub_thread, + description=metadata.get("description"), + status=task_status, + ) return { "type": "update_segment", @@ -526,8 +604,15 @@ def _handle_tool_result(td: ThreadDisplay, data: dict) -> dict | None: def _handle_notice(td: ThreadDisplay, data: dict) -> dict | None: content = data.get("content", "") ntype = data.get("notification_type") + task_id, task_status = _extract_terminal_task_status(ntype, content) turn = _get_current_turn(td) + if task_id and task_status: + # @@@live-notice-status-reconcile - live parent detail stays on the + # in-memory display while the followthrough run is still active, so the + # terminal notice must reconcile the earlier Agent step immediately + # instead of waiting for a later cold rebuild from checkpoint. + _reconcile_subagent_stream_status(td.entries, turn, task_id, task_status) if turn: # Fold into current turn seg = {"type": "notice", "content": content, "notification_type": ntype} @@ -629,22 +714,18 @@ def _handle_task_start(td: ThreadDisplay, data: dict) -> dict | None: task_id = data["task_id"] sub_thread = data.get("thread_id") or f"subagent-{task_id}" - # Find most recent Agent tool call without subagent_stream + # @@@late-task-start-race - background Agent tools can return their + # immediate "started" ToolMessage before the async task_start activity + # reaches the parent thread. Still patch the newest Agent step that + # has no child stream, even if its tool_result already marked it done. for seg in reversed(turn["segments"]): - if ( - seg.get("type") == "tool" - and seg.get("step", {}).get("name") == "Agent" - and seg.get("step", {}).get("status") == "calling" - and not seg.get("step", {}).get("subagent_stream") - ): - seg["step"]["subagent_stream"] = { - "task_id": task_id, - "thread_id": sub_thread, - "description": data.get("description"), - "text": "", - "tool_calls": [], - "status": "running", - } + if seg.get("type") == "tool" and seg.get("step", {}).get("name") == "Agent" and not seg.get("step", {}).get("subagent_stream"): + seg["step"]["subagent_stream"] = _build_subagent_stream( + task_id=task_id, + thread_id=sub_thread, + description=data.get("description"), + status="running", + ) idx = _find_seg_index(turn, seg["step"]["id"]) return { "type": "update_segment", @@ -679,6 +760,28 @@ def _find_seg_index(turn: dict, tc_id: str) -> int: return -1 +def _extract_subagent_stream_identity(step_name: str | None, metadata: dict, content: str) -> tuple[str | None, str | None, str]: + task_id = metadata.get("task_id") + sub_thread = metadata.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) + task_status = "completed" if task_id else "running" + + if task_id or step_name != "Agent": + return task_id, sub_thread, task_status + + try: + parsed = json.loads(content) + except (json.JSONDecodeError, TypeError): + return task_id, sub_thread, task_status + + if not isinstance(parsed, dict) or not parsed.get("task_id"): + return task_id, sub_thread, task_status + + task_id = parsed["task_id"] + sub_thread = parsed.get("thread_id") or f"subagent-{task_id}" + task_status = parsed.get("status") or "running" + return task_id, sub_thread, task_status + + # Event type → handler _EVENT_HANDLERS: dict[str, Any] = { "user_message": _handle_user_message, diff --git a/backend/web/services/event_buffer.py b/backend/web/services/event_buffer.py index df2db5263..103622ca3 100644 --- a/backend/web/services/event_buffer.py +++ b/backend/web/services/event_buffer.py @@ -70,6 +70,9 @@ class ThreadEventBuffer: _ring: deque[dict] = field(default_factory=lambda: deque(maxlen=2000)) _notify: asyncio.Condition = field(default_factory=asyncio.Condition) _total_count: int = 0 # monotonic counter (total events ever put) + # @@@thread-buffer-never-finishes - keep the same observer protocol surface + # as RunEventBuffer, but thread buffers never mark completion. + finished: asyncio.Event = field(default_factory=asyncio.Event) async def put(self, event: dict) -> None: self._ring.append(event) diff --git a/backend/web/services/event_store.py b/backend/web/services/event_store.py index 998b08018..b33eb61ea 100644 --- a/backend/web/services/event_store.py +++ b/backend/web/services/event_store.py @@ -2,56 +2,34 @@ import asyncio import json -from pathlib import Path from typing import Any from storage.contracts import RunEventRepo -from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path from storage.runtime import build_storage_container -_DB_PATH = resolve_role_db_path(SQLiteDBRole.MAIN) _default_run_event_repo: RunEventRepo | None = None -_default_run_event_repo_path: Path | None = None -def init_event_store() -> None: - """Initialize run event storage for current provider strategy.""" - global _default_run_event_repo, _default_run_event_repo_path - if _default_run_event_repo is not None: - _default_run_event_repo.close() - _default_run_event_repo = None - _default_run_event_repo_path = None - - container = build_storage_container(main_db_path=_DB_PATH) - provider = container.provider_for("run_event_repo") - if provider != "sqlite": - return - - # Connection factory in RunEventRepo already guarantees WAL + PRAGMA settings. - repo = container.run_event_repo() - repo.close() - - -def _resolve_run_event_repo(run_event_repo: RunEventRepo | None) -> RunEventRepo: +def _resolve_run_event_repo(run_event_repo: RunEventRepo | None) -> RunEventRepo | None: if run_event_repo is not None: return run_event_repo - global _default_run_event_repo, _default_run_event_repo_path - if _default_run_event_repo is not None and _default_run_event_repo_path == _DB_PATH: - return _default_run_event_repo - + global _default_run_event_repo if _default_run_event_repo is not None: - _default_run_event_repo.close() - _default_run_event_repo = None - _default_run_event_repo_path = None + return _default_run_event_repo - container = build_storage_container(main_db_path=_DB_PATH) + try: + container = build_storage_container() + except RuntimeError: + return None # @@@event-store-single-path - keep one persistence boundary; when caller omits repo, resolve default repo from storage container. _default_run_event_repo = container.run_event_repo() - _default_run_event_repo_path = _DB_PATH return _default_run_event_repo +_noop_seq = 0 + + async def append_event( thread_id: str, run_id: str, @@ -61,6 +39,10 @@ async def append_event( ) -> int: """Persist one SSE event and return its sequence number.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + global _noop_seq + _noop_seq += 1 + return _noop_seq payload = _event_payload_to_dict(event) return int( await asyncio.to_thread( @@ -82,6 +64,8 @@ async def read_events_after( ) -> list[dict[str, Any]]: """Return events with seq > after_seq for the given run.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + return [] rows = await asyncio.to_thread( repo.list_events, thread_id, @@ -103,18 +87,24 @@ async def read_events_after( async def get_last_seq(thread_id: str, run_event_repo: RunEventRepo | None = None) -> int: """Return the highest seq for a thread, or 0.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + return 0 return int(await asyncio.to_thread(repo.latest_seq, thread_id)) async def get_run_start_seq(thread_id: str, run_id: str, run_event_repo: RunEventRepo | None = None) -> int: """Return the first seq for a specific run, or 0.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + return 0 return int(await asyncio.to_thread(repo.run_start_seq, thread_id, run_id)) async def get_latest_run_id(thread_id: str, run_event_repo: RunEventRepo | None = None) -> str | None: """Return the run_id of the most recent run for a thread, or None.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + return None return await asyncio.to_thread(repo.latest_run_id, thread_id) @@ -125,6 +115,8 @@ async def cleanup_old_runs( ) -> int: """Delete all but the N most recent runs for a thread. Returns deleted count.""" repo = _resolve_run_event_repo(run_event_repo) + if repo is None: + return 0 run_ids = await asyncio.to_thread(repo.list_run_ids, thread_id) if len(run_ids) <= keep_latest: return 0 @@ -136,12 +128,6 @@ async def cleanup_old_runs( return int(await asyncio.to_thread(repo.delete_runs, thread_id, old_ids)) -async def cleanup_thread(thread_id: str, run_event_repo: RunEventRepo | None = None) -> int: - """Delete all events for a thread. Returns deleted count.""" - repo = _resolve_run_event_repo(run_event_repo) - return int(await asyncio.to_thread(repo.delete_thread_events, thread_id)) - - def _event_payload_to_dict(event: dict[str, Any]) -> dict[str, Any]: raw_data = event.get("data", {}) if isinstance(raw_data, dict): diff --git a/backend/web/services/idle_reaper.py b/backend/web/services/idle_reaper.py index 90651365a..a739aa9fb 100644 --- a/backend/web/services/idle_reaper.py +++ b/backend/web/services/idle_reaper.py @@ -40,7 +40,7 @@ async def idle_reaper_loop(app_obj: FastAPI) -> None: try: count = await asyncio.to_thread(run_idle_reaper_once, app_obj) if count > 0: - print(f"[idle-reaper] paused+closed {count} expired chat session(s)") + print(f"[idle-reaper] reclaimed+closed {count} expired chat session(s)") except Exception as e: print(f"[idle-reaper] error: {e}") await asyncio.sleep(IDLE_REAPER_INTERVAL_SEC) diff --git a/backend/web/services/library_service.py b/backend/web/services/library_service.py index 2919f8dd6..a33886e17 100644 --- a/backend/web/services/library_service.py +++ b/backend/web/services/library_service.py @@ -15,19 +15,6 @@ LIBRARY_DIR = library_dir() -def ensure_library_dir() -> None: - LIBRARY_DIR.mkdir(parents=True, exist_ok=True) - (LIBRARY_DIR / "skills").mkdir(exist_ok=True) - (LIBRARY_DIR / "agents").mkdir(exist_ok=True) - legacy_recipe_dir = LIBRARY_DIR / "recipes" - # @@@recipe-storage-cutover - recipes now live in SQLite only; delete the dead file tree so it cannot masquerade as live state. - if legacy_recipe_dir.exists(): - if legacy_recipe_dir.is_dir(): - shutil.rmtree(legacy_recipe_dir) - else: - legacy_recipe_dir.unlink() - - def _read_json(path: Path, default: Any = None) -> Any: if not path.exists(): return default if default is not None else {} diff --git a/backend/web/services/marketplace_client.py b/backend/web/services/marketplace_client.py index 49de82258..47dc1fb49 100644 --- a/backend/web/services/marketplace_client.py +++ b/backend/web/services/marketplace_client.py @@ -17,7 +17,7 @@ HUB_URL = os.environ.get("MYCEL_HUB_URL", "http://localhost:8090") -_hub_client = httpx.Client(timeout=30.0) +_hub_client = httpx.Client(timeout=30.0, trust_env=False) def _hub_api(method: str, path: str, **kwargs: Any) -> dict: diff --git a/backend/web/services/member_service.py b/backend/web/services/member_service.py index ac295e4f4..d1ae1f965 100644 --- a/backend/web/services/member_service.py +++ b/backend/web/services/member_service.py @@ -22,7 +22,6 @@ import yaml from backend.web.core.paths import avatars_dir, members_dir -from backend.web.services.thread_naming import canonical_entity_name from backend.web.utils.serializers import avatar_url from config.defaults.tool_catalog import TOOLS_BY_NAME, ToolDef from config.loader import AgentLoader @@ -38,10 +37,6 @@ def _load_tools_catalog() -> dict[str, ToolDef]: return TOOLS_BY_NAME -def ensure_members_dir() -> None: - MEMBERS_DIR.mkdir(parents=True, exist_ok=True) - - # ── Low-level I/O helpers ── @@ -346,15 +341,8 @@ def list_members(owner_user_id: str | None = None, member_repo: Any = None) -> l # @@@auth-scope — scoped by owner from DB, config from filesystem if owner_user_id: if member_repo is None: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() - try: - agents = repo.list_by_owner_user_id(owner_user_id) - finally: - repo.close() - else: - agents = member_repo.list_by_owner_user_id(owner_user_id) + raise RuntimeError("member_repo is required when owner_user_id is provided") + agents = member_repo.list_by_owner_user_id(owner_user_id) results = [] for agent in agents: agent_dir = MEMBERS_DIR / agent.id @@ -391,9 +379,15 @@ def get_member(member_id: str) -> dict[str, Any] | None: return _member_to_dict(member_dir) -def create_member(name: str, description: str = "", owner_user_id: str | None = None, member_repo: Any = None) -> dict[str, Any]: +def create_member( + name: str, + description: str = "", + owner_user_id: str | None = None, + member_repo: Any = None, + agent_config_repo: Any = None, +) -> dict[str, Any]: from storage.contracts import MemberRow, MemberType - from storage.providers.sqlite.member_repo import generate_member_id + from storage.utils import generate_member_id now = time.time() now_ms = int(now * 1000) @@ -411,6 +405,19 @@ def create_member(name: str, description: str = "", owner_user_id: str | None = }, ) + # Dual-write to Supabase repo + if agent_config_repo: + _save_config_to_repo( + agent_config_repo, + member_id, + name=name, + description=description, + status="draft", + version="0.1.0", + created_at=now_ms, + updated_at=now_ms, + ) + # Persist to members table so list_members finds it if owner_user_id: row = MemberRow( @@ -422,16 +429,9 @@ def create_member(name: str, description: str = "", owner_user_id: str | None = owner_user_id=owner_user_id, created_at=now, ) - if member_repo is not None: - member_repo.create(row) - else: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() - try: - repo.create(row) - finally: - repo.close() + if member_repo is None: + raise RuntimeError("member_repo is required when owner_user_id is provided") + member_repo.create(row) return get_member(member_id) # type: ignore @@ -439,8 +439,6 @@ def create_member(name: str, description: str = "", owner_user_id: str | None = def update_member( member_id: str, member_repo: Any = None, - entity_repo: Any = None, - thread_repo: Any = None, **fields: Any, ) -> dict[str, Any] | None: if member_id == "__leon__": @@ -472,45 +470,15 @@ def update_member( meta["updated_at"] = int(time.time() * 1000) _write_json(member_dir / "meta.json", meta) - # Sync name to DB if "name" in updates: if member_repo is None: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - member_repo = SQLiteMemberRepo() - if entity_repo is None: - from storage.providers.sqlite.entity_repo import SQLiteEntityRepo - - entity_repo = SQLiteEntityRepo() - if thread_repo is None: - from storage.providers.sqlite.thread_repo import SQLiteThreadRepo - - thread_repo = SQLiteThreadRepo() - + raise RuntimeError("member_repo is required to update member name") member_repo.update(member_id, name=updates["name"]) - member = member_repo.get_by_id(member_id) - if member is None: - raise ValueError(f"Member {member_id} not found after update") - for entity in entity_repo.get_by_member_id(member_id): - if entity.thread_id is None: - entity_repo.update(entity.id, name=member.name) - continue - thread = thread_repo.get_by_id(entity.thread_id) - if thread is None: - raise ValueError(f"Entity {entity.id} references missing thread {entity.thread_id}") - entity_repo.update( - entity.id, - name=canonical_entity_name( - member.name, - is_main=bool(thread["is_main"]), - branch_index=int(thread["branch_index"]), - ), - ) return get_member(member_id) -def update_member_config(member_id: str, config_patch: dict[str, Any]) -> dict[str, Any] | None: +def update_member_config(member_id: str, config_patch: dict[str, Any], agent_config_repo: Any = None) -> dict[str, Any] | None: if member_id == "__leon__": member_dir = _ensure_leon_dir() else: @@ -549,9 +517,94 @@ def update_member_config(member_id: str, config_patch: dict[str, Any]) -> dict[s meta = _read_json(member_dir / "meta.json", {}) meta["updated_at"] = int(time.time() * 1000) _write_json(member_dir / "meta.json", meta) + + # Dual-write full state to Supabase repo + if agent_config_repo: + try: + bundle = AgentLoader().load_bundle(member_dir) + _save_config_to_repo( + agent_config_repo, + member_id, + name=bundle.agent.name, + description=bundle.agent.description, + model=bundle.agent.model, + tools=bundle.agent.tools, + system_prompt=bundle.agent.system_prompt, + status=bundle.meta.get("status", "draft"), + version=bundle.meta.get("version", "0.1.0"), + created_at=bundle.meta.get("created_at", 0), + updated_at=bundle.meta.get("updated_at", 0), + runtime={k: {"enabled": v.enabled, "desc": v.desc} for k, v in bundle.runtime.items()}, + mcp={n: {"command": s.command, "args": s.args, "env": s.env, "disabled": s.disabled} for n, s in bundle.mcp.items()}, + ) + # Sync rules + for rule in bundle.rules: + agent_config_repo.save_rule(member_id, f"{rule['name']}.md", rule.get("content", "")) + # Sync sub-agents + for agent_cfg in bundle.agents: + if agent_cfg.source_dir and agent_cfg.source_dir.resolve() == _SYSTEM_AGENTS_DIR: + continue # skip builtins + agent_config_repo.save_sub_agent( + member_id, + agent_cfg.name, + description=agent_cfg.description, + model=agent_cfg.model, + tools=agent_cfg.tools, + system_prompt=agent_cfg.system_prompt, + ) + # Sync skills + for skill in bundle.skills: + skill_path = Path(skill.get("path", "")) + skill_md = skill_path / "SKILL.md" + content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else "" + agent_config_repo.save_skill(member_id, skill["name"], content) + except Exception: + logger.warning("Failed to sync config to repo for member %s", member_id, exc_info=True) + return get_member(member_id) +# ── Supabase repo dual-write helper ── + + +def _save_config_to_repo( + agent_config_repo: Any, + member_id: str, + *, + name: str, + description: str = "", + model: str | None = None, + tools: list[str] | None = None, + system_prompt: str = "", + status: str = "draft", + version: str = "0.1.0", + created_at: int = 0, + updated_at: int = 0, + runtime: dict | None = None, + mcp: dict | None = None, +) -> None: + """Save agent config to Supabase repo. Best-effort — logs errors but doesn't raise.""" + try: + agent_config_repo.save_config( + member_id, + { + "name": name, + "description": description, + "model": model, + "tools": tools or ["*"], + "system_prompt": system_prompt, + "status": status, + "version": version, + "created_at": created_at, + "updated_at": updated_at, + "runtime": runtime or {}, + "mcp": mcp or {}, + }, + ) + except Exception: + logger.warning("Failed to save config to repo for member %s", member_id, exc_info=True) + + # ── Write helpers for config fields → file structure ── @@ -678,7 +731,7 @@ def _write_mcps(member_dir: Path, mcps: list[dict[str, Any]]) -> None: # ── Publish / Delete ── -def publish_member(member_id: str, bump_type: str = "patch") -> dict[str, Any] | None: +def publish_member(member_id: str, bump_type: str = "patch", agent_config_repo: Any = None) -> dict[str, Any] | None: member_dir = MEMBERS_DIR / member_id if not member_dir.is_dir(): return None @@ -695,29 +748,47 @@ def publish_member(member_id: str, bump_type: str = "patch") -> dict[str, Any] | meta["status"] = "active" meta["updated_at"] = int(time.time() * 1000) _write_json(member_dir / "meta.json", meta) + + # Dual-write publish status to Supabase repo + if agent_config_repo: + try: + config = agent_config_repo.get_config(member_id) + if config: + agent_config_repo.save_config( + member_id, + { + **config, + "version": meta["version"], + "status": "active", + "updated_at": meta["updated_at"], + }, + ) + except Exception: + logger.warning("Failed to update repo for publish of %s", member_id, exc_info=True) + return get_member(member_id) -def delete_member(member_id: str, member_repo: Any = None) -> bool: +def delete_member(member_id: str, member_repo: Any = None, agent_config_repo: Any = None) -> bool: if member_id == "__leon__": return False member_dir = MEMBERS_DIR / member_id if not member_dir.is_dir(): return False + # Delete from Supabase repo before removing filesystem + if agent_config_repo: + try: + agent_config_repo.delete_config(member_id) + except Exception: + logger.warning("Failed to delete config from repo for %s", member_id, exc_info=True) + shutil.rmtree(member_dir) # Also remove from DB - if member_repo is not None: - member_repo.delete(member_id) - else: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() - try: - repo.delete(member_id) - finally: - repo.close() + if member_repo is None: + raise RuntimeError("member_repo is required to delete member") + member_repo.delete(member_id) return True @@ -740,10 +811,11 @@ def install_from_snapshot( owner_user_id: str, existing_member_id: str | None = None, member_repo: Any = None, + agent_config_repo: Any = None, ) -> str: """Create or update a local member from a marketplace snapshot.""" from storage.contracts import MemberRow, MemberType - from storage.providers.sqlite.member_repo import generate_member_id + from storage.utils import generate_member_id now = time.time() now_ms = int(now * 1000) @@ -843,15 +915,37 @@ def install_from_snapshot( owner_user_id=owner_user_id, created_at=now, ) - if member_repo is not None: - member_repo.create(row) - else: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() + if member_repo is None: + raise RuntimeError("member_repo is required to register new member from snapshot") + member_repo.create(row) + + # Dual-write to Supabase repo + if agent_config_repo: + _save_config_to_repo( + agent_config_repo, + member_id, + name=name, + description=description, + status=meta["status"], + version=meta["version"], + created_at=meta["created_at"], + updated_at=meta["updated_at"], + runtime=runtime_data if runtime_data else {}, + mcp=mcp_data if mcp_data else {}, + ) + # Sync rules from snapshot + for rule in snapshot.get("rules", []): + rule_name = _sanitize_name(rule.get("name", "default")) + try: + agent_config_repo.save_rule(member_id, f"{rule_name}.md", rule.get("content", "")) + except Exception: + logger.warning("Failed to save snapshot rule %s for member %s", rule_name, member_id, exc_info=True) + # Sync skills from snapshot + for skill in snapshot.get("skills", []): + skill_name = _sanitize_name(skill.get("name", "default")) try: - repo.create(row) - finally: - repo.close() + agent_config_repo.save_skill(member_id, skill_name, skill.get("content", "")) + except Exception: + logger.warning("Failed to save snapshot skill %s for member %s", skill_name, member_id, exc_info=True) return member_id diff --git a/backend/web/services/message_routing.py b/backend/web/services/message_routing.py index 7984e9552..d73dfef32 100644 --- a/backend/web/services/message_routing.py +++ b/backend/web/services/message_routing.py @@ -19,6 +19,7 @@ async def route_message_to_brain( sender_name: str | None = None, sender_avatar_url: str | None = None, attachments: list[str] | None = None, + message_metadata: dict[str, Any] | None = None, ) -> dict: """Route message to agent brain thread. @@ -26,6 +27,7 @@ async def route_message_to_brain( ACTIVE → enqueue as steer """ from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox + from backend.web.services.resource_cache import clear_monitor_resource_overview_cache from backend.web.services.streaming_service import start_agent_run sandbox_type = resolve_thread_sandbox(app, thread_id) @@ -71,7 +73,12 @@ async def route_message_to_brain( return {"status": "injected", "routing": "steer", "thread_id": thread_id} logger.debug("[route] → START RUN (idle→active)") meta = {"source": source, "sender_name": sender_name, "sender_avatar_url": sender_avatar_url} + if message_metadata: + meta.update(message_metadata) if attachments: meta["attachments"] = attachments run_id = start_agent_run(agent, thread_id, run_content, app, message_metadata=meta) + # @@@monitor-resource-cache-run-start - a fresh run can create or resume a lease immediately. + # Drop the cached monitor snapshot so the next /api/monitor/resources read reflects the live topology. + clear_monitor_resource_overview_cache() return {"status": "started", "routing": "direct", "run_id": run_id, "thread_id": thread_id} diff --git a/backend/web/services/monitor_service.py b/backend/web/services/monitor_service.py index 31f59b729..e813718a6 100644 --- a/backend/web/services/monitor_service.py +++ b/backend/web/services/monitor_service.py @@ -3,18 +3,29 @@ from __future__ import annotations import json +import re from datetime import UTC, datetime from typing import Any from backend.web.core.storage_factory import make_sandbox_monitor_repo from backend.web.services.sandbox_service import init_providers_and_managers, load_all_sessions +from storage.providers.sqlite.chat_session_repo import SQLiteChatSessionRepo from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path +from storage.providers.sqlite.lease_repo import SQLiteLeaseRepo # --------------------------------------------------------------------------- # Mapping helpers (private) # --------------------------------------------------------------------------- +def make_chat_session_repo() -> SQLiteChatSessionRepo: + return SQLiteChatSessionRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) + + +def make_lease_repo() -> SQLiteLeaseRepo: + return SQLiteLeaseRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) + + def _format_time_ago(iso_timestamp: str | None) -> str: if not iso_timestamp: return "never" @@ -75,6 +86,325 @@ def _lease_link(lease_id: str | None) -> dict[str, Any]: return {"lease_id": lease_id, "lease_url": f"/lease/{lease_id}" if lease_id else None} +LEASE_SEMANTIC_ORDER = [ + "orphan_diverged", + "diverged", + "orphan", + "healthy", +] + +LEASE_SEMANTIC_META = { + "orphan_diverged": { + "title": "Orphaned + Diverged", + "description": "Lease lost thread binding while desired and observed state still disagree.", + }, + "diverged": { + "title": "Diverged", + "description": "Lease is still attached to a thread, but runtime state has not converged.", + }, + "orphan": { + "title": "Orphans", + "description": "Lease has no active thread binding. Usually cleanup or historical residue.", + }, + "healthy": { + "title": "Healthy", + "description": "Lease has a thread binding and desired state matches observed state.", + }, +} + + +EVAL_NOTE_KEYS = [ + "runner", + "rc", + "sandbox", + "run_dir", + "stdout_log", + "stderr_log", +] + +LEASE_TRIAGE_ORDER = [ + "active_drift", + "detached_residue", + "orphan_cleanup", + "healthy_capacity", +] + +LEASE_TRIAGE_META = { + "active_drift": { + "title": "Active Drift", + "description": "Leases whose desired and observed state still disagree recently enough to warrant active operator attention.", + "tone": "warning", + }, + "detached_residue": { + "title": "Detached Residue", + "description": ( + "Leases still marked desired=running but observed=detached long after the runtime " + "stopped moving. Usually cleanup debt, not live pressure." + ), + "tone": "danger", + }, + "orphan_cleanup": { + "title": "Orphan Cleanup", + "description": "Lease rows that have already lost thread binding and mainly represent cleanup backlog or historical residue.", + "tone": "warning", + }, + "healthy_capacity": { + "title": "Healthy Capacity", + "description": "Leases with attached thread context and converged runtime state.", + "tone": "success", + }, +} + +DETACHED_RESIDUE_THRESHOLD_HOURS = 4.0 +RESOURCE_CLEANUP_ALLOWED_CATEGORIES = {"detached_residue", "orphan_cleanup"} +ACTIVE_CHAT_SESSION_STATUSES = {"active", "idle", "paused"} + + +def _classify_lease_semantics(*, thread_id: str | None, badge: dict[str, Any]) -> dict[str, str]: + is_orphan = not bool(thread_id) + is_converged = bool(badge.get("converged")) + if is_orphan and not is_converged: + category = "orphan_diverged" + elif not is_converged: + category = "diverged" + elif is_orphan: + category = "orphan" + else: + category = "healthy" + meta = LEASE_SEMANTIC_META[category] + return { + "category": category, + "title": meta["title"], + "description": meta["description"], + } + + +def _parse_local_timestamp(iso_timestamp: str | None) -> datetime | None: + if not iso_timestamp: + return None + cleaned = iso_timestamp + if "Z" in cleaned: + cleaned = cleaned.replace("Z", "") + if "+" in cleaned: + cleaned = cleaned.split("+")[0] + try: + return datetime.fromisoformat(cleaned) + except ValueError: + return None + + +def _hours_since(iso_timestamp: str | None) -> float | None: + dt = _parse_local_timestamp(iso_timestamp) + if dt is None: + return None + delta = datetime.now() - dt + return delta.total_seconds() / 3600 + + +def _classify_lease_triage( + *, + thread_id: str | None, + badge: dict[str, Any], + observed_state: str | None, + desired_state: str | None, + updated_at: str | None, +) -> dict[str, Any]: + observed = str(observed_state or "").strip().lower() or None + desired = str(desired_state or "").strip().lower() or None + age_hours = _hours_since(updated_at) + is_orphan = not bool(thread_id) + is_converged = bool(badge.get("converged")) + + if is_orphan: + key = "orphan_cleanup" + elif is_converged: + key = "healthy_capacity" + elif observed == "detached" and desired == "running" and age_hours is not None and age_hours >= DETACHED_RESIDUE_THRESHOLD_HOURS: + key = "detached_residue" + else: + key = "active_drift" + + meta = LEASE_TRIAGE_META[key] + return { + "category": key, + "title": meta["title"], + "description": meta["description"], + "tone": meta["tone"], + "age_hours": age_hours, + } + + +def _cleanable_lease_ids(lease_ids: list[str]) -> list[str]: + cleaned: list[str] = [] + seen: set[str] = set() + for raw in lease_ids: + lease_id = str(raw or "").strip() + if not lease_id or lease_id in seen: + continue + seen.add(lease_id) + cleaned.append(lease_id) + if not cleaned: + raise ValueError("lease_ids must contain at least one non-empty lease id") + return cleaned + + +def _triage_category_for_row(row: dict[str, Any]) -> str: + badge = _make_badge(row.get("desired_state"), row.get("observed_state")) + triage = _classify_lease_triage( + thread_id=row.get("thread_id"), + badge=badge, + observed_state=row.get("observed_state"), + desired_state=row.get("desired_state"), + updated_at=row.get("updated_at"), + ) + return str(triage["category"]) + + +def _extract_eval_note_value(notes: str, key: str) -> str | None: + match = re.search(rf"(?:^|[ |]){re.escape(key)}=([^ ]+)", notes) + if not match: + return None + return match.group(1).strip() + + +def build_evaluation_operator_surface( + *, + status: str, + notes: str, + score: dict[str, Any], + threads_total: int, + threads_running: int, + threads_done: int, +) -> dict[str, Any]: + extracted = {key: _extract_eval_note_value(notes, key) for key in EVAL_NOTE_KEYS} + rc_text = extracted.get("rc") + try: + rc = int(rc_text) if rc_text is not None else None + except ValueError: + rc = None + + scored = bool(score.get("scored")) + score_gate = str(score.get("score_gate") or "provisional") + artifacts = [ + { + "label": "Run directory", + "path": score.get("run_dir") or extracted.get("run_dir"), + }, + {"label": "Run manifest", "path": score.get("manifest_path")}, + {"label": "STDOUT log", "path": extracted.get("stdout_log")}, + {"label": "STDERR log", "path": extracted.get("stderr_log")}, + {"label": "Eval summary", "path": score.get("eval_summary_path")}, + {"label": "Trace summaries", "path": score.get("trace_summaries_path")}, + ] + artifacts = [ + { + **item, + "status": "present" if item["path"] else "missing", + } + for item in artifacts + ] + artifact_summary = { + "present": sum(1 for item in artifacts if item["status"] == "present"), + "missing": sum(1 for item in artifacts if item["status"] == "missing"), + "total": len(artifacts), + } + + facts = [ + {"label": "Status", "value": status}, + {"label": "Score gate", "value": score_gate}, + {"label": "Threads materialized", "value": str(threads_total)}, + {"label": "Threads running", "value": str(threads_running)}, + {"label": "Threads done", "value": str(threads_done)}, + ] + runner = extracted.get("runner") + if runner: + facts.append({"label": "Runner", "value": runner}) + if rc is not None: + facts.append({"label": "Exit code", "value": str(rc)}) + + kind = "collecting_runtime_evidence" + tone = "default" + headline = "Evaluation is still collecting runtime evidence." + summary = "Use the artifacts below to inspect progress and confirm whether thread rows are materializing." + next_steps = [ + "Open the run manifest to confirm the slice payload and output directory.", + "Inspect stdout/stderr before assuming the run is healthy.", + ] + + if status == "provisional" and not scored: + kind = "provisional_waiting_for_summary" + tone = "warning" + headline = "Evaluation is provisional. Final score is blocked." + summary = "This run has not produced the final eval summary yet, so publishable scoring is intentionally withheld." + next_steps = [ + "Check whether eval_summary_path is still missing because the run is ongoing or because the runner exited early.", + "Use stdout/stderr logs to confirm whether the solve phase actually started.", + ] + + if rc is not None and rc != 0 and threads_total == 0: + kind = "bootstrap_failure" + tone = "danger" + headline = "Runner exited before evaluation threads materialized." + summary = "Treat this as a bootstrap failure, not as an empty successful run. No evaluation thread rows were created." + next_steps = [ + "Inspect STDERR first to find the failing bootstrap step.", + "Use the run manifest and stdout log to confirm whether the slice was prepared before exit.", + "Re-run only after the failing dependency or model configuration is understood.", + ] + elif status == "running" and threads_total == 0 and threads_running > 0: + kind = "running_waiting_for_threads" + tone = "default" + headline = "Evaluation is actively running while thread rows catch up." + summary = ( + "The runner is alive, but thread rows have not materialized yet. Treat this as an ingestion lag window, not as an empty run." + ) + next_steps = [ + "Refresh after the first thread row materializes.", + "Use stdout/stderr to confirm the solve loop is still advancing.", + ] + elif status == "running": + kind = "running_active" + tone = "default" + headline = "Evaluation is actively running." + summary = "Thread rows and traces may lag behind the runner. Use live progress and logs before declaring drift." + next_steps = [ + "Refresh after new thread rows materialize.", + "Inspect traces only after the first active thread appears.", + ] + elif status == "completed_with_errors" and scored: + kind = "completed_with_errors" + tone = "warning" + headline = "Evaluation completed with recorded errors." + summary = ( + "Some thread rows reached completion, but at least one instance recorded an error. Treat this as reviewable but not clean." + ) + next_steps = [ + "Inspect error-bearing threads before comparing this run against cleaner baselines.", + "Use eval summary and trace summaries to isolate failing instances.", + ] + elif status == "completed" and scored: + kind = "completed_publishable" + tone = "success" + headline = "Evaluation finished with a publishable score surface." + summary = "Score artifacts are present. Use the thread table to drill into trace-level evidence." + next_steps = [ + "Open threads with low-quality traces and inspect tool-call detail.", + "Use the eval summary and trace summaries to compare runs.", + ] + + return { + "kind": kind, + "tone": tone, + "headline": headline, + "summary": summary, + "facts": facts, + "artifacts": artifacts, + "artifact_summary": artifact_summary, + "next_steps": next_steps, + "raw_notes": notes, + } + + # --------------------------------------------------------------------------- # Mappers (private) # --------------------------------------------------------------------------- @@ -130,21 +460,82 @@ def _map_thread_detail(thread_id: str, sessions: list[dict[str, Any]]) -> dict[s def _map_leases(rows: list[dict[str, Any]]) -> dict[str, Any]: - items = [ - { - "lease_id": row["lease_id"], - "lease_url": f"/lease/{row['lease_id']}", - "provider": row["provider_name"], - "instance_id": row["current_instance_id"], - "thread": _thread_ref(row["thread_id"]), - "state_badge": _make_badge(row["desired_state"], row["observed_state"]), - "error": row["last_error"], - "updated_at": row["updated_at"], - "updated_ago": _format_time_ago(row["updated_at"]), - } - for row in rows - ] - return {"title": "All Leases", "count": len(items), "items": items} + items = [] + for row in rows: + badge = _make_badge(row["desired_state"], row["observed_state"]) + triage = _classify_lease_triage( + thread_id=row["thread_id"], + badge=badge, + observed_state=row["observed_state"], + desired_state=row["desired_state"], + updated_at=row["updated_at"], + ) + items.append( + { + "lease_id": row["lease_id"], + "lease_url": f"/lease/{row['lease_id']}", + "provider": row["provider_name"], + "instance_id": row["current_instance_id"], + "thread": _thread_ref(row["thread_id"]), + "state_badge": badge, + "semantics": _classify_lease_semantics(thread_id=row["thread_id"], badge=badge), + "triage": triage, + "error": row["last_error"], + "updated_at": row["updated_at"], + "updated_ago": _format_time_ago(row["updated_at"]), + } + ) + + summary = {key: 0 for key in LEASE_SEMANTIC_ORDER} + for item in items: + summary[item["semantics"]["category"]] += 1 + summary["total"] = len(items) + + groups = [] + for key in LEASE_SEMANTIC_ORDER: + meta = LEASE_SEMANTIC_META[key] + group_items = [item for item in items if item["semantics"]["category"] == key] + groups.append( + { + "key": key, + "title": meta["title"], + "description": meta["description"], + "count": len(group_items), + "items": group_items, + } + ) + + triage_summary = {key: 0 for key in LEASE_TRIAGE_ORDER} + for item in items: + triage_summary[item["triage"]["category"]] += 1 + triage_summary["total"] = len(items) + + triage_groups = [] + for key in LEASE_TRIAGE_ORDER: + meta = LEASE_TRIAGE_META[key] + group_items = [item for item in items if item["triage"]["category"] == key] + triage_groups.append( + { + "key": key, + "title": meta["title"], + "description": meta["description"], + "tone": meta["tone"], + "count": len(group_items), + "items": group_items, + } + ) + + return { + "title": "All Leases", + "count": len(items), + "summary": summary, + "groups": groups, + "triage": { + "summary": triage_summary, + "groups": triage_groups, + }, + "items": items, + } def _map_lease_detail( @@ -192,6 +583,47 @@ def _map_lease_detail( } +def _historical_lease_detail( + lease_id: str, + sessions: list[dict[str, Any]], + events: list[dict[str, Any]], +) -> dict[str, Any] | None: + if not sessions and not events: + return None + + created_candidates = [ + str(value) for value in [*(row.get("started_at") for row in sessions), *(row.get("created_at") for row in events)] if value + ] + updated_candidates = [ + str(value) + for value in [ + *(row.get("ended_at") or row.get("started_at") for row in sessions), + *(row.get("created_at") for row in events), + ] + if value + ] + first_session = sessions[0] if sessions else {} + thread_ids: list[str] = [] + seen_threads: set[str] = set() + for row in sessions: + thread_id = str(row.get("thread_id") or "").strip() + if thread_id and thread_id not in seen_threads: + seen_threads.add(thread_id) + thread_ids.append(thread_id) + + lease = { + "provider_name": first_session.get("provider_name") or "unknown", + "current_instance_id": first_session.get("current_instance_id"), + "created_at": min(created_candidates) if created_candidates else None, + "updated_at": max(updated_candidates) if updated_candidates else None, + "desired_state": first_session.get("desired_state"), + "observed_state": first_session.get("observed_state"), + "last_error": first_session.get("last_error"), + } + threads = [{"thread_id": thread_id} for thread_id in thread_ids] + return _map_lease_detail(lease_id, lease, threads, events) + + def _map_diverged(rows: list[dict[str, Any]]) -> dict[str, Any]: items = [ { @@ -297,16 +729,152 @@ def list_leases() -> dict[str, Any]: repo.close() +def cleanup_resource_leases( + *, + action: str, + lease_ids: list[str], + expected_category: str, +) -> dict[str, Any]: + if action != "cleanup_residue": + raise ValueError(f"Unsupported cleanup action: {action}") + if expected_category not in RESOURCE_CLEANUP_ALLOWED_CATEGORIES: + raise ValueError("expected_category must be one of: detached_residue, orphan_cleanup") + + target_lease_ids = _cleanable_lease_ids(lease_ids) + monitor_repo = make_sandbox_monitor_repo() + lease_repo = make_lease_repo() + chat_session_repo = make_chat_session_repo() + try: + rows_by_id = {str(row.get("lease_id") or ""): row for row in monitor_repo.query_leases() if row.get("lease_id")} + providers, _ = init_providers_and_managers() + cleaned: list[dict[str, Any]] = [] + skipped: list[str] = [] + errors: list[dict[str, Any]] = [] + + for lease_id in target_lease_ids: + row = rows_by_id.get(lease_id) + if row is None: + skipped.append(lease_id) + errors.append({"lease_id": lease_id, "reason": "lease_not_found"}) + continue + + actual_category = _triage_category_for_row(row) + if actual_category != expected_category: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "category_mismatch", + "expected_category": expected_category, + "actual_category": actual_category, + } + ) + continue + + sessions = monitor_repo.query_lease_sessions(lease_id) + live_session_ids = [ + str(session.get("chat_session_id")) + for session in sessions + if str(session.get("status") or "").strip().lower() in ACTIVE_CHAT_SESSION_STATUSES + ] + if live_session_ids: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "live_sessions_present", + "session_ids": live_session_ids, + } + ) + continue + + if chat_session_repo.lease_has_running_command(lease_id): + skipped.append(lease_id) + errors.append({"lease_id": lease_id, "reason": "running_command_present"}) + continue + + provider_name = str(row.get("provider_name") or "").strip() + instance_id = str(row.get("current_instance_id") or "").strip() or None + if instance_id: + provider = providers.get(provider_name) + if provider is None: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "provider_unavailable", + "provider": provider_name, + } + ) + continue + if not provider.get_capability().can_destroy: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "provider_destroy_unsupported", + "provider": provider_name, + } + ) + continue + try: + destroyed = provider.destroy_session(instance_id, sync=True) + except Exception as exc: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "provider_destroy_failed", + "provider": provider_name, + "detail": str(exc), + } + ) + continue + if not destroyed: + skipped.append(lease_id) + errors.append( + { + "lease_id": lease_id, + "reason": "provider_destroy_failed", + "provider": provider_name, + "detail": "destroy_session returned false", + } + ) + continue + + lease_repo.delete(lease_id) + cleaned.append({"lease_id": lease_id, "category": actual_category}) + + refreshed_summary = list_leases()["triage"]["summary"] + return { + "action": action, + "expected_category": expected_category, + "attempted": target_lease_ids, + "cleaned": cleaned, + "skipped": skipped, + "errors": errors, + "refreshed_summary": refreshed_summary, + } + finally: + chat_session_repo.close() + lease_repo.close() + monitor_repo.close() + + def get_lease(lease_id: str) -> dict[str, Any]: repo = make_sandbox_monitor_repo() try: lease = repo.query_lease(lease_id) - if not lease: - raise KeyError("Lease not found") threads = repo.query_lease_threads(lease_id) events = repo.query_lease_events(lease_id) + sessions = repo.query_lease_sessions(lease_id) finally: repo.close() + if not lease: + fallback = _historical_lease_detail(lease_id, sessions, events) + if fallback: + return fallback + raise KeyError("Lease not found") return _map_lease_detail(lease_id, lease, threads, events) diff --git a/backend/web/services/profile_service.py b/backend/web/services/profile_service.py index c6b755bde..60359431a 100644 --- a/backend/web/services/profile_service.py +++ b/backend/web/services/profile_service.py @@ -1,10 +1,11 @@ -"""Profile CRUD — config.json based.""" +"""Profile CRUD — config.json based, with auth-member override for signed-in shell.""" import json from pathlib import Path from typing import Any from config.user_paths import preferred_existing_user_home_path, user_home_path +from storage.contracts import MemberRow LEON_HOME = user_home_path() CONFIG_PATH = LEON_HOME / "config.json" @@ -24,7 +25,23 @@ def _write_json(path: Path, data: Any) -> None: path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") -def get_profile() -> dict[str, Any]: +def _initials_from_name(name: str) -> str: + stripped = name.strip() + if not stripped: + return "U" + compact = "".join(part[:1] for part in stripped.split() if part) + if len(compact) >= 2: + return compact[:2].upper() + return stripped[:2].upper() + + +def get_profile(member: MemberRow | None = None) -> dict[str, Any]: + if member is not None: + return { + "name": member.name or "用户", + "initials": _initials_from_name(member.name or ""), + "email": member.email or "", + } cfg = _read_json(preferred_existing_user_home_path("config.json"), {}) profile = cfg.get("profile", {}) return { diff --git a/backend/web/services/resource_cache.py b/backend/web/services/resource_cache.py index 4b1d5f5fe..62846a653 100644 --- a/backend/web/services/resource_cache.py +++ b/backend/web/services/resource_cache.py @@ -10,7 +10,7 @@ from datetime import UTC, datetime from typing import Any -from backend.web.services import resource_service +from backend.web.services import monitor_service, resource_service _DEFAULT_REFRESH_INTERVAL_SEC = 90.0 @@ -24,6 +24,10 @@ def clear_resource_overview_cache() -> None: _snapshot_cache = None +def clear_monitor_resource_overview_cache() -> None: + clear_resource_overview_cache() + + def _now_iso() -> str: return datetime.now(UTC).isoformat().replace("+00:00", "Z") @@ -55,12 +59,37 @@ def _with_refresh_metadata( return payload +def _attach_monitor_triage(payload: dict[str, Any]) -> dict[str, Any]: + lease_payload = monitor_service.list_leases() + triage = lease_payload.get("triage") or {"summary": {}, "groups": []} + payload["triage"] = triage + return payload + + +def _snapshot_drifted_from_live_sessions(snapshot: dict[str, Any]) -> bool: + live_stats = resource_service.visible_resource_session_stats() + for provider in snapshot.get("providers") or []: + provider_id = str(provider.get("id") or "") + current = live_stats.get(provider_id, {"sessions": 0, "running": 0}) + cached_running = int(((provider.get("telemetry") or {}).get("running") or {}).get("used") or 0) + cached_sessions = len(provider.get("sessions") or []) + if cached_running != current["running"] or cached_sessions != current["sessions"]: + return True + for provider_id, current in live_stats.items(): + if current["running"] or current["sessions"]: + cached = next((item for item in snapshot.get("providers") or [] if str(item.get("id") or "") == provider_id), None) + if cached is None: + return True + return False + + def refresh_resource_overview_sync() -> dict[str, Any]: """Refresh cached overview snapshot and return latest payload.""" global _snapshot_cache started = time.perf_counter() try: payload = resource_service.list_resource_providers() + payload = _attach_monitor_triage(payload) duration_ms = (time.perf_counter() - started) * 1000 payload = _with_refresh_metadata(payload, duration_ms=duration_ms, status="ok", error=None) with _snapshot_lock: @@ -79,16 +108,29 @@ def refresh_resource_overview_sync() -> dict[str, Any]: return degraded +def refresh_monitor_resource_overview_sync() -> dict[str, Any]: + return refresh_resource_overview_sync() + + def get_resource_overview_snapshot() -> dict[str, Any]: """Return cached snapshot; perform one synchronous refresh on cold start.""" with _snapshot_lock: cached = copy.deepcopy(_snapshot_cache) if cached is not None: + # @@@resource-cache-live-drift - durable session truth lands in sandbox.db after a run + # starts; if the cached Resources snapshot no longer matches visible lease/session + # counts, refresh synchronously instead of serving a stale zero-sandbox card. + if _snapshot_drifted_from_live_sessions(cached): + return refresh_resource_overview_sync() return cached # @@@cold-start-cache-fill - route fallback fills cache once to keep first call deterministic. return refresh_resource_overview_sync() +def get_monitor_resource_overview_snapshot() -> dict[str, Any]: + return get_resource_overview_snapshot() + + async def resource_overview_refresh_loop() -> None: """Continuously refresh resource overview snapshot.""" interval_sec = _read_refresh_interval_sec() @@ -116,3 +158,7 @@ async def resource_overview_refresh_loop() -> None: print("[monitor] resource refresh loop timeout") except Exception as exc: print(f"[monitor] resource refresh loop error: {exc}") + + +async def monitor_resource_overview_refresh_loop() -> None: + await resource_overview_refresh_loop() diff --git a/backend/web/services/resource_projection_service.py b/backend/web/services/resource_projection_service.py new file mode 100644 index 000000000..41f3f1327 --- /dev/null +++ b/backend/web/services/resource_projection_service.py @@ -0,0 +1,119 @@ +"""User-visible resource projection service.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from backend.web.services import resource_service, sandbox_service +from sandbox.provider import RESOURCE_CAPABILITY_KEYS +from storage.models import map_lease_to_session_status + + +def _now_iso() -> str: + return datetime.now(UTC).isoformat().replace("+00:00", "Z") + + +def _empty_metric(unit: str) -> dict[str, Any]: + return { + "used": None, + "limit": None, + "unit": unit, + "source": "unknown", + "freshness": "stale", + } + + +def _empty_capabilities() -> dict[str, bool]: + return {key: False for key in RESOURCE_CAPABILITY_KEYS} + + +def _build_provider_card(config_name: str, leases: list[dict[str, Any]]) -> dict[str, Any]: + display = resource_service.get_provider_display_contract(config_name) + capabilities, capability_error = resource_service.get_provider_capability_contract(config_name) + provider_type = str(display["type"]) + + sessions: list[dict[str, Any]] = [] + running_count = 0 + for lease in leases: + thread_id = str((lease.get("thread_ids") or [None])[0] or "") + owner = (lease.get("agents") or [{}])[0] + status = map_lease_to_session_status(lease.get("observed_state"), lease.get("desired_state")) + if status == "running": + running_count += 1 + sessions.append( + resource_service.build_resource_session_payload( + session_identity=f"{lease['lease_id']}:{thread_id}", + lease_id=str(lease["lease_id"]), + thread_id=thread_id, + owner=owner, + status=status, + started_at=str(lease.get("created_at") or ""), + metrics=None, + ) + ) + + telemetry = { + "running": { + "used": running_count, + "limit": None, + "unit": "sandbox", + "source": "derived", + "freshness": "live", + }, + "cpu": _empty_metric("%"), + "memory": _empty_metric("GB"), + "disk": _empty_metric("GB"), + } + availability = resource_service.build_provider_availability_payload( + available=capability_error is None, + running_count=running_count, + unavailable_reason=capability_error, + ) + + return { + "id": config_name, + "name": config_name, + "description": display["description"], + "vendor": display["vendor"], + "type": provider_type, + **availability, + "capabilities": capabilities, + "telemetry": telemetry, + "cardCpu": dict(telemetry["cpu"]), + "consoleUrl": display["console_url"], + "sessions": sessions, + } + + +def list_user_resource_providers(app: Any, owner_user_id: str) -> dict[str, Any]: + thread_repo = getattr(app.state, "thread_repo", None) + member_repo = getattr(app.state, "member_repo", None) + if thread_repo is None or member_repo is None: + raise RuntimeError("thread_repo and member_repo are required") + + leases = sandbox_service.list_user_leases( + owner_user_id, + thread_repo=thread_repo, + member_repo=member_repo, + ) + + leases_by_provider: dict[str, list[dict[str, Any]]] = {} + for lease in leases: + config_name = str(lease.get("provider_name") or "local") + leases_by_provider.setdefault(config_name, []).append(lease) + + providers = [_build_provider_card(config_name, provider_leases) for config_name, provider_leases in sorted(leases_by_provider.items())] + + return { + "summary": { + "snapshot_at": _now_iso(), + "total_providers": len(providers), + "active_providers": len([item for item in providers if item["status"] == "active"]), + "unavailable_providers": len([item for item in providers if item["status"] == "unavailable"]), + "running_sessions": sum(int(item["telemetry"]["running"]["used"] or 0) for item in providers), + "scope": "user", + "lease_count": len(leases), + }, + "providers": providers, + } diff --git a/backend/web/services/resource_service.py b/backend/web/services/resource_service.py index 236db63ab..58a58d8f6 100644 --- a/backend/web/services/resource_service.py +++ b/backend/web/services/resource_service.py @@ -8,7 +8,7 @@ from typing import Any from backend.web.core.config import SANDBOXES_DIR -from backend.web.core.storage_factory import list_resource_snapshots, make_sandbox_monitor_repo, upsert_resource_snapshot +from backend.web.core.storage_factory import list_resource_snapshots, make_sandbox_monitor_repo from backend.web.services.config_loader import SandboxConfigLoader from backend.web.services.sandbox_service import available_sandbox_types, build_provider_from_config_name from backend.web.utils.serializers import avatar_url @@ -23,6 +23,7 @@ probe_and_upsert_for_instance, ) from storage.models import map_lease_to_session_status +from storage.runtime import build_member_repo, build_resource_snapshot_repo, build_thread_repo _CONFIG_LOADER = SandboxConfigLoader(SANDBOXES_DIR) @@ -72,7 +73,8 @@ def _resolve_console_url(provider_name: str, config_name: str, *, sandboxes_dir: if provider_name == "e2b": return "https://e2b.dev" if provider_name == "daytona": - daytona = payload.get("daytona") if isinstance(payload.get("daytona"), dict) else {} + raw_daytona = payload.get("daytona") + daytona = raw_daytona if isinstance(raw_daytona, dict) else {} target = str(daytona.get("target") or "").strip().lower() if target == "cloud": return "https://app.daytona.io" @@ -81,6 +83,18 @@ def _resolve_console_url(provider_name: str, config_name: str, *, sandboxes_dir: return None +def get_provider_display_contract(config_name: str) -> dict[str, Any]: + provider_name = resolve_provider_name(config_name, sandboxes_dir=SANDBOXES_DIR) + catalog = _CATALOG.get(provider_name) or _CatalogEntry(vendor=None, description=provider_name, provider_type="cloud") + return { + "provider_name": provider_name, + "description": catalog.description, + "vendor": catalog.vendor, + "type": _resolve_provider_type(provider_name, config_name, sandboxes_dir=SANDBOXES_DIR), + "console_url": _resolve_console_url(provider_name, config_name, sandboxes_dir=SANDBOXES_DIR), + } + + # --------------------------------------------------------------------------- # Capability helpers # --------------------------------------------------------------------------- @@ -102,6 +116,13 @@ def _resolve_instance_capabilities(config_name: str) -> tuple[dict[str, bool], s return {key: normalized[key] for key in RESOURCE_CAPABILITY_KEYS}, None +def get_provider_capability_contract(config_name: str) -> tuple[dict[str, bool], str | None]: + capabilities, capability_error = _resolve_instance_capabilities(config_name) + if capability_error: + return _empty_capabilities(), capability_error + return capabilities, None + + # --------------------------------------------------------------------------- # Status/metric helpers # --------------------------------------------------------------------------- @@ -113,6 +134,14 @@ def _to_resource_status(available: bool, running_count: int) -> str: return "active" if running_count > 0 else "ready" +def build_provider_availability_payload(*, available: bool, running_count: int, unavailable_reason: str | None) -> dict[str, Any]: + return { + "status": _to_resource_status(available, running_count), + "unavailableReason": unavailable_reason, + "error": ({"code": "PROVIDER_UNAVAILABLE", "message": unavailable_reason} if unavailable_reason else None), + } + + def _to_metric_freshness(collected_at: str | None) -> str: if not collected_at: return "stale" @@ -216,17 +245,13 @@ def _to_session_metrics(snapshot: dict[str, Any] | None) -> dict[str, Any] | Non def _member_meta_map(member_repo: Any = None) -> dict[str, dict[str, str | None]]: """Build member_id → display metadata map from DB.""" + repo = member_repo + own_repo = False + if repo is None: + repo = build_member_repo() + own_repo = True try: - if member_repo is not None: - members = member_repo.list_all() - else: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() - try: - members = repo.list_all() - finally: - repo.close() + members = repo.list_all() return { m.id: { "member_name": m.name, @@ -237,6 +262,9 @@ def _member_meta_map(member_repo: Any = None) -> dict[str, dict[str, str | None] } except Exception: return {} + finally: + if own_repo: + repo.close() def _thread_agent_refs(thread_ids: list[str], thread_repo: Any = None) -> dict[str, str]: @@ -244,14 +272,11 @@ def _thread_agent_refs(thread_ids: list[str], thread_repo: Any = None) -> dict[s unique = sorted({tid for tid in thread_ids if tid}) if not unique: return {} - if thread_repo is None: - from storage.providers.sqlite.thread_repo import SQLiteThreadRepo - - repo = SQLiteThreadRepo() + repo = thread_repo + own_repo = False + if repo is None: + repo = build_thread_repo() own_repo = True - else: - repo = thread_repo - own_repo = False try: refs: dict[str, str] = {} for tid in unique: @@ -350,6 +375,92 @@ def _resolve_card_cpu_metric(provider_type: str, telemetry: dict[str, Any]) -> d return cpu +def _is_resource_visible_thread(thread_id: str | None) -> bool: + raw = str(thread_id or "").strip() + if raw.startswith("subagent-"): + return False + return True + + +def _resource_session_identity(session: dict[str, Any]) -> str: + lease_id = str(session.get("lease_id") or "") + thread_id = str(session.get("thread_id") or "") + if lease_id and thread_id: + # @@@resource-session-contract - resource cards are lease/thread scoped, not chat-session scoped. + # Terminal fallback rows can carry distinct session ids for the same visible lease+thread binding. + return f"{lease_id}:{thread_id}" + session_id = str(session.get("session_id") or "") + if session_id: + return session_id + return f"{lease_id}:{thread_id or 'unbound'}" + + +def build_resource_session_payload( + *, + session_identity: str, + lease_id: str, + thread_id: str, + owner: dict[str, Any], + status: str, + started_at: str, + metrics: dict[str, Any] | None, +) -> dict[str, Any]: + return { + "id": session_identity, + "leaseId": lease_id, + "threadId": thread_id, + "memberId": str(owner.get("member_id") or ""), + "memberName": str(owner.get("member_name") or "未绑定Agent"), + "avatarUrl": owner.get("avatar_url"), + "status": status, + "startedAt": started_at, + "metrics": metrics, + } + + +def _project_user_visible_resource_sessions(repo: Any, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Project raw monitor rows into the user-visible resource surface. + + @@@user-visible-resource-projection - raw monitor rows may be bound to a newer + subagent terminal even though the lease still belongs to a user-visible parent + thread. Keep raw monitor truth in the repo; only the Resources UI gets this + parent-thread preference. + """ + grouped: dict[str, list[dict[str, Any]]] = {} + for row in rows: + lease_id = str(row.get("lease_id") or "") + grouped.setdefault(lease_id, []).append(dict(row)) + + projected: list[dict[str, Any]] = [] + for lease_id, group in grouped.items(): + visible_rows = [row for row in group if _is_resource_visible_thread(row.get("thread_id"))] + if visible_rows: + projected.extend(visible_rows) + continue + + if not lease_id: + continue + + try: + thread_rows = repo.query_lease_threads(lease_id) + except Exception: + thread_rows = [] + + preferred_thread_id = next( + (str(item.get("thread_id") or "").strip() for item in thread_rows if _is_resource_visible_thread(item.get("thread_id"))), + "", + ) + if not preferred_thread_id: + continue + + base = dict(group[0]) + base["thread_id"] = preferred_thread_id + base["session_id"] = None + projected.append(base) + + return projected + + # --------------------------------------------------------------------------- # Public API: resource overview # --------------------------------------------------------------------------- @@ -359,7 +470,8 @@ def list_resource_providers() -> dict[str, Any]: # @@@overview-fast-path - avoid provider-network calls; overview uses DB session snapshot. repo = make_sandbox_monitor_repo() try: - sessions = repo.list_sessions_with_leases() + raw_sessions = repo.list_sessions_with_leases() + sessions = _project_user_visible_resource_sessions(repo, raw_sessions) finally: repo.close() @@ -376,9 +488,8 @@ def list_resource_providers() -> dict[str, Any]: for item in available_sandbox_types(): config_name = str(item["name"]) available = bool(item.get("available")) - provider_name = resolve_provider_name(config_name, sandboxes_dir=SANDBOXES_DIR) - catalog = _CATALOG.get(provider_name) or _CatalogEntry(vendor=None, description=provider_name, provider_type="cloud") - capabilities, capability_error = _resolve_instance_capabilities(config_name) + display = get_provider_display_contract(config_name) + capabilities, capability_error = get_provider_capability_contract(config_name) effective_available = available and capability_error is None unavailable_reason: str | None = None if not effective_available: @@ -386,6 +497,7 @@ def list_resource_providers() -> dict[str, Any]: provider_sessions = grouped.get(config_name, []) normalized_sessions: list[dict[str, Any]] = [] + seen_session_ids: set[str] = set() running_count = 0 # @@@running-dedup - lease-driven query may yield multiple rows per lease (one per crew member). # Count each running lease only once. @@ -402,23 +514,26 @@ def list_resource_providers() -> dict[str, Any]: seen_running_leases.add(lease_id) session_metrics = _to_session_metrics(snapshot_by_lease.get(lease_id)) owner = owners.get(thread_id, {"member_id": None, "member_name": "未绑定Agent"}) + session_identity = _resource_session_identity(session) + # @@@resource-session-dedup - terminal fallback can surface multiple + # monitor rows for the same lease/thread binding. The overview + # contract is one session row per stable session identity. + if session_identity in seen_session_ids: + continue + seen_session_ids.add(session_identity) normalized_sessions.append( - { - # @@@resource-session-identity - monitor rows can legitimately have empty chat session ids. - # Use stable lease+thread identity so React keys do not collapse when one lease has multiple threads. - "id": str(session.get("session_id") or f"{lease_id}:{thread_id or 'unbound'}"), - "leaseId": lease_id, - "threadId": thread_id, - "memberId": str(owner.get("member_id") or ""), - "memberName": str(owner.get("member_name") or "未绑定Agent"), - "avatarUrl": owner.get("avatar_url"), - "status": normalized, - "startedAt": str(session.get("created_at") or ""), - "metrics": session_metrics, - } + build_resource_session_payload( + session_identity=session_identity, + lease_id=lease_id, + thread_id=thread_id, + owner=owner, + status=normalized, + started_at=str(session.get("created_at") or ""), + metrics=session_metrics, + ) ) - provider_type = _resolve_provider_type(provider_name, config_name, sandboxes_dir=SANDBOXES_DIR) + provider_type = str(display["type"]) telemetry = _aggregate_provider_telemetry( provider_sessions=provider_sessions, running_count=running_count, @@ -441,20 +556,23 @@ def list_resource_providers() -> dict[str, Any]: ), "disk": _metric(host_m.disk_used_gb, host_m.disk_total_gb, "GB", "direct", "live"), } + availability = build_provider_availability_payload( + available=effective_available, + running_count=running_count, + unavailable_reason=unavailable_reason, + ) providers.append( { "id": config_name, "name": config_name, - "description": catalog.description, - "vendor": catalog.vendor, + "description": display["description"], + "vendor": display["vendor"], "type": provider_type, - "status": _to_resource_status(effective_available, running_count), - "unavailableReason": unavailable_reason, - "error": ({"code": "PROVIDER_UNAVAILABLE", "message": unavailable_reason} if unavailable_reason else None), + **availability, "capabilities": capabilities, "telemetry": telemetry, "cardCpu": _resolve_card_cpu_metric(provider_type, telemetry), - "consoleUrl": _resolve_console_url(provider_name, config_name, sandboxes_dir=SANDBOXES_DIR), + "consoleUrl": display["console_url"], "sessions": normalized_sessions, } ) @@ -469,6 +587,36 @@ def list_resource_providers() -> dict[str, Any]: return {"summary": summary, "providers": providers} +def visible_resource_session_stats() -> dict[str, dict[str, int]]: + """Return the current user-visible session/running counts per provider.""" + repo = make_sandbox_monitor_repo() + try: + raw_sessions = repo.list_sessions_with_leases() + sessions = _project_user_visible_resource_sessions(repo, raw_sessions) + finally: + repo.close() + + stats: dict[str, dict[str, int]] = {} + seen_session_ids: set[str] = set() + seen_running_leases: set[tuple[str, str]] = set() + for session in sessions: + provider_instance = str(session.get("provider") or "local") + provider_stats = stats.setdefault(provider_instance, {"sessions": 0, "running": 0}) + session_identity = _resource_session_identity(session) + if session_identity not in seen_session_ids: + seen_session_ids.add(session_identity) + provider_stats["sessions"] += 1 + + lease_id = str(session.get("lease_id") or "") + normalized = map_lease_to_session_status(session.get("observed_state"), session.get("desired_state")) + running_identity = (provider_instance, lease_id) + if normalized == "running" and lease_id and running_identity not in seen_running_leases: + seen_running_leases.add(running_identity) + provider_stats["running"] += 1 + + return stats + + # --------------------------------------------------------------------------- # Public API: sandbox filesystem browse # --------------------------------------------------------------------------- @@ -576,6 +724,7 @@ def refresh_resource_snapshots() -> dict[str, Any]: probe_targets = repo.list_probe_targets() finally: repo.close() + snapshot_repo = build_resource_snapshot_repo() provider_cache: dict[str, Any] = {} probed = 0 @@ -583,44 +732,48 @@ def refresh_resource_snapshots() -> dict[str, Any]: running_targets = 0 non_running_targets = 0 - for item in probe_targets: - lease_id = item["lease_id"] - provider_key = item["provider_name"] - instance_id = item["instance_id"] - status = item["observed_state"] - # detached means running (not connected to terminal) - probe_mode = "running_runtime" if status in ("running", "detached") else "non_running_sdk" - if probe_mode == "running_runtime": - running_targets += 1 - else: - non_running_targets += 1 - - provider = provider_cache.get(provider_key) - if provider is None: - provider = build_provider_from_config_name(provider_key) - provider_cache[provider_key] = provider - if provider is None: - upsert_resource_snapshot( + try: + for item in probe_targets: + lease_id = item["lease_id"] + provider_key = item["provider_name"] + instance_id = item["instance_id"] + status = item["observed_state"] + # detached means running (not connected to terminal) + probe_mode = "running_runtime" if status in ("running", "detached") else "non_running_sdk" + if probe_mode == "running_runtime": + running_targets += 1 + else: + non_running_targets += 1 + + provider = provider_cache.get(provider_key) + if provider is None: + provider = build_provider_from_config_name(provider_key) + provider_cache[provider_key] = provider + if provider is None: + snapshot_repo.upsert_lease_resource_snapshot( + lease_id=lease_id, + provider_name=provider_key, + observed_state=status, + probe_mode=probe_mode, + probe_error=f"provider init failed: {provider_key}", + ) + errors += 1 + continue + + result = probe_and_upsert_for_instance( lease_id=lease_id, provider_name=provider_key, observed_state=status, probe_mode=probe_mode, - probe_error=f"provider init failed: {provider_key}", + provider=provider, + instance_id=instance_id, + repo=snapshot_repo, ) - errors += 1 - continue - - result = probe_and_upsert_for_instance( - lease_id=lease_id, - provider_name=provider_key, - observed_state=status, - probe_mode=probe_mode, - provider=provider, - instance_id=instance_id, - ) - probed += 1 - if not result["ok"]: - errors += 1 + probed += 1 + if not result["ok"]: + errors += 1 + finally: + snapshot_repo.close() return { "probed": probed, diff --git a/backend/web/services/sandbox_service.py b/backend/web/services/sandbox_service.py index 2e5e06cf0..4076bd280 100644 --- a/backend/web/services/sandbox_service.py +++ b/backend/web/services/sandbox_service.py @@ -16,9 +16,8 @@ from sandbox.manager import SandboxManager from sandbox.provider import ProviderCapability from sandbox.recipes import default_recipe_id, list_builtin_recipes, normalize_recipe_snapshot, provider_type_from_name +from storage.models import map_lease_to_session_status from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path -from storage.providers.sqlite.member_repo import SQLiteMemberRepo -from storage.providers.sqlite.thread_repo import SQLiteThreadRepo logger = logging.getLogger(__name__) @@ -55,9 +54,11 @@ def list_user_leases( sandbox_db_path: str | Path | None = None, ) -> list[dict[str, Any]]: monitor_repo = make_sandbox_monitor_repo() - _thread_repo = thread_repo or SQLiteThreadRepo(db_path=main_db_path) - _member_repo = member_repo or SQLiteMemberRepo(db_path=main_db_path) - own_repos = thread_repo is None # only close if we created them + if thread_repo is None or member_repo is None: + raise RuntimeError("thread_repo and member_repo are required for list_user_leases") + _thread_repo = thread_repo + _member_repo = member_repo + own_repos = False try: rows = monitor_repo.list_leases_with_threads() grouped: dict[str, dict[str, Any]] = {} @@ -74,13 +75,15 @@ def list_user_leases( "recipe": row.get("recipe_json"), "observed_state": row.get("observed_state"), "desired_state": row.get("desired_state"), + "created_at": row.get("created_at"), "cwd": row.get("cwd"), "thread_ids": [], "agents": [], + "_seen_member_ids": set(), }, ) thread_id = str(row.get("thread_id") or "").strip() - if not thread_id or thread_id in group["thread_ids"]: + if not _is_user_visible_lease_thread(thread_id) or thread_id in group["thread_ids"]: continue thread = _thread_repo.get_by_id(thread_id) if thread is None: @@ -89,20 +92,25 @@ def list_user_leases( if member is None or member.owner_user_id != user_id: continue group["thread_ids"].append(thread_id) - group["agents"].append( - { - "member_id": member.id, - "member_name": member.name, - "avatar_url": avatar_url(member.id, bool(member.avatar)), - } - ) + if member.id not in group["_seen_member_ids"]: + group["_seen_member_ids"].add(member.id) + group["agents"].append( + { + "member_id": member.id, + "member_name": member.name, + "avatar_url": avatar_url(member.id, bool(member.avatar)), + } + ) if not group["cwd"] and row.get("cwd"): group["cwd"] = row.get("cwd") leases: list[dict[str, Any]] = [] for lease in grouped.values(): + lease.pop("_seen_member_ids", None) if not lease["thread_ids"]: continue + if not _is_user_visible_lease_state(lease): + continue provider_name = lease["provider_name"] provider_type = provider_type_from_name(provider_name) if lease["recipe"]: @@ -123,6 +131,25 @@ def list_user_leases( monitor_repo.close() +def _is_user_visible_lease_thread(thread_id: str | None) -> bool: + raw = str(thread_id or "").strip() + if not raw: + return False + if raw.startswith("subagent-"): + return False + if is_virtual_thread_id(raw): + return False + return True + + +def _is_user_visible_lease_state(lease: dict[str, Any]) -> bool: + # @@@user-visible-lease-scope - product-facing lease surfaces should only + # expose leases the user can still act on, not historical stopped/destroying + # residue from monitor storage. + status = map_lease_to_session_status(lease.get("observed_state"), lease.get("desired_state")) + return status in {"running", "paused"} + + def available_sandbox_types() -> list[dict[str, Any]]: """Scan ~/.leon/sandboxes/ for configured providers.""" providers, _ = init_providers_and_managers() @@ -142,6 +169,16 @@ def available_sandbox_types() -> list[dict[str, Any]]: try: config = SandboxConfig.load(name) provider_obj = providers.get(name) + if provider_obj is None: + types.append( + { + "name": name, + "provider": config.provider, + "available": False, + "reason": f"Provider {name} is configured but unavailable in the current process", + } + ) + continue item: dict[str, Any] = { "name": name, "provider": config.provider, @@ -194,6 +231,8 @@ def _build_providers_and_managers() -> tuple[dict[str, Any], dict[str, Any]]: default_context_path=config.agentbay.context_path, image_id=config.agentbay.image_id, provider_name=name, + supports_pause=config.agentbay.supports_pause, + supports_resume=config.agentbay.supports_resume, ) elif config.provider == "docker": from sandbox.providers.docker import DockerProvider @@ -387,6 +426,35 @@ def mutate_sandbox_session( } +def get_session_metrics(session_id: str, provider_hint: str | None = None) -> dict[str, Any]: + """Load one session's provider metrics through the current manager inventory.""" + _, managers = init_providers_and_managers() + sessions = load_all_sessions(managers) + session, manager = find_session_and_manager(sessions, managers, session_id, provider_name=provider_hint) + if not session: + raise RuntimeError(f"Session not found: {session_id}") + if manager is None: + raise RuntimeError(f"Provider manager unavailable: {session.get('provider')}") + + target_session_id = str(session.get("instance_id") or session.get("session_id") or session_id) + metrics = manager.provider.get_metrics(target_session_id) + if metrics is None: + return {"session_id": target_session_id, "provider": session.get("provider"), "metrics": None} + return { + "session_id": target_session_id, + "provider": session.get("provider"), + "metrics": { + "cpu_percent": metrics.cpu_percent, + "memory_used_mb": metrics.memory_used_mb, + "memory_total_mb": metrics.memory_total_mb, + "disk_used_gb": metrics.disk_used_gb, + "disk_total_gb": metrics.disk_total_gb, + "network_rx_kbps": metrics.network_rx_kbps, + "network_tx_kbps": metrics.network_tx_kbps, + }, + } + + def build_provider_from_config_name(name: str, *, sandboxes_dir: Path | None = None) -> Any | None: """Build one provider instance from sandbox config name. Used by resource_service for per-session ops.""" providers, _ = init_providers_and_managers() diff --git a/backend/web/services/streaming_service.py b/backend/web/services/streaming_service.py index 9e6e71a77..7227a87e6 100644 --- a/backend/web/services/streaming_service.py +++ b/backend/web/services/streaming_service.py @@ -4,7 +4,6 @@ import json import logging import random -import traceback import uuid as _uuid from collections.abc import AsyncGenerator from typing import Any @@ -13,11 +12,31 @@ from backend.web.services.event_store import cleanup_old_runs from backend.web.utils.serializers import extract_text_content from core.runtime.middleware.monitor import AgentState +from core.runtime.notifications import is_terminal_background_notification from sandbox.thread_context import set_current_run_id, set_current_thread_id from storage.contracts import RunEventRepo logger = logging.getLogger(__name__) +type SSEEvent = dict[str, str | int] + +_TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE = ( + "Terminal background completion notifications require an explicit assistant followthrough. " + "Treat these notifications as fresh inputs that need a visible assistant reply. " + "You must produce at least one visible assistant message for them; " + "do not stay silent and do not end the run after only surfacing a notice. " + "Do not call TaskOutput or TaskStop for a terminal notification. " + "If no further tool is truly needed, answer directly in natural language " + "and briefly acknowledge the completion, failure, or cancellation honestly." +) + + +def _log_captured_exception(message: str, err: BaseException) -> None: + logger.error( + message, + exc_info=(type(err), err, err.__traceback__), + ) + def _resolve_run_event_repo(agent: Any) -> RunEventRepo | None: storage_container = getattr(agent, "storage_container", None) @@ -28,6 +47,18 @@ def _resolve_run_event_repo(agent: Any) -> RunEventRepo | None: return storage_container.run_event_repo() +def _augment_system_prompt_for_terminal_followthrough(system_prompt: Any) -> Any: + content = getattr(system_prompt, "content", None) + if not isinstance(content, str): + return system_prompt + if _TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE in content: + return system_prompt + # @@@terminal-followthrough-system-note - live models can otherwise treat + # terminal background notifications as internal reminders and emit no + # assistant text, leaving caller surfaces notice-only. + return system_prompt.__class__(content=f"{content}\n\n{_TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE}") + + async def prime_sandbox(agent: Any, thread_id: str) -> None: """Prime sandbox session before tool calls to avoid race conditions.""" @@ -256,8 +287,7 @@ def _ensure_thread_handlers(agent: Any, thread_id: str, app: Any) -> None: runtime = getattr(agent, "runtime", None) if not runtime: return - # Already bound? Skip. - if getattr(runtime, "_activity_sink", None) is not None: + if getattr(runtime, "_bound_thread_id", None) == thread_id and getattr(runtime, "_bound_thread_app", None) is app: return # Runtime must support bind_thread (AgentRuntime does, test fakes may not) if not hasattr(runtime, "bind_thread"): @@ -288,6 +318,7 @@ async def activity_sink(event: dict) -> None: if event_type and isinstance(data, dict): delta = display_builder_ref.apply_event(thread_id, event_type, data) if delta: + delta["_seq"] = seq await thread_buf.put( { "event": "display_delta", @@ -373,6 +404,8 @@ async def _start_run(): agent.runtime.transition(AgentState.IDLE) runtime.bind_thread(activity_sink=activity_sink) + runtime._bound_thread_id = thread_id + runtime._bound_thread_app = app qm.register_wake(thread_id, wake_handler) # Subscribe to EventBus so sub-agent events (spawned via AgentService) @@ -380,17 +413,227 @@ async def _start_run(): try: from backend.web.event_bus import get_event_bus - get_event_bus().subscribe(thread_id, activity_sink) + unsubscribe = getattr(runtime, "_thread_event_unsubscribe", None) + if callable(unsubscribe): + unsubscribe() + runtime._thread_event_unsubscribe = get_event_bus().subscribe(thread_id, activity_sink) except ImportError: pass +def _is_terminal_background_notification_message( + message: str, + *, + source: str | None, + notification_type: str | None, +) -> bool: + return is_terminal_background_notification( + message, + source=source, + notification_type=notification_type, + ) + + +def _partition_terminal_followups(items: list[Any]) -> tuple[list[Any], list[Any]]: + terminal = [] + passthrough = [] + for item in items: + if _is_terminal_background_notification_message( + item.content, + source=item.source or "system", + notification_type=item.notification_type, + ): + terminal.append(item) + else: + passthrough.append(item) + return terminal, passthrough + + +def _message_metadata_dict(message_metadata: dict[str, Any] | None) -> dict[str, Any]: + return dict(message_metadata or {}) + + +def _message_already_persisted(message: Any, *, content: str, metadata: dict[str, Any]) -> bool: + if message.__class__.__name__ != "HumanMessage": + return False + if getattr(message, "content", None) != content: + return False + return (getattr(message, "metadata", None) or {}) == metadata + + +async def _persist_cancelled_run_input_if_missing( + *, + agent: Any, + config: dict[str, Any], + message: str, + message_metadata: dict[str, Any] | None, +) -> None: + graph = getattr(agent, "agent", None) + if graph is None or not hasattr(graph, "aget_state") or not hasattr(graph, "aupdate_state"): + return + + from langchain_core.messages import HumanMessage + + metadata = _message_metadata_dict(message_metadata) + state = await graph.aget_state(config) + persisted = list((getattr(state, "values", None) or {}).get("messages", [])) + if persisted and _message_already_persisted(persisted[-1], content=message, metadata=metadata): + return + + # @@@cancelled-run-input-persist - a started run has already accepted this + # input at the caller boundary. If cancellation lands before the next loop + # checkpoint save, persist the input here so later turns do not pretend it + # never happened. + candidate = HumanMessage(content=message, metadata=metadata) if metadata else HumanMessage(content=message) + await graph.aupdate_state(config, {"messages": [candidate]}) + + +def _is_owner_steer_followup_message( + *, + source: str | None, + notification_type: str | None, +) -> bool: + return source == "owner" and notification_type == "steer" + + +async def _persist_cancelled_owner_steers( + *, + agent: Any, + config: dict[str, Any], + items: list[dict[str, str | None]], +) -> None: + graph = getattr(agent, "agent", None) + if graph is None or not hasattr(graph, "aupdate_state") or not items: + return + + from langchain_core.messages import HumanMessage + + # @@@cancelled-steer-persist - accepted steer is a real user turn. If the + # active run is cancelled before the next model call, we must checkpoint it + # now instead of letting it silently relaunch as a ghost instruction. + await graph.aupdate_state( + config, + { + "messages": [ + HumanMessage( + content=str(item["content"] or ""), + metadata={ + "source": "owner", + "notification_type": "steer", + "is_steer": True, + }, + ) + for item in items + ] + }, + ) + + +async def _flush_cancelled_owner_steers( + *, + agent: Any, + config: dict[str, Any], + thread_id: str, + app: Any, +) -> None: + qm = app.state.queue_manager + queued_items = qm.drain_all(thread_id) + if not queued_items: + return + + owner_steers: list[dict[str, str | None]] = [] + passthrough: list[Any] = [] + for item in queued_items: + if _is_owner_steer_followup_message( + source=item.source, + notification_type=item.notification_type, + ): + owner_steers.append( + { + "content": item.content, + "source": item.source or "owner", + "notification_type": item.notification_type, + } + ) + else: + passthrough.append(item) + + await _persist_cancelled_owner_steers(agent=agent, config=config, items=owner_steers) + + for item in passthrough: + qm.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + sender_avatar_url=item.sender_avatar_url, + is_steer=item.is_steer, + ) + + +async def _emit_queued_terminal_followups( + *, + app: Any, + thread_id: str, + emit: Any, +) -> list[dict[str, str | None]]: + emitted_terminal: list[dict[str, str | None]] = [] + + async def _drain_once() -> bool: + queued_items = app.state.queue_manager.drain_all(thread_id) + extra_terminal, passthrough = _partition_terminal_followups(queued_items) + for item in passthrough: + app.state.queue_manager.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + sender_avatar_url=item.sender_avatar_url, + is_steer=item.is_steer, + ) + for item in extra_terminal: + await emit( + { + "event": "notice", + "data": json.dumps( + { + "content": item.content, + "source": item.source or "system", + "notification_type": item.notification_type, + }, + ensure_ascii=False, + ), + } + ) + emitted_terminal.append( + { + "content": item.content, + "source": item.source or "system", + "notification_type": item.notification_type, + } + ) + return bool(extra_terminal) + + # @@@terminal-followup-race-window - multiple background tasks can finish + # while the first notice-only followthrough run is being emitted. Drain once + # for already-persisted notices, yield one loop tick, then drain again so + # same-turn terminal completions are folded into the same stable followthrough. + await _drain_once() + await asyncio.sleep(0) + await _drain_once() + return emitted_terminal + + # --------------------------------------------------------------------------- # Producer: runs agent, writes events to ThreadEventBuffer # --------------------------------------------------------------------------- -async def _run_agent_to_buffer( +async def _run_agent_to_buffer( # pyright: ignore[reportGeneralTypeIssues] # @@@nu59-complexity-honesty agent: Any, thread_id: str, message: str, @@ -399,7 +642,8 @@ async def _run_agent_to_buffer( thread_buf: ThreadEventBuffer, run_id: str, message_metadata: dict[str, Any] | None = None, -) -> None: + input_messages: list[Any] | None = None, +) -> str: """Run agent execution and write all SSE events into *thread_buf*.""" from backend.web.services.event_store import append_event @@ -428,12 +672,16 @@ async def emit(event: dict, message_id: str | None = None) -> None: event = {**event, "data": json.dumps(data, ensure_ascii=False)} await thread_buf.put(event) - # Compute display delta and emit it (no _seq — avoids dedup conflict - # with the raw event that shares the same seq) + # Compute display delta and emit it alongside the raw event. event_type = event.get("event", "") if event_type and isinstance(data, dict): delta = display_builder.apply_event(thread_id, event_type, data) if delta: + # @@@display-delta-source-seq - replay after-filter only knows raw + # event seqs. Carry the source seq onto the derived delta so a + # reconnect after GET /thread can skip stale display_delta + # replays instead of rebuilding the same thread a second time. + delta["_seq"] = seq await thread_buf.put( { "event": "display_delta", @@ -444,6 +692,7 @@ async def emit(event: dict, message_id: str | None = None) -> None: task = None stream_gen = None pending_tool_calls: dict[str, dict] = {} + output_parts: list[str] = [] try: config = {"configurable": {"thread_id": thread_id, "run_id": run_id}} if hasattr(agent, "_current_model_config"): @@ -486,8 +735,8 @@ async def emit(event: dict, message_id: str | None = None) -> None: obs_config = ObservationLoader().load() if obs_provider == "langfuse": - from langfuse import Langfuse - from langfuse.langchain import CallbackHandler as LangfuseHandler + from langfuse import Langfuse # pyright: ignore[reportMissingImports] + from langfuse.langchain import CallbackHandler as LangfuseHandler # pyright: ignore[reportMissingImports] cfg = obs_config.langfuse if cfg.secret_key and cfg.public_key: @@ -589,7 +838,21 @@ def on_activity_event(event: dict) -> None: # enqueue time (@@@steer-instant-feedback). # Note: is_steer is NOT persisted in queue, so check notification_type too. is_steer = meta.get("is_steer") or meta.get("notification_type") == "steer" - if (not src or src == "owner") and not is_steer: + if meta.get("ask_user_question_answered"): + await emit( + { + "event": "user_message", + "data": json.dumps( + { + "content": "", + "showing": False, + "ask_user_question_answered": meta["ask_user_question_answered"], + }, + ensure_ascii=False, + ), + } + ) + elif (not src or src == "owner") and not is_steer: # @@@strip-for-display — agent sees full content (with system-reminder), # frontend sees clean text (tags stripped) from backend.web.utils.serializers import strip_system_tags @@ -625,9 +888,10 @@ def on_activity_event(event: dict) -> None: ) # @@@run-notice — emit notice right after run_start so frontend folds it - # into the (re)opened turn. Only for external notifications (not owner steer). + # into the (re)opened turn. Mirror the cold-path DisplayBuilder rule: + # any source=system message is a notice; external notices stay chat-only. ntype = meta.get("notification_type") - if src and src != "owner" and ntype == "chat": + if src == "system" or (src == "external" and ntype == "chat"): await emit( { "event": "notice", @@ -642,7 +906,46 @@ def on_activity_event(event: dict) -> None: } ) - if message_metadata: + terminal_followthrough_items: list[dict[str, str | None]] | None = None + original_system_prompt = None + # @@@terminal-followthrough-reentry - terminal background completions + # still surface as durable notices first, but they must then re-enter the + # model as a real followthrough turn instead of terminating at notice-only. + if _is_terminal_background_notification_message( + message, + source=src, + notification_type=ntype, + ): + terminal_followthrough_items = [ + { + "content": message, + "source": src or "system", + "notification_type": ntype, + } + ] + terminal_followthrough_items.extend(await _emit_queued_terminal_followups(app=app, thread_id=thread_id, emit=emit)) + if hasattr(agent, "agent") and hasattr(agent.agent, "system_prompt"): + original_system_prompt = agent.agent.system_prompt + agent.agent.system_prompt = _augment_system_prompt_for_terminal_followthrough(original_system_prompt) + + if terminal_followthrough_items: + from langchain_core.messages import HumanMessage + + _initial_input = { + "messages": [ + HumanMessage( + content=str(item["content"] or ""), + metadata={ + "source": item["source"] or "system", + "notification_type": item["notification_type"], + }, + ) + for item in terminal_followthrough_items + ] + } + elif input_messages is not None: + _initial_input = {"messages": input_messages} + elif message_metadata: from langchain_core.messages import HumanMessage _initial_input: dict | None = {"messages": [HumanMessage(content=message, metadata=message_metadata)]} @@ -725,7 +1028,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: mode, data = chunk if mode == "messages": - msg_chunk, metadata = data + msg_chunk, _metadata = data msg_class = msg_chunk.__class__.__name__ if msg_class == "AIMessageChunk": # @@@compact-leak-guard — skip chunks from compact's summary LLM call. @@ -735,6 +1038,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: content = extract_text_content(getattr(msg_chunk, "content", "")) chunk_msg_id = getattr(msg_chunk, "id", None) if content: + output_parts.append(content) await emit( { "event": "text", @@ -792,14 +1096,13 @@ def _is_retryable_stream_error(err: Exception) -> bool: msg_class = msg.__class__.__name__ if msg_class == "HumanMessage": - # @@@mid-turn-chat-notice — emit notice for chat - # notifications injected by before_model. display_builder - # folds it into the current turn as a segment (same as - # cold-path checkpoint rebuild behavior). + # @@@mid-turn-notice-parity — hot streaming must use the + # same notice contract as cold checkpoint rebuild: + # source=system always folds as notice; external stays + # limited to chat notifications. meta = getattr(msg, "metadata", None) or {} - if meta.get("notification_type") == "chat" and meta.get("source") in ( - "external", - "system", + if meta.get("source") == "system" or ( + meta.get("source") == "external" and meta.get("notification_type") == "chat" ): await emit( { @@ -808,7 +1111,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: { "content": msg.content if isinstance(msg.content, str) else str(msg.content), "source": meta.get("source", "external"), - "notification_type": "chat", + "notification_type": meta.get("notification_type"), }, ensure_ascii=False, ), @@ -861,8 +1164,11 @@ def _is_retryable_stream_error(err: Exception) -> bool: continue if tc_id: pending_tool_calls.pop(tc_id, None) - if hasattr(msg, "metadata") and isinstance(msg.metadata, dict): - msg.metadata["run_id"] = run_id + merged_meta = dict(getattr(msg, "metadata", None) or {}) + tool_result_meta = getattr(msg, "additional_kwargs", {}).get("tool_result_meta") + if isinstance(tool_result_meta, dict): + merged_meta = {**tool_result_meta, **merged_meta} + merged_meta["run_id"] = run_id tool_name = getattr(msg, "name", "") or "" await emit( { @@ -872,7 +1178,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: "tool_call_id": tc_id, "name": tool_name, "content": str(getattr(msg, "content", "")), - "metadata": getattr(msg, "metadata", None) or {}, + "metadata": merged_meta, "showing": True, }, ensure_ascii=False, @@ -920,7 +1226,10 @@ def _is_retryable_stream_error(err: Exception) -> bool: await stream_gen.aclose() await asyncio.sleep(wait) else: - traceback.print_exc() + _log_captured_exception( + f"[streaming] stream failed for thread {thread_id}", + stream_err, + ) await emit({"event": "error", "data": json.dumps({"error": str(stream_err)}, ensure_ascii=False)}) break @@ -954,8 +1263,21 @@ def _is_retryable_stream_error(err: Exception) -> bool: # A5: emit run_done instead of done (persistent buffer — no mark_done) await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "".join(output_parts).strip() except asyncio.CancelledError: cancelled_tool_call_ids = await write_cancellation_markers(agent, config, pending_tool_calls) + await _persist_cancelled_run_input_if_missing( + agent=agent, + config=config, + message=message, + message_metadata=message_metadata, + ) + await _flush_cancelled_owner_steers( + agent=agent, + config=config, + thread_id=thread_id, + app=app, + ) await emit( { "event": "cancelled", @@ -969,11 +1291,18 @@ def _is_retryable_stream_error(err: Exception) -> bool: ) # Also emit run_done so frontend knows the run ended await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "" except Exception as e: - traceback.print_exc() + _log_captured_exception( + f"[streaming] run failed for thread {thread_id}", + e, + ) await emit({"event": "error", "data": json.dumps({"error": str(e)}, ensure_ascii=False)}) await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "" finally: + if original_system_prompt is not None and hasattr(agent, "agent") and hasattr(agent.agent, "system_prompt"): + agent.agent.system_prompt = original_system_prompt # @@@typing-lifecycle-stop — guaranteed cleanup even on crash/cancel typing_tracker = getattr(app.state, "typing_tracker", None) if typing_tracker is not None: @@ -985,7 +1314,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: if obs_handler is not None: try: if obs_active == "langfuse": - from langfuse import get_client + from langfuse import get_client # pyright: ignore[reportMissingImports] get_client().flush() elif obs_active == "langsmith": @@ -1036,22 +1365,29 @@ async def _consume_followup_queue(agent: Any, thread_id: str, app: Any) -> None: item = None try: qm = app.state.queue_manager + if not qm.peek(thread_id) or not app: + return + if not (hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE)): + return item = qm.dequeue(thread_id) - if item and app: - if hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE): - start_agent_run( - agent, - thread_id, - item.content, - app, - message_metadata={ - "source": item.source or "system", - "notification_type": item.notification_type, - "sender_name": item.sender_name, - "sender_avatar_url": item.sender_avatar_url, - "is_steer": getattr(item, "is_steer", False), - }, - ) + if item is None: + logger.warning("followup dequeue lost race for thread %s; reverting to IDLE", thread_id) + if hasattr(agent, "runtime"): + agent.runtime.transition(AgentState.IDLE) + return + start_agent_run( + agent, + thread_id, + item.content, + app, + message_metadata={ + "source": item.source or "system", + "notification_type": item.notification_type, + "sender_name": item.sender_name, + "sender_avatar_url": item.sender_avatar_url, + "is_steer": getattr(item, "is_steer", False), + }, + ) except Exception: logger.exception("Failed to consume followup queue for thread %s", thread_id) # Re-enqueue the message if it was already dequeued to prevent data loss @@ -1074,18 +1410,90 @@ def start_agent_run( app: Any, enable_trajectory: bool = False, message_metadata: dict[str, Any] | None = None, + input_messages: list[Any] | None = None, ) -> str: """Launch agent producer on the persistent ThreadEventBuffer. Returns run_id.""" thread_buf = get_or_create_thread_buffer(app, thread_id) run_id = str(_uuid.uuid4()) bg_task = asyncio.create_task( - _run_agent_to_buffer(agent, thread_id, message, app, enable_trajectory, thread_buf, run_id, message_metadata) + _run_agent_to_buffer( + agent, + thread_id, + message, + app, + enable_trajectory, + thread_buf, + run_id, + message_metadata, + input_messages, + ) ) # Store the background task so cancel_run can still cancel it app.state.thread_tasks[thread_id] = bg_task return run_id +async def run_child_thread_live( + agent: Any, + thread_id: str, + message: str, + app: Any, + *, + input_messages: list[Any], +) -> str: + """Run a spawned child agent through the normal web thread bridge.""" + from backend.web.services.agent_pool import resolve_thread_sandbox + from backend.web.utils.serializers import extract_text_content + + sandbox_type = resolve_thread_sandbox(app, thread_id) + app.state.agent_pool[f"{thread_id}:{sandbox_type}"] = agent + thread_buf = get_or_create_thread_buffer(app, thread_id) + error_cursor = thread_buf.total_count + _ensure_thread_handlers(agent, thread_id, app) + if not (hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE)): + raise RuntimeError(f"Child thread {thread_id} could not transition to active") + + start_agent_run( + agent, + thread_id, + message, + app, + input_messages=input_messages, + ) + task = app.state.thread_tasks[thread_id] + result = await task + recent_events, _ = await thread_buf.read_with_timeout(error_cursor, timeout=0.01) + if recent_events: + # @@@child-live-error-surfacing - child live runs can emit an error event + # and still return an empty string from _run_agent_to_buffer(); treat that + # as a real child failure instead of laundering it into fake completion. + for event in recent_events: + if event.get("event") != "error": + continue + try: + payload = json.loads(event.get("data", "{}")) + except (json.JSONDecodeError, TypeError): + payload = {} + error_text = payload.get("error") if isinstance(payload, dict) else None + raise RuntimeError(error_text or f"Child thread {thread_id} failed") + if isinstance(result, str) and result.strip(): + return result.strip() + + state = await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + values = getattr(state, "values", {}) if state else {} + messages = values.get("messages", []) if isinstance(values, dict) else [] + visible_ai = [ + extract_text_content(getattr(msg, "content", "")).strip() + for msg in messages + if msg.__class__.__name__ == "AIMessage" and extract_text_content(getattr(msg, "content", "")).strip() + ] + runtime_status = agent.runtime.get_status_dict() if hasattr(agent, "runtime") and hasattr(agent.runtime, "get_status_dict") else {} + runtime_calls = runtime_status.get("calls") if isinstance(runtime_status, dict) else None + if not visible_ai and runtime_calls == 0: + raise RuntimeError(f"Child thread {thread_id} failed before first model call") + return "\n".join(visible_ai) if visible_ai else "(Agent completed with no text output)" + + # --------------------------------------------------------------------------- # Consumer: persistent thread event stream # --------------------------------------------------------------------------- @@ -1094,54 +1502,37 @@ def start_agent_run( async def observe_thread_events( thread_buf: ThreadEventBuffer, after: int = 0, -) -> AsyncGenerator[dict[str, str], None]: +) -> AsyncGenerator[SSEEvent, None]: """Consume events from a persistent ThreadEventBuffer. Yields SSE event dicts. Unlike observe_run_events, this never terminates on its own — the client disconnect (or server shutdown) closes the connection. run_done is a flow event, not a terminal signal. """ - yield {"retry": 5000} - # Always start from the beginning of the ring buffer. # For after=0 (new connection): replay all buffered events so we never miss # events emitted between postRun and SSE connect (race condition fix). # For after>0 (reconnect): start from ring start, filter by _seq below. - cursor = 0 - - while True: - events, cursor = await thread_buf.read_with_timeout(cursor, timeout=30) - if events is None: - yield {"comment": "keepalive"} - continue - if not events: - continue - for event in events: - parsed_data = None - try: - parsed_data = json.loads(event.get("data", "{}")) - except (json.JSONDecodeError, TypeError): - pass - - # @@@after-filter — skip events already seen on reconnect. - # Events without _seq (e.g. display_delta) are never filtered — - # they are ephemeral derivatives of persisted events. - if after > 0 and isinstance(parsed_data, dict) and "_seq" in parsed_data: - if parsed_data["_seq"] <= after: - continue - - seq_id = str(parsed_data["_seq"]) if isinstance(parsed_data, dict) and "_seq" in parsed_data else None - if seq_id: - yield {**event, "id": seq_id} - else: - yield event + async for event in _observe_sse_buffer(thread_buf, after=after, stop_on_finish=False): + yield event async def observe_run_events( buf: RunEventBuffer, after: int = 0, -) -> AsyncGenerator[dict[str, str], None]: +) -> AsyncGenerator[SSEEvent, None]: """Consume events from a RunEventBuffer (subagent streams only). Yields SSE event dicts.""" + async for event in _observe_sse_buffer(buf, after=after, stop_on_finish=True): + yield event + + +async def _observe_sse_buffer( + buf: ThreadEventBuffer | RunEventBuffer, + *, + after: int, + stop_on_finish: bool, +) -> AsyncGenerator[SSEEvent, None]: + """Shared SSE observer loop for thread and run buffers.""" yield {"retry": 5000} cursor = 0 @@ -1150,7 +1541,7 @@ async def observe_run_events( if events is None and not buf.finished.is_set(): yield {"comment": "keepalive"} continue - if not events and buf.finished.is_set(): + if stop_on_finish and not events and buf.finished.is_set(): break if not events: continue @@ -1162,8 +1553,8 @@ async def observe_run_events( pass # @@@after-filter — skip events already seen on reconnect. - # Events without _seq (e.g. display_delta) are never filtered — - # they are ephemeral derivatives of persisted events. + # display_delta now carries the source raw-event seq too, so stale + # derived deltas are filtered together with their persisted source. if after > 0 and isinstance(parsed_data, dict) and "_seq" in parsed_data: if parsed_data["_seq"] <= after: continue diff --git a/backend/web/services/task_service.py b/backend/web/services/task_service.py index 86197b584..3c7ae1b91 100644 --- a/backend/web/services/task_service.py +++ b/backend/web/services/task_service.py @@ -3,71 +3,114 @@ from typing import Any from backend.web.core.storage_factory import make_panel_task_repo +from storage.runtime import build_thread_repo def _repo() -> Any: return make_panel_task_repo() -def list_tasks() -> list[dict[str, Any]]: - repo = _repo() +def list_tasks(owner_user_id: str | None = None, repo: Any = None, thread_repo: Any = None) -> list[dict[str, Any]]: + own_repo = repo is None + repo = repo or _repo() try: - return repo.list_all() + return _enrich_task_thread_members(repo.list_all(owner_user_id=owner_user_id), thread_repo=thread_repo) finally: - repo.close() + if own_repo: + repo.close() -def get_task(task_id: str) -> dict[str, Any] | None: - repo = _repo() +def _enrich_task_thread_members(tasks: list[dict[str, Any]], thread_repo: Any = None) -> list[dict[str, Any]]: + thread_ids = [str(task.get("thread_id") or "").strip() for task in tasks] + thread_ids = [thread_id for thread_id in dict.fromkeys(thread_ids) if thread_id] + if not thread_ids: + return tasks + + # @@@task-thread-member-enrichment - panel tasks persist thread_id only, so enrich member_id + # from canonical thread metadata before frontend deep-links are rendered. + own_thread_repo = thread_repo is None + thread_repo = thread_repo or build_thread_repo() + try: + member_ids = {thread_id: (thread_repo.get_by_id(thread_id) or {}).get("member_id") for thread_id in thread_ids} + finally: + if own_thread_repo: + thread_repo.close() + + enriched: list[dict[str, Any]] = [] + for task in tasks: + thread_id = str(task.get("thread_id") or "").strip() + if thread_id and member_ids.get(thread_id): + enriched.append({**task, "member_id": member_ids[thread_id]}) + else: + enriched.append(task) + return enriched + + +def get_task(task_id: str, owner_user_id: str | None = None, repo: Any = None) -> dict[str, Any] | None: + own_repo = repo is None + repo = repo or _repo() try: - return repo.get(task_id) + return repo.get(task_id, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def get_highest_priority_pending_task() -> dict[str, Any] | None: - repo = _repo() +def get_highest_priority_pending_task(owner_user_id: str | None = None, repo: Any = None) -> dict[str, Any] | None: + own_repo = repo is None + repo = repo or _repo() try: - return repo.get_highest_priority_pending() + return repo.get_highest_priority_pending(owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def create_task(**fields: Any) -> dict[str, Any]: - repo = _repo() +def create_task(repo: Any = None, **fields: Any) -> dict[str, Any]: + own_repo = repo is None + repo = repo or _repo() try: return repo.create(**fields) finally: - repo.close() + if own_repo: + repo.close() -def update_task(task_id: str, **fields: Any) -> dict[str, Any] | None: - repo = _repo() +def update_task(task_id: str, owner_user_id: str | None = None, repo: Any = None, **fields: Any) -> dict[str, Any] | None: + own_repo = repo is None + repo = repo or _repo() try: - return repo.update(task_id, **fields) + return repo.update(task_id, owner_user_id=owner_user_id, **fields) finally: - repo.close() + if own_repo: + repo.close() -def delete_task(task_id: str) -> bool: - repo = _repo() +def delete_task(task_id: str, owner_user_id: str | None = None, repo: Any = None) -> bool: + own_repo = repo is None + repo = repo or _repo() try: - return repo.delete(task_id) + return repo.delete(task_id, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def bulk_delete_tasks(ids: list[str]) -> int: - repo = _repo() +def bulk_delete_tasks(ids: list[str], owner_user_id: str | None = None, repo: Any = None) -> int: + own_repo = repo is None + repo = repo or _repo() try: - return repo.bulk_delete(ids) + return repo.bulk_delete(ids, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() -def bulk_update_task_status(ids: list[str], status: str) -> int: - repo = _repo() +def bulk_update_task_status(ids: list[str], status: str, owner_user_id: str | None = None, repo: Any = None) -> int: + own_repo = repo is None + repo = repo or _repo() try: - return repo.bulk_update_status(ids, status) + return repo.bulk_update_status(ids, status, owner_user_id=owner_user_id) finally: - repo.close() + if own_repo: + repo.close() diff --git a/backend/web/services/thread_launch_config_service.py b/backend/web/services/thread_launch_config_service.py index 00060e222..b9202c21c 100644 --- a/backend/web/services/thread_launch_config_service.py +++ b/backend/web/services/thread_launch_config_service.py @@ -6,7 +6,7 @@ from backend.web.services import sandbox_service from backend.web.services.library_service import list_library -from sandbox.recipes import provider_type_from_name +from sandbox.recipes import normalize_recipe_snapshot, provider_type_from_name def normalize_launch_config_payload(payload: dict[str, Any]) -> dict[str, Any]: @@ -20,22 +20,51 @@ def normalize_launch_config_payload(payload: dict[str, Any]) -> dict[str, Any]: } -def save_last_confirmed_config(app: Any, owner_user_id: str, member_id: str, payload: dict[str, Any]) -> None: - app.state.thread_launch_pref_repo.save_confirmed( - owner_user_id, - member_id, - normalize_launch_config_payload(payload), +def build_existing_launch_config( + *, + lease: dict[str, Any], + model: str | None, + workspace: str | None, +) -> dict[str, Any]: + return normalize_launch_config_payload( + { + "create_mode": "existing", + "provider_config": lease.get("provider_name"), + "recipe": lease.get("recipe"), + "lease_id": lease.get("lease_id"), + "model": model, + "workspace": workspace, + } ) -def save_last_successful_config(app: Any, owner_user_id: str, member_id: str, payload: dict[str, Any]) -> None: - app.state.thread_launch_pref_repo.save_successful( - owner_user_id, - member_id, - normalize_launch_config_payload(payload), +def build_new_launch_config( + *, + provider_config: str, + recipe: dict[str, Any] | None, + model: str | None, + workspace: str | None, +) -> dict[str, Any]: + return normalize_launch_config_payload( + { + "create_mode": "new", + "provider_config": provider_config, + "recipe": normalize_recipe_snapshot(provider_type_from_name(provider_config), recipe), + "lease_id": None, + "model": model, + "workspace": workspace, + } ) +def save_last_confirmed_config(app: Any, owner_user_id: str, member_id: str, payload: dict[str, Any]) -> None: + _save_launch_config(app.state.thread_launch_pref_repo.save_confirmed, owner_user_id, member_id, payload) + + +def save_last_successful_config(app: Any, owner_user_id: str, member_id: str, payload: dict[str, Any]) -> None: + _save_launch_config(app.state.thread_launch_pref_repo.save_successful, owner_user_id, member_id, payload) + + def resolve_default_config(app: Any, owner_user_id: str, member_id: str) -> dict[str, Any]: prefs = app.state.thread_launch_pref_repo.get(owner_user_id, member_id) or {} leases = sandbox_service.list_user_leases( @@ -119,6 +148,14 @@ def _validate_saved_config( } +def _save_launch_config(save_fn: Any, owner_user_id: str, member_id: str, payload: dict[str, Any]) -> None: + save_fn( + owner_user_id, + member_id, + normalize_launch_config_payload(payload), + ) + + def _derive_default_config( *, member_threads: list[dict[str, Any]], diff --git a/backend/web/services/thread_naming.py b/backend/web/services/thread_naming.py index ee65a9923..0e3fba68d 100644 --- a/backend/web/services/thread_naming.py +++ b/backend/web/services/thread_naming.py @@ -1,4 +1,4 @@ -"""Canonical thread/entity naming helpers.""" +"""Canonical thread naming helpers.""" from __future__ import annotations @@ -7,18 +7,11 @@ def validate_thread_identity(*, is_main: bool, branch_index: int) -> None: if branch_index < 0: raise ValueError(f"branch_index must be >= 0, got {branch_index}") if is_main and branch_index != 0: - raise ValueError(f"Main thread must have branch_index=0, got {branch_index}") + raise ValueError(f"Default thread must have branch_index=0, got {branch_index}") if not is_main and branch_index == 0: raise ValueError("Child thread must have branch_index>0") -def canonical_entity_name(member_name: str, *, is_main: bool, branch_index: int) -> str: - validate_thread_identity(is_main=is_main, branch_index=branch_index) - if is_main: - return member_name - return f"{member_name} · 分身{branch_index}" - - def sidebar_label(*, is_main: bool, branch_index: int) -> str | None: validate_thread_identity(is_main=is_main, branch_index=branch_index) if is_main: diff --git a/backend/web/services/thread_state_service.py b/backend/web/services/thread_state_service.py index 30e0186ec..b9acf4ae2 100644 --- a/backend/web/services/thread_state_service.py +++ b/backend/web/services/thread_state_service.py @@ -21,7 +21,14 @@ def _resolve_thread_sandbox_instance(mgr: Any, lease: Any) -> Any | None: def _display_sandbox_status(lease: Any, instance: Any) -> str: observed = getattr(lease, "observed_state", None) - return instance.status if observed in {None, "", "detached"} else observed + if observed in {None, "", "detached"}: + status = getattr(instance, "status", None) + if not isinstance(status, str) or not status: + raise RuntimeError("Sandbox instance missing status") + return status + if not isinstance(observed, str): + raise RuntimeError("Lease observed_state must be a string when present") + return observed def get_sandbox_info(agent: Any, thread_id: str, sandbox_type: str) -> dict[str, Any]: @@ -125,14 +132,14 @@ def _get_terminal(): } -async def get_lease_status(agent: Any, thread_id: str) -> dict[str, Any]: +async def get_lease_status(agent: Any, thread_id: str) -> dict[str, Any] | None: """Get SandboxLease status for a thread. Returns: Dict with lease_id, provider_name, states, instance info, timestamps Raises: - ValueError: If no lease found for thread + None: If no lease found for thread """ def _get_lease(): @@ -147,7 +154,7 @@ def _get_lease(): lease = await asyncio.to_thread(_get_lease) if not lease: - raise ValueError(f"No lease found for thread {thread_id}") + return None instance = lease.get_instance() created_at, updated_at = await asyncio.to_thread(get_lease_timestamps, lease.lease_id) diff --git a/backend/web/services/wechat_service.py b/backend/web/services/wechat_service.py deleted file mode 100644 index b19261d79..000000000 --- a/backend/web/services/wechat_service.py +++ /dev/null @@ -1,517 +0,0 @@ -"""WeChat connection service — ilink API client + connection lifecycle + background poll. - -Uses the official WeChat ClawBot ilink API at ilinkai.weixin.qq.com. -Protocol: HTTP/JSON long-polling, modeled after Telegram Bot API. -Auth: Bearer token obtained via QR code scan. - -@@@per-user — each human user_id gets its own WeChatConnection. -user_id is the social identity in Leon's network (Supabase auth UUID for humans). -Polling auto-starts at backend boot via lifespan.py for all users with saved credentials. - -@@@no-globals — WeChatConnectionRegistry lives on app.state, not module-level. -""" - -import asyncio -import json -import logging -import os -import random -import struct -import time -from base64 import b64encode -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import Literal - -import httpx -from pydantic import BaseModel - -from config.user_paths import user_home_path, user_home_read_candidates - -logger = logging.getLogger(__name__) - -DEFAULT_BASE_URL = "https://ilinkai.weixin.qq.com" -BOT_TYPE = "3" -CHANNEL_VERSION = "0.1.0" -LONG_POLL_TIMEOUT_S = 35 -SEND_TIMEOUT_S = 15 - -MSG_TYPE_USER = 1 -MSG_TYPE_BOT = 2 -MSG_ITEM_TEXT = 1 -MSG_ITEM_VOICE = 3 -MSG_STATE_FINISH = 2 - -CONNECTIONS_BASE = user_home_path("connections", "wechat") - -RoutingType = Literal["thread", "chat"] - -# @@@delivery-callback — injected at construction, avoids circular import of app -DeliveryFn = Callable[["WeChatConnection", "WeChatMessage"], Awaitable[None]] - - -# --- Pydantic models for API --- - - -class WeChatCredentials(BaseModel): - token: str - base_url: str = DEFAULT_BASE_URL - account_id: str - user_id: str = "" - saved_at: str = "" - - -class RoutingConfig(BaseModel): - type: RoutingType | None = None - id: str | None = None - label: str = "" - - -class QrPollRequest(BaseModel): - qrcode: str - - -class RoutingSetRequest(BaseModel): - type: RoutingType - id: str - label: str = "" - - -class WeChatMessage(BaseModel): - from_user_id: str - text: str - context_token: str - - class Config: - frozen = True - - -class WeChatAPIError(Exception): - pass - - -class SessionExpiredError(WeChatAPIError): - pass - - -# --- ilink protocol helpers --- - - -def _random_wechat_uin() -> str: - val = struct.unpack(">I", os.urandom(4))[0] - return b64encode(str(val).encode()).decode() - - -def _build_headers(token: str | None = None, body: str | None = None) -> dict[str, str]: - headers: dict[str, str] = { - "Content-Type": "application/json", - "AuthorizationType": "ilink_bot_token", - "X-WECHAT-UIN": _random_wechat_uin(), - } - if body: - headers["Content-Length"] = str(len(body.encode())) - if token: - headers["Authorization"] = f"Bearer {token.strip()}" - return headers - - -def _extract_text(msg: dict) -> str: - items = msg.get("item_list") or [] - for item in items: - if item.get("type") == MSG_ITEM_TEXT: - text = (item.get("text_item") or {}).get("text", "") - ref = item.get("ref_msg") - if ref and ref.get("title"): - return f"[引用: {ref['title']}]\n{text}" - return text - if item.get("type") == MSG_ITEM_VOICE: - return (item.get("voice_item") or {}).get("text", "") - return "" - - -# --- Per-user persistence (keyed by user_id) --- - - -def _user_dir(user_id: str) -> Path: - return CONNECTIONS_BASE / user_id - - -def _user_dir_candidates(user_id: str) -> tuple[Path, ...]: - return tuple(path / user_id for path in user_home_read_candidates("connections", "wechat")) - - -def _save_json(user_id: str, filename: str, data: dict) -> None: - d = _user_dir(user_id) - d.mkdir(parents=True, exist_ok=True) - path = d / filename - path.write_text(json.dumps(data, indent=2)) - if filename == "credentials.json": - path.chmod(0o600) - - -def _load_json(user_id: str, filename: str) -> dict | None: - for path in reversed(_user_dir_candidates(user_id)): - candidate = path / filename - if not candidate.exists(): - continue - try: - return json.loads(candidate.read_text()) - except (json.JSONDecodeError, KeyError) as e: - logger.error("Failed to load %s for %s: %s", filename, user_id[:12], e) - return None - - -def _delete_file(user_id: str, filename: str) -> None: - seen: set[Path] = set() - for user_dir in _user_dir_candidates(user_id): - path = user_dir / filename - if path in seen: - continue - seen.add(path) - if path.exists(): - path.unlink() - - -def migrate_entity_id_dirs() -> None: - """Startup migration: rename {user_id}-1/ → {user_id}/ for existing connections.""" - if not CONNECTIONS_BASE.exists(): - return - for user_dir in list(CONNECTIONS_BASE.iterdir()): - if not user_dir.is_dir(): - continue - name = user_dir.name - # Old entity_id format was "{user_id}-1" — strip the suffix - if name.endswith("-1"): - new_name = name[:-2] - new_dir = CONNECTIONS_BASE / new_name - if not new_dir.exists(): - try: - user_dir.rename(new_dir) - logger.info("Migrated WeChat dir: %s → %s", name, new_name) - except Exception as e: - logger.error("Failed to migrate WeChat dir %s: %s", name, e) - - -# --- WeChatConnection (one per human user) --- - - -class WeChatConnection: - """A single user's WeChat connection. Keyed by user_id.""" - - def __init__(self, user_id: str, delivery_fn: DeliveryFn | None = None) -> None: - self.user_id = user_id - self._delivery_fn = delivery_fn - self._credentials: WeChatCredentials | None = None - self._context_tokens: dict[str, str] = {} - self._sync_buf: str = "" - self._poll_task: asyncio.Task | None = None - self._routing = RoutingConfig() - # @@@no-proxy — trust_env=False prevents httpx from inheriting - # http_proxy/all_proxy which causes bimodal latency on long-poll. - self._http = httpx.AsyncClient( - timeout=httpx.Timeout(LONG_POLL_TIMEOUT_S + 5), - trust_env=False, - ) - - # Load persisted state - routing_data = _load_json(user_id, "routing.json") - if routing_data: - try: - self._routing = RoutingConfig(**routing_data) - except Exception: - pass - - ctx = _load_json(user_id, "context_tokens.json") - if ctx: - self._context_tokens = ctx - - creds_data = _load_json(user_id, "credentials.json") - if creds_data: - try: - self._credentials = WeChatCredentials(**creds_data) - logger.info("Loaded WeChat credentials for user=%s", user_id[:12]) - except Exception as e: - logger.error("Invalid WeChat credentials for %s: %s", user_id[:12], e) - - @property - def connected(self) -> bool: - return self._credentials is not None - - @property - def polling(self) -> bool: - return self._poll_task is not None and not self._poll_task.done() - - @property - def routing(self) -> RoutingConfig: - return self._routing - - def set_routing(self, config: RoutingConfig) -> None: - self._routing = config - _save_json(self.user_id, "routing.json", config.model_dump()) - - def get_state(self) -> dict: - if not self._credentials: - return {"connected": False, "routing": self._routing.model_dump()} - return { - "connected": True, - "polling": self.polling, - "account_id": self._credentials.account_id, - "user_id": self._credentials.user_id, - "contact_count": len(self._context_tokens), - "contacts": self.list_contacts(), - "routing": self._routing.model_dump(), - } - - def list_contacts(self) -> list[dict[str, str]]: - return [{"user_id": uid, "display_name": uid.split("@")[0] or uid} for uid in self._context_tokens] - - # --- QR Login --- - - async def get_qr_code(self) -> dict: - url = f"{DEFAULT_BASE_URL}/ilink/bot/get_bot_qrcode?bot_type={BOT_TYPE}" - resp = await self._http.get(url, timeout=10) - resp.raise_for_status() - data = resp.json() - return {"qrcode": data["qrcode"], "qrcode_img_url": data["qrcode_img_content"]} - - async def poll_qr_status(self, qrcode: str) -> dict: - url = f"{DEFAULT_BASE_URL}/ilink/bot/get_qrcode_status?qrcode={qrcode}" - try: - resp = await self._http.get( - url, - headers={"iLink-App-ClientVersion": "1"}, - timeout=LONG_POLL_TIMEOUT_S + 5, - ) - resp.raise_for_status() - data = resp.json() - except httpx.TimeoutException: - return {"status": "wait"} - - status = data.get("status", "wait") - if status == "confirmed": - bot_token = data.get("bot_token") - bot_id = data.get("ilink_bot_id") - if not bot_token or not bot_id: - return {"status": "error", "message": "Missing bot credentials in response"} - creds = WeChatCredentials( - token=bot_token, - base_url=data.get("baseurl") or DEFAULT_BASE_URL, - account_id=bot_id, - user_id=data.get("ilink_user_id", ""), - saved_at=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - ) - self._credentials = creds - _save_json(self.user_id, "credentials.json", creds.model_dump()) - logger.info("WeChat connected for user=%s account=%s", self.user_id[:12], creds.account_id) - self.start_polling() - return {"status": "confirmed", "account_id": creds.account_id} - return {"status": status} - - # --- Disconnect --- - - def disconnect(self) -> None: - self.stop_polling() - self._credentials = None - self._context_tokens.clear() - self._sync_buf = "" - _delete_file(self.user_id, "credentials.json") - _delete_file(self.user_id, "context_tokens.json") - logger.info("WeChat disconnected for user=%s", self.user_id[:12]) - - async def close(self) -> None: - """Shutdown: stop polling + close HTTP client.""" - self.stop_polling() - await self._http.aclose() - - # --- Polling --- - - def start_polling(self) -> None: - if self.polling: - return - if not self._credentials: - raise RuntimeError("Cannot start polling: not connected") - self._poll_task = asyncio.create_task(self._poll_loop()) - logger.info("WeChat polling started for user=%s", self.user_id[:12]) - - def stop_polling(self) -> None: - if self._poll_task and not self._poll_task.done(): - self._poll_task.cancel() - self._poll_task = None - - async def _deliver_message(self, msg: WeChatMessage) -> None: - """Deliver via injected callback. No circular imports.""" - if not self._delivery_fn: - logger.warning("No delivery function configured for user=%s", self.user_id[:12]) - return - if not self._routing.type or not self._routing.id: - logger.debug("WeChat message not delivered — no routing configured") - return - try: - await self._delivery_fn(self, msg) - except Exception: - logger.exception("Failed to deliver WeChat message") - - async def _poll_loop(self) -> None: - consecutive_failures = 0 - while True: - try: - messages = await self._get_updates() - consecutive_failures = 0 - for msg in messages: - logger.info("WeChat[%s] from=%s: %s", self.user_id[:8], msg.from_user_id[:20], msg.text[:60]) - asyncio.create_task(self._deliver_message(msg)) - except asyncio.CancelledError: - return - except SessionExpiredError: - logger.error("WeChat session expired for user=%s", self.user_id[:12]) - self._credentials = None - _delete_file(self.user_id, "credentials.json") - return - except Exception: - consecutive_failures += 1 - logger.exception("WeChat poll error #%d user=%s", consecutive_failures, self.user_id[:12]) - if consecutive_failures >= 3: - consecutive_failures = 0 - await asyncio.sleep(30) - else: - await asyncio.sleep(2) - - async def _get_updates(self) -> list[WeChatMessage]: - if not self._credentials: - raise RuntimeError("Not connected") - body = json.dumps( - { - "get_updates_buf": self._sync_buf, - "base_info": {"channel_version": CHANNEL_VERSION}, - } - ) - headers = _build_headers(self._credentials.token, body) - try: - resp = await self._http.post( - f"{self._credentials.base_url}/ilink/bot/getupdates", - content=body, - headers=headers, - timeout=LONG_POLL_TIMEOUT_S + 5, - ) - resp.raise_for_status() - data = resp.json() - except httpx.TimeoutException: - return [] - - if data.get("ret", 0) != 0 or data.get("errcode", 0) != 0: - errcode = data.get("errcode", 0) - errmsg = data.get("errmsg", "") - if errcode == -14: - raise SessionExpiredError("Session expired") - raise WeChatAPIError(f"getUpdates: errcode={errcode} {errmsg}") - - if data.get("get_updates_buf"): - self._sync_buf = data["get_updates_buf"] - - messages = [] - tokens_changed = False - for msg in data.get("msgs") or []: - if msg.get("message_type") != MSG_TYPE_USER: - continue - text = _extract_text(msg) - if not text: - continue - sender = msg.get("from_user_id", "unknown") - ctx_token = msg.get("context_token", "") - if ctx_token: - self._context_tokens[sender] = ctx_token - tokens_changed = True - messages.append( - WeChatMessage( - from_user_id=sender, - text=text, - context_token=ctx_token, - ) - ) - if tokens_changed: - await asyncio.to_thread(_save_json, self.user_id, "context_tokens.json", self._context_tokens) - return messages - - # --- Send --- - - async def send_message(self, to_user_id: str, text: str) -> str: - if not self._credentials: - raise RuntimeError("WeChat not connected") - context_token = self._context_tokens.get(to_user_id) - if not context_token: - raise RuntimeError(f"No context_token for {to_user_id}. The user needs to message the bot first.") - client_id = f"leon:{int(time.time())}-{random.randint(0, 0xFFFF):04x}" - body = json.dumps( - { - "msg": { - "from_user_id": "", - "to_user_id": to_user_id, - "client_id": client_id, - "message_type": MSG_TYPE_BOT, - "message_state": MSG_STATE_FINISH, - "item_list": [{"type": MSG_ITEM_TEXT, "text_item": {"text": text}}], - "context_token": context_token, - }, - "base_info": {"channel_version": CHANNEL_VERSION}, - } - ) - headers = _build_headers(self._credentials.token, body) - resp = await self._http.post( - f"{self._credentials.base_url}/ilink/bot/sendmessage", - content=body, - headers=headers, - timeout=SEND_TIMEOUT_S, - ) - resp.raise_for_status() - return client_id - - -# --- WeChatConnectionRegistry (lives on app.state) --- - - -class WeChatConnectionRegistry: - """Manages per-user WeChatConnections. Lives on app.state, not module-level.""" - - def __init__(self, delivery_fn: DeliveryFn | None = None) -> None: - self._connections: dict[str, WeChatConnection] = {} - self._delivery_fn = delivery_fn - - def get(self, user_id: str) -> WeChatConnection: - if user_id not in self._connections: - self._connections[user_id] = WeChatConnection(user_id, self._delivery_fn) - return self._connections[user_id] - - def auto_start_all(self) -> None: - """Resume polling for all users with saved credentials on disk.""" - if not CONNECTIONS_BASE.exists(): - return - for user_dir in CONNECTIONS_BASE.iterdir(): - if user_dir.is_dir() and (user_dir / "credentials.json").exists(): - conn = self.get(user_dir.name) - if conn.connected and not conn.polling: - conn.start_polling() - - def evict_duplicates(self, account_id: str, keep_user_id: str) -> None: - """@@@unique-wechat — one WeChat account → one Leon user. Last one wins.""" - for uid, conn in list(self._connections.items()): - if uid == keep_user_id: - continue - if conn._credentials and conn._credentials.account_id == account_id: - logger.info("Evicting WeChat: user=%s (same account=%s)", uid[:12], account_id[:12]) - conn.disconnect() - - if CONNECTIONS_BASE.exists(): - for user_dir in CONNECTIONS_BASE.iterdir(): - if not user_dir.is_dir() or user_dir.name == keep_user_id: - continue - data = _load_json(user_dir.name, "credentials.json") - if data and data.get("account_id") == account_id: - logger.info("Evicting persisted WeChat: user=%s", user_dir.name[:12]) - _delete_file(user_dir.name, "credentials.json") - _delete_file(user_dir.name, "context_tokens.json") - - async def shutdown(self) -> None: - """Close all connections gracefully.""" - for conn in self._connections.values(): - await conn.close() - self._connections.clear() diff --git a/backend/web/utils/helpers.py b/backend/web/utils/helpers.py index b652e04f1..436f42948 100644 --- a/backend/web/utils/helpers.py +++ b/backend/web/utils/helpers.py @@ -5,19 +5,16 @@ from fastapi import HTTPException -from backend.web.core.config import DB_PATH from sandbox.sync.state import SyncState from storage.container import StorageContainer from storage.providers.sqlite.chat_session_repo import SQLiteChatSessionRepo from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo -from storage.runtime import build_storage_container +from storage.runtime import build_storage_container, build_thread_repo SANDBOX_DB_PATH = resolve_role_db_path(SQLiteDBRole.SANDBOX) -# @@@cached-container - reuse a single StorageContainer across helper calls to avoid per-call rebuild. _cached_container: StorageContainer | None = None -_cached_container_db_path: Path | None = None def is_virtual_thread_id(thread_id: str | None) -> bool: @@ -71,11 +68,10 @@ def extract_webhook_instance_id(payload: dict[str, Any]) -> str | None: def _get_container() -> StorageContainer: - global _cached_container, _cached_container_db_path - if _cached_container is not None and _cached_container_db_path == DB_PATH: + global _cached_container + if _cached_container is not None: return _cached_container - _cached_container = build_storage_container(main_db_path=DB_PATH) - _cached_container_db_path = DB_PATH + _cached_container = build_storage_container() return _cached_container @@ -89,34 +85,15 @@ def _get_thread_repo(thread_repo=None): global _cached_thread_repo if _cached_thread_repo is not None: return _cached_thread_repo - from storage.providers.sqlite.thread_repo import SQLiteThreadRepo - - _cached_thread_repo = SQLiteThreadRepo(DB_PATH) + _cached_thread_repo = build_thread_repo() return _cached_thread_repo -def save_thread_config(thread_id: str, thread_repo=None, **fields: Any) -> None: - """Update specific fields of thread config.""" - allowed = {"sandbox_type", "cwd", "model", "observation_provider"} - updates = {k: v for k, v in fields.items() if k in allowed} - if not updates: - return - _get_thread_repo(thread_repo).update(thread_id, **updates) - - def load_thread_config(thread_id: str, thread_repo=None) -> dict[str, Any] | None: """Load thread data. Returns dict or None.""" return _get_thread_repo(thread_repo).get_by_id(thread_id) -def get_active_observation_provider() -> str | None: - """Read global observation config and return the active provider name.""" - from config.observation_loader import ObservationLoader - - config = ObservationLoader().load() - return config.active if config.active else None - - def resolve_local_workspace_path( raw_path: str | None, thread_id: str | None = None, diff --git a/backend/web/utils/serializers.py b/backend/web/utils/serializers.py index 4c070f285..082f08b44 100644 --- a/backend/web/utils/serializers.py +++ b/backend/web/utils/serializers.py @@ -38,7 +38,15 @@ def extract_text_content(raw_content: Any) -> str: def serialize_message(msg: Any) -> dict[str, Any]: """Serialize a LangChain message to a JSON-compatible dict.""" content = getattr(msg, "content", "") - metadata = getattr(msg, "metadata", None) or {} + metadata = dict(getattr(msg, "metadata", None) or {}) + additional_kwargs = getattr(msg, "additional_kwargs", None) or {} + tool_result_meta = additional_kwargs.get("tool_result_meta") + # @@@tool-result-meta-bridge - LangChain ToolMessage keeps durable tool + # metadata in additional_kwargs, but Leon display rebuild consumes + # serialized metadata. Merge the exact structured tool_result_meta here so + # checkpoint rebuild can recover blocking subagent identity honestly. + if isinstance(tool_result_meta, dict): + metadata = {**tool_result_meta, **metadata} # Strip system tags from owner HumanMessages (context-shift hints). # External HumanMessages keep their so frontend can @@ -63,4 +71,6 @@ def serialize_message(msg: Any) -> dict[str, Any]: } if metadata: result["metadata"] = metadata + if metadata.get("source") == "internal": + result["display"] = {"showing": False} return result diff --git a/config/defaults/tool_catalog.py b/config/defaults/tool_catalog.py index 294293874..1c2e67d2e 100644 --- a/config/defaults/tool_catalog.py +++ b/config/defaults/tool_catalog.py @@ -21,7 +21,9 @@ class ToolGroup(StrEnum): COMMAND = "command" WEB = "web" AGENT = "agent" + CHAT = "chat" TODO = "todo" + CRON = "cron" SKILLS = "skills" SYSTEM = "system" TASKBOARD = "taskboard" @@ -62,16 +64,26 @@ class ToolDef(BaseModel): ToolDef(name="TaskOutput", desc="获取后台任务输出", group=ToolGroup.AGENT), ToolDef(name="TaskStop", desc="停止后台任务", group=ToolGroup.AGENT), ToolDef(name="Agent", desc="启动子 Agent 执行任务", group=ToolGroup.AGENT), - ToolDef(name="SendMessage", desc="向其他 Agent 发送消息", group=ToolGroup.AGENT), + ToolDef(name="SendMessage", desc="向运行中的 Agent 发送排队消息", group=ToolGroup.AGENT), + # chat + ToolDef(name="list_chats", desc="列出当前实体可访问的聊天会话", group=ToolGroup.CHAT), + ToolDef(name="read_messages", desc="读取聊天消息并标记为已读", group=ToolGroup.CHAT), + ToolDef(name="send_message", desc="向聊天对象发送消息", group=ToolGroup.CHAT), + ToolDef(name="search_messages", desc="搜索历史聊天消息", group=ToolGroup.CHAT), # todo ToolDef(name="TaskCreate", desc="创建待办任务", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskGet", desc="获取任务详情", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskList", desc="列出所有任务", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskUpdate", desc="更新任务状态", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), + # cron — backed by existing cron_jobs substrate; off by default until explicitly enabled + ToolDef(name="CronCreate", desc="创建定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), + ToolDef(name="CronDelete", desc="删除定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), + ToolDef(name="CronList", desc="列出定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), # skills ToolDef(name="load_skill", desc="加载 Skill", group=ToolGroup.SKILLS), # system ToolDef(name="tool_search", desc="搜索可用工具", group=ToolGroup.SYSTEM), + ToolDef(name="LSP", desc="Language Server Protocol 操作", group=ToolGroup.SYSTEM, mode=ToolMode.DEFERRED, default=False), # taskboard — all off by default; enable on dedicated scheduler members ToolDef(name="ListBoardTasks", desc="列出任务板上的任务", group=ToolGroup.TASKBOARD, default=False), ToolDef(name="ClaimTask", desc="认领一个任务板任务", group=ToolGroup.TASKBOARD, default=False), diff --git a/config/env_manager.py b/config/env_manager.py deleted file mode 100644 index a5f5a6cc6..000000000 --- a/config/env_manager.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -Leon 配置管理模块 -""" - -import os -from pathlib import Path - - -class ConfigManager: - """管理 Leon 的配置""" - - def __init__(self): - self.config_dir = Path.home() / ".leon" - self.config_file = self.config_dir / "config.env" - self.config_dir.mkdir(parents=True, exist_ok=True) - - def _parse_file(self) -> dict[str, str]: - if not self.config_file.exists(): - return {} - config = {} - for line in self.config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - config[k.strip()] = v.strip() - return config - - def get(self, key: str) -> str | None: - """获取配置值""" - return self._parse_file().get(key) - - def set(self, key: str, value: str): - """设置配置值""" - config = self._parse_file() - config[key] = value - with self.config_file.open("w") as f: - for k, v in config.items(): - f.write(f"{k}={v}\n") - - def list_all(self) -> dict[str, str]: - """列出所有配置""" - return self._parse_file() - - def load_to_env(self): - """加载配置到环境变量""" - for key, value in self.list_all().items(): - if key not in os.environ: - # 规范化 OPENAI_BASE_URL:确保包含 /v1 - if key == "OPENAI_BASE_URL" and value: - value = normalize_base_url(value) - os.environ[key] = value - - -def normalize_base_url(url: str) -> str: - """ - 规范化 OpenAI 兼容 API 的 base_url - - OpenAI SDK 会在 base_url 后直接拼接 /chat/completions, - 所以 base_url 必须以 /v1 结尾。 - - Examples: - https://api.openai.com -> https://api.openai.com/v1 - https://yunwu.ai -> https://yunwu.ai/v1 - https://yunwu.ai/v1 -> https://yunwu.ai/v1 (不变) - https://example.com/api/v1 -> https://example.com/api/v1 (不变) - """ - if not url: - return url - - url = url.rstrip("/") - - # 如果已经以 /v1 结尾,不处理 - if url.endswith("/v1"): - return url - - # 如果包含 /v1/ 在中间(如 /v1/engines),不处理 - if "/v1/" in url: - return url - - # 否则补全 /v1 - return f"{url}/v1" diff --git a/config/loader.py b/config/loader.py index 7b2f3190c..3931147ff 100644 --- a/config/loader.py +++ b/config/loader.py @@ -153,7 +153,7 @@ def _load_agents_from_members(self, members_dir: Path) -> None: continue config = self.parse_agent_file(agent_md) if config: - # source_dir is already set to member_dir by parse_agent_file + config.source_dir = member_dir.resolve() self._agents[config.name] = config @staticmethod @@ -184,7 +184,7 @@ def parse_agent_file(path: Path) -> AgentConfig | None: tools=fm.get("tools", ["*"]), system_prompt=parts[2].strip(), model=fm.get("model"), - source_dir=path.resolve().parent, + source_dir=None, ) def get_agent(self, name: str) -> AgentConfig | None: @@ -422,3 +422,74 @@ def load_config( ) -> LeonSettings: """Convenience function to load runtime configuration.""" return AgentLoader(workspace_root=workspace_root).load(cli_overrides=cli_overrides) + + +def load_bundle_from_repo(agent_config_repo: Any, member_id: str) -> AgentBundle | None: + """Load agent bundle from Supabase agent_config tables. Returns None if no config found.""" + config = agent_config_repo.get_config(member_id) + if not config: + return None + + # Parse agent identity from config + agent = AgentConfig( + name=config.get("name", ""), + description=config.get("description", ""), + tools=config.get("tools", ["*"]), + system_prompt=config.get("system_prompt", ""), + model=config.get("model"), + source_dir=None, + ) + + meta = { + "status": config.get("status", "draft"), + "version": config.get("version", "0.1.0"), + "created_at": config.get("created_at", 0), + "updated_at": config.get("updated_at", 0), + } + + # Runtime from config + runtime_data = config.get("runtime") or {} + runtime = {} + for rname, rcfg in runtime_data.items(): + if isinstance(rcfg, dict): + runtime[rname] = RuntimeResourceConfig(**rcfg) + + # Rules from agent_rules table + rule_rows = agent_config_repo.list_rules(member_id) + rules = [{"name": r.get("filename", "").replace(".md", ""), "content": r.get("content", "")} for r in rule_rows] + + # Sub-agents from agent_sub_agents table + sub_agent_rows = agent_config_repo.list_sub_agents(member_id) + agents = [] + for sa in sub_agent_rows: + agents.append( + AgentConfig( + name=sa.get("name", ""), + description=sa.get("description", ""), + tools=sa.get("tools", ["*"]), + system_prompt=sa.get("system_prompt", ""), + model=sa.get("model"), + source_dir=None, + ) + ) + + # Skills from agent_skills table + skill_rows = agent_config_repo.list_skills(member_id) + skills = [{"name": s.get("name", ""), "content": s.get("content", "")} for s in skill_rows] + + # MCP from config + mcp_data = config.get("mcp") or {} + mcp = {} + for mname, mcfg in mcp_data.items(): + if isinstance(mcfg, dict): + mcp[mname] = McpServerConfig(**{k: v for k, v in mcfg.items() if k in McpServerConfig.model_fields}) + + return AgentBundle( + agent=agent, + meta=meta, + runtime=runtime, + rules=rules, + agents=agents, + skills=skills, + mcp=mcp, + ) diff --git a/config/observation_schema.py b/config/observation_schema.py index eb01acd02..3d819cf78 100644 --- a/config/observation_schema.py +++ b/config/observation_schema.py @@ -3,6 +3,8 @@ Per-provider named fields, following sandbox/config.py pattern. """ +from typing import Annotated + from pydantic import BaseModel, Field @@ -11,7 +13,7 @@ class LangfuseConfig(BaseModel): secret_key: str | None = None public_key: str | None = None - host: str | None = Field(None, description="e.g. https://cloud.langfuse.com") + host: Annotated[str | None, Field(description="e.g. https://cloud.langfuse.com")] = None class LangSmithConfig(BaseModel): @@ -26,5 +28,5 @@ class ObservationConfig(BaseModel): """Observation configuration with per-provider named fields.""" active: str | None = Field(None, description="'langfuse' | 'langsmith' | None (disabled)") - langfuse: LangfuseConfig = Field(default_factory=LangfuseConfig) - langsmith: LangSmithConfig = Field(default_factory=LangSmithConfig) + langfuse: LangfuseConfig = Field(default_factory=lambda: LangfuseConfig()) + langsmith: LangSmithConfig = Field(default_factory=lambda: LangSmithConfig()) diff --git a/config/schema.py b/config/schema.py index 53a0cc8ea..8aff62bb7 100644 --- a/config/schema.py +++ b/config/schema.py @@ -11,7 +11,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Annotated, Any from pydantic import BaseModel, Field, field_validator @@ -26,15 +26,17 @@ class RuntimeConfig(BaseModel): """Runtime behavior configuration (non-model identity).""" - temperature: float | None = Field(None, ge=0.0, le=2.0, description="Temperature") - max_tokens: int | None = Field(None, gt=0, description="Max tokens") - model_kwargs: dict[str, Any] = Field(default_factory=dict, description="Extra kwargs for init_chat_model") - context_limit: int = Field(0, ge=0, description="Context window limit in tokens (0 = auto-detect from model)") - enable_audit_log: bool = Field(True, description="Enable audit logging") - allowed_extensions: list[str] | None = Field(None, description="Allowed extensions (None = all)") - block_dangerous_commands: bool = Field(True, description="Block dangerous commands") - block_network_commands: bool = Field(False, description="Block network commands") - queue_mode: str = Field("steer", deprecated=True, description="Deprecated. Queue mode is now determined by message timing.") + temperature: Annotated[float | None, Field(ge=0.0, le=2.0, description="Temperature")] = None + max_tokens: Annotated[int | None, Field(gt=0, description="Max tokens")] = None + model_kwargs: Annotated[dict[str, Any], Field(default_factory=dict, description="Extra kwargs for init_chat_model")] = Field( + default_factory=dict + ) + context_limit: Annotated[int, Field(ge=0, description="Context window limit in tokens (0 = auto-detect from model)")] = 0 + enable_audit_log: Annotated[bool, Field(description="Enable audit logging")] = True + allowed_extensions: Annotated[list[str] | None, Field(description="Allowed extensions (None = all)")] = None + block_dangerous_commands: Annotated[bool, Field(description="Block dangerous commands")] = True + block_network_commands: Annotated[bool, Field(description="Block network commands")] = False + queue_mode: Annotated[str, Field(deprecated=True, description="Deprecated. Queue mode is now determined by message timing.")] = "steer" # ============================================================================ @@ -48,11 +50,11 @@ class PruningConfig(BaseModel): Field names match SessionPruner constructor for direct passthrough. """ - enabled: bool = Field(True, description="Enable message pruning") - soft_trim_chars: int = Field(3000, gt=0, description="Soft-trim tool results longer than this") - hard_clear_threshold: int = Field(10000, gt=0, description="Hard-clear tool results longer than this") - protect_recent: int = Field(3, gt=0, description="Keep last N tool messages untrimmed") - trim_tool_results: bool = Field(True, description="Trim large tool results") + enabled: Annotated[bool, Field(description="Enable message pruning")] = True + soft_trim_chars: Annotated[int, Field(gt=0, description="Soft-trim tool results longer than this")] = 3000 + hard_clear_threshold: Annotated[int, Field(gt=0, description="Hard-clear tool results longer than this")] = 10000 + protect_recent: Annotated[int, Field(gt=0, description="Keep last N tool messages untrimmed")] = 3 + trim_tool_results: Annotated[bool, Field(description="Trim large tool results")] = True class CompactionConfig(BaseModel): @@ -61,17 +63,17 @@ class CompactionConfig(BaseModel): Field names match ContextCompactor constructor for direct passthrough. """ - enabled: bool = Field(True, description="Enable context compaction") - reserve_tokens: int = Field(16384, gt=0, description="Reserve space for new messages") - keep_recent_tokens: int = Field(20000, gt=0, description="Keep recent messages verbatim") - min_messages: int = Field(20, gt=0, description="Minimum messages before compaction") + enabled: Annotated[bool, Field(description="Enable context compaction")] = True + reserve_tokens: Annotated[int, Field(gt=0, description="Reserve space for new messages")] = 16384 + keep_recent_tokens: Annotated[int, Field(gt=0, description="Keep recent messages verbatim")] = 20000 + min_messages: Annotated[int, Field(gt=0, description="Minimum messages before compaction")] = 20 class MemoryConfig(BaseModel): """Memory management configuration.""" - pruning: PruningConfig = Field(default_factory=PruningConfig) - compaction: CompactionConfig = Field(default_factory=CompactionConfig) + pruning: PruningConfig = Field(default_factory=lambda: PruningConfig()) + compaction: CompactionConfig = Field(default_factory=lambda: CompactionConfig()) # ============================================================================ @@ -83,13 +85,13 @@ class ReadFileConfig(BaseModel): """Configuration for read_file tool.""" enabled: bool = True - max_file_size: int = Field(10485760, gt=0, description="Max file size in bytes (10MB)") + max_file_size: Annotated[int, Field(gt=0, description="Max file size in bytes (10MB)")] = 10485760 class FileSystemToolsConfig(BaseModel): """Configuration for filesystem tools.""" - read_file: ReadFileConfig = Field(default_factory=ReadFileConfig) + read_file: ReadFileConfig = Field(default_factory=lambda: ReadFileConfig()) write_file: bool = True edit_file: bool = True list_dir: bool = True @@ -99,20 +101,20 @@ class FileSystemConfig(BaseModel): """Configuration for filesystem middleware.""" enabled: bool = True - tools: FileSystemToolsConfig = Field(default_factory=FileSystemToolsConfig) + tools: FileSystemToolsConfig = Field(default_factory=lambda: FileSystemToolsConfig()) class GrepConfig(BaseModel): """Configuration for Grep tool.""" enabled: bool = True - max_file_size: int = Field(10485760, gt=0, description="Max file size in bytes (10MB)") + max_file_size: Annotated[int, Field(gt=0, description="Max file size in bytes (10MB)")] = 10485760 class SearchToolsConfig(BaseModel): """Configuration for search tools.""" - grep: GrepConfig = Field(default_factory=GrepConfig) + grep: GrepConfig = Field(default_factory=lambda: GrepConfig()) glob: bool = True @@ -120,52 +122,52 @@ class SearchConfig(BaseModel): """Configuration for search middleware.""" enabled: bool = True - tools: SearchToolsConfig = Field(default_factory=SearchToolsConfig) + tools: SearchToolsConfig = Field(default_factory=lambda: SearchToolsConfig()) class WebSearchConfig(BaseModel): """Configuration for web_search tool.""" enabled: bool = True - max_results: int = Field(5, gt=0, description="Max search results") - tavily_api_key: str | None = Field(None, description="Tavily API key") - exa_api_key: str | None = Field(None, description="Exa API key") - firecrawl_api_key: str | None = Field(None, description="Firecrawl API key") + max_results: Annotated[int, Field(gt=0, description="Max search results")] = 5 + tavily_api_key: Annotated[str | None, Field(description="Tavily API key")] = None + exa_api_key: Annotated[str | None, Field(description="Exa API key")] = None + firecrawl_api_key: Annotated[str | None, Field(description="Firecrawl API key")] = None class FetchConfig(BaseModel): """Configuration for Fetch tool (AI extraction mode).""" enabled: bool = True - jina_api_key: str | None = Field(None, description="Jina AI API key") + jina_api_key: Annotated[str | None, Field(description="Jina AI API key")] = None class WebToolsConfig(BaseModel): """Configuration for web tools.""" - web_search: WebSearchConfig = Field(default_factory=WebSearchConfig) - fetch: FetchConfig = Field(default_factory=FetchConfig) + web_search: WebSearchConfig = Field(default_factory=lambda: WebSearchConfig()) + fetch: FetchConfig = Field(default_factory=lambda: FetchConfig()) class WebConfig(BaseModel): """Configuration for web middleware.""" enabled: bool = True - timeout: int = Field(15, gt=0, description="Request timeout in seconds") - tools: WebToolsConfig = Field(default_factory=WebToolsConfig) + timeout: Annotated[int, Field(gt=0, description="Request timeout in seconds")] = 15 + tools: WebToolsConfig = Field(default_factory=lambda: WebToolsConfig()) class RunCommandConfig(BaseModel): """Configuration for run_command tool.""" enabled: bool = True - default_timeout: int = Field(120, gt=0, description="Default timeout in seconds") + default_timeout: Annotated[int, Field(gt=0, description="Default timeout in seconds")] = 120 class CommandToolsConfig(BaseModel): """Configuration for command tools.""" - run_command: RunCommandConfig = Field(default_factory=RunCommandConfig) + run_command: RunCommandConfig = Field(default_factory=lambda: RunCommandConfig()) command_status: bool = True @@ -173,14 +175,14 @@ class CommandConfig(BaseModel): """Configuration for command middleware.""" enabled: bool = True - tools: CommandToolsConfig = Field(default_factory=CommandToolsConfig) + tools: CommandToolsConfig = Field(default_factory=lambda: CommandToolsConfig()) class SpillBufferConfig(BaseModel): """Configuration for SpillBuffer middleware.""" enabled: bool = True - default_threshold: int = Field(50_000, gt=0, description="Default spill threshold in bytes") + default_threshold: Annotated[int, Field(gt=0, description="Default spill threshold in bytes")] = 50_000 thresholds: dict[str, int] = Field( default_factory=lambda: { "Grep": 20_000, @@ -196,11 +198,11 @@ class SpillBufferConfig(BaseModel): class ToolsConfig(BaseModel): """Tools configuration.""" - filesystem: FileSystemConfig = Field(default_factory=FileSystemConfig) - search: SearchConfig = Field(default_factory=SearchConfig) - web: WebConfig = Field(default_factory=WebConfig) - command: CommandConfig = Field(default_factory=CommandConfig) - spill_buffer: SpillBufferConfig = Field(default_factory=SpillBufferConfig) + filesystem: FileSystemConfig = Field(default_factory=lambda: FileSystemConfig()) + search: SearchConfig = Field(default_factory=lambda: SearchConfig()) + web: WebConfig = Field(default_factory=lambda: WebConfig()) + command: CommandConfig = Field(default_factory=lambda: CommandConfig()) + spill_buffer: SpillBufferConfig = Field(default_factory=lambda: SpillBufferConfig()) tool_modes: dict[str, str] = Field( default_factory=dict, description="Per-tool mode overrides: tool_name -> 'inline' | 'deferred'", @@ -215,6 +217,10 @@ class ToolsConfig(BaseModel): class MCPServerConfig(BaseModel): """Configuration for a single MCP server.""" + transport: str | None = Field( + None, + description="MCP transport type: stdio | streamable_http | sse | websocket", + ) command: str | None = Field(None, description="Command to run the MCP server") args: list[str] = Field(default_factory=list, description="Command arguments") env: dict[str, str] = Field(default_factory=dict, description="Environment variables") @@ -271,13 +277,13 @@ class LeonSettings(BaseModel): """ # Runtime behavior (replaces APIConfig model-identity fields) - runtime: RuntimeConfig = Field(default_factory=RuntimeConfig, description="Runtime behavior config") + runtime: RuntimeConfig = Field(default_factory=lambda: RuntimeConfig(), description="Runtime behavior config") # Core configuration groups - memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory management") - tools: ToolsConfig = Field(default_factory=ToolsConfig, description="Tools configuration") - mcp: MCPConfig = Field(default_factory=MCPConfig, description="MCP configuration") - skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration") + memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(), description="Memory management") + tools: ToolsConfig = Field(default_factory=lambda: ToolsConfig(), description="Tools configuration") + mcp: MCPConfig = Field(default_factory=lambda: MCPConfig(), description="MCP configuration") + skills: SkillsConfig = Field(default_factory=lambda: SkillsConfig(), description="Skills configuration") # Agent configuration system_prompt: str | None = Field(None, description="Custom system prompt") diff --git a/config/types.py b/config/types.py index 9731d5aff..0c49458fd 100644 --- a/config/types.py +++ b/config/types.py @@ -20,10 +20,12 @@ class AgentConfig(BaseModel): class McpServerConfig(BaseModel): """Single MCP server entry from .mcp.json.""" + transport: str | None = None command: str | None = None args: list[str] = Field(default_factory=list) env: dict[str, str] = Field(default_factory=dict) url: str | None = None + instructions: str | None = None allowed_tools: list[str] | None = None disabled: bool = False diff --git a/core/agents/communication/delivery.py b/core/agents/communication/delivery.py index c14ee6025..c79e4c121 100644 --- a/core/agents/communication/delivery.py +++ b/core/agents/communication/delivery.py @@ -1,22 +1,30 @@ """Chat delivery — enqueues lightweight notifications for agent threads. v3: no full message text injected. Agent must chat_read to see content. -ChatService._deliver_to_agents calls the delivery function for each -non-sender agent entity. +MessagingService._deliver_to_agents calls the delivery function for each +non-sender agent member. """ from __future__ import annotations +import functools import logging from typing import Any -from storage.contracts import EntityRow +from storage.contracts import MemberRow logger = logging.getLogger(__name__) +def _resolve_recipient_thread_id(app: Any, recipient_id: str) -> str | None: + thread = app.state.thread_repo.get_by_user_id(recipient_id) + if thread is None: + return None + return thread["id"] + + def make_chat_delivery_fn(app: Any): - """Create a delivery callback for ChatService. + """Create a delivery callback for MessagingService. Uses qm.enqueue() + wake_handler to route notifications. No more route_fn injection from backend layer. @@ -27,7 +35,8 @@ def make_chat_delivery_fn(app: Any): logger.info("[delivery] make_chat_delivery_fn: loop=%s", loop) def _deliver( - entity: EntityRow, + recipient_id: str, + member: MemberRow, content: str, sender_name: str, chat_id: str, @@ -35,27 +44,30 @@ def _deliver( sender_avatar_url: str | None = None, signal: str | None = None, ) -> None: - logger.info("[delivery] _deliver called: entity=%s, thread=%s", entity.id, entity.thread_id) + logger.info("[delivery] _deliver called: recipient=%s member=%s", recipient_id, member.id) future = asyncio.run_coroutine_threadsafe( - _async_deliver(app, entity, sender_name, chat_id, sender_id, sender_avatar_url, signal=signal), + _async_deliver(app, recipient_id, member, sender_name, chat_id, sender_id, sender_avatar_url, signal=signal), loop, ) - def _on_done(f): - exc = f.exception() - if exc: - logger.error("[delivery] async delivery failed for %s: %s", entity.id, exc, exc_info=exc) - else: - logger.info("[delivery] async delivery completed for %s", entity.id) - - future.add_done_callback(_on_done) + future.add_done_callback(functools.partial(_log_delivery_result, recipient_id)) return _deliver +def _log_delivery_result(member_id: str, f: Any) -> None: + """Done-callback for async delivery futures.""" + exc = f.exception() + if exc: + logger.error("[delivery] async delivery failed for %s: %s", member_id, exc, exc_info=exc) + else: + logger.info("[delivery] async delivery completed for %s", member_id) + + async def _async_deliver( app: Any, - entity: EntityRow, + recipient_id: str, + member: MemberRow, sender_name: str, chat_id: str, sender_id: str, @@ -64,25 +76,22 @@ async def _async_deliver( ) -> None: """Enqueue chat notification to an agent's brain thread. - @@@v3-notification-only — no message content. Agent calls chat_read to see it. + @@@v3-notification-only — no message content. Agent calls read_messages to see it. """ - # @@@context-isolation — clear inherited LangChain ContextVar so the recipient - # agent's astream doesn't inherit the sender's StreamMessagesHandler callbacks. from langchain_core.runnables.config import var_child_runnable_config var_child_runnable_config.set(None) - logger.info("[delivery] _async_deliver: entity=%s thread=%s from=%s", entity.id, entity.thread_id, sender_name) + # @@@thread-delivery-route - delivery target must come from the recipient social handle, + # never from the template default-thread shortcut. + thread_id = _resolve_recipient_thread_id(app, recipient_id) + logger.info("[delivery] _async_deliver: recipient=%s member=%s thread=%s from=%s", recipient_id, member.id, thread_id, sender_name) from core.runtime.middleware.queue.formatters import format_chat_notification - if not entity.thread_id: - logger.warning("Entity %s has no thread_id, skipping delivery", entity.id) + if not thread_id: + logger.warning("Recipient %s has no thread, skipping delivery", recipient_id) return - thread_id = entity.thread_id - - # @@@cold-wake — ensure agent + wake_handler exist before enqueue. - # Without this, enqueue on an unvisited thread has no handler to wake the agent. from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox from backend.web.services.streaming_service import _ensure_thread_handlers @@ -90,13 +99,11 @@ async def _async_deliver( agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) _ensure_thread_handlers(agent, thread_id, app) - # @@@typing-lifecycle - start typing indicator typing_tracker = getattr(app.state, "typing_tracker", None) if typing_tracker is not None: - typing_tracker.start_chat(thread_id, chat_id, entity.id) + typing_tracker.start_chat(thread_id, chat_id, recipient_id) - # Unread count for this recipient - unread_count = app.state.chat_message_repo.count_unread(chat_id, entity.id) + unread_count = app.state.messaging_service.count_unread(chat_id, recipient_id) formatted = format_chat_notification(sender_name, chat_id, unread_count, signal=signal) diff --git a/core/agents/registry.py b/core/agents/registry.py index f74f4f4ec..d6f492f34 100644 --- a/core/agents/registry.py +++ b/core/agents/registry.py @@ -1,4 +1,4 @@ -"""Agent Registry — SQLite-backed agent_id -> thread_id mapping. +"""Agent Registry — Supabase-backed agent_id -> thread_id mapping. @@@id-based — all lookups use agent_id, never name. Name is stored for display only. @@ -8,9 +8,9 @@ import asyncio from dataclasses import dataclass -from pathlib import Path +from typing import Any -from backend.web.core.storage_factory import make_agent_registry_repo +from storage.runtime import build_agent_registry_repo @dataclass @@ -23,17 +23,42 @@ class AgentEntry: subagent_type: str | None = None -class AgentRegistry: - """SQLite-backed registry mapping agent_ids to thread IDs. +class _InMemoryAgentRegistryRepo: + """Noop in-memory fallback when Supabase is unavailable (tests/CLI).""" + + def __init__(self) -> None: + self._rows: dict[str, tuple] = {} + + def register( + self, *, agent_id: str, name: str, thread_id: str, status: str, parent_agent_id: str | None = None, subagent_type: str | None = None + ) -> None: + self._rows[agent_id] = (agent_id, name, thread_id, status, parent_agent_id, subagent_type) + + def get_by_id(self, agent_id: str) -> tuple | None: + return self._rows.get(agent_id) + + def list_running_by_name(self, name: str) -> list[tuple]: + return [r for r in self._rows.values() if r[1] == name and r[3] == "running"] + + def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None) -> tuple | None: + matches = [r for r in self._rows.values() if r[1] == name and r[4] == parent_agent_id] + return matches[-1] if matches else None + + def update_status(self, agent_id: str, status: str) -> None: + if agent_id in self._rows: + old = self._rows[agent_id] + self._rows[agent_id] = (old[0], old[1], old[2], status, old[4], old[5]) - Persisted at ~/.leon/agent_registry.db - """ + def list_running(self) -> list[tuple]: + return [r for r in self._rows.values() if r[3] == "running"] - DEFAULT_DB_PATH = None # resolved by storage_factory - def __init__(self, db_path: Path | None = None): +class AgentRegistry: + """Supabase-backed registry mapping agent_ids to thread IDs.""" + + def __init__(self, repo: Any = None): self._lock = asyncio.Lock() - self._repo = make_agent_registry_repo() + self._repo = repo or build_agent_registry_repo() async def register(self, entry: AgentEntry) -> None: async with self._lock: @@ -59,6 +84,33 @@ async def get_by_id(self, agent_id: str) -> AgentEntry | None: subagent_type=row[5], ) + async def list_running_by_name(self, name: str) -> list[AgentEntry]: + rows = self._repo.list_running_by_name(name) + return [ + AgentEntry( + agent_id=row[0], + name=row[1], + thread_id=row[2], + status=row[3], + parent_agent_id=row[4], + subagent_type=row[5], + ) + for row in rows + ] + + async def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None) -> AgentEntry | None: + row = self._repo.get_latest_by_name_and_parent(name, parent_agent_id) + if row is None: + return None + return AgentEntry( + agent_id=row[0], + name=row[1], + thread_id=row[2], + status=row[3], + parent_agent_id=row[4], + subagent_type=row[5], + ) + async def update_status(self, agent_id: str, status: str) -> None: async with self._lock: self._repo.update_status(agent_id, status) diff --git a/core/agents/service.py b/core/agents/service.py index e7baff89b..a35da5d37 100644 --- a/core/agents/service.py +++ b/core/agents/service.py @@ -11,89 +11,305 @@ import asyncio import json import logging +import os +import time import uuid +from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, cast +from config.loader import AgentLoader from core.agents.registry import AgentEntry, AgentRegistry -from core.runtime.middleware.queue.formatters import format_background_notification -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.middleware.queue.formatters import ( + format_agent_message, + format_background_notification, + format_progress_notification, +) +from core.runtime.permissions import ToolPermissionContext +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.state import BootstrapConfig, ToolUseContext +from core.runtime.tool_result import tool_error, tool_permission_request, tool_success logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.runtime.agent import LeonAgent -AGENT_SCHEMA = { - "name": "Agent", - "description": ( - "Launch a new agent to handle complex tasks autonomously. " - "Use subagent_type to select a specialized agent, or omit for default. " - "Agents run independently with their own tool stack." + +EventEmitter = Callable[[dict[str, Any]], Awaitable[None] | None] +ChildAgentFactory = Callable[..., "LeonAgent"] + + +def _resolve_default_child_agent_factory() -> ChildAgentFactory: + from core.runtime.agent import create_leon_agent + + return cast(ChildAgentFactory, create_leon_agent) + + +# ── Sub-agent tool filtering (CC alignment) ────────────────────────────────── +# Tools that sub-agents must never access (prevents controlling parent). +AGENT_DISALLOWED: set[str] = {"TaskOutput", "TaskStop", "Agent"} + +# Per-type allowed tool sets. Tools not in the set are blocked. +EXPLORE_ALLOWED: set[str] = {"Read", "Grep", "Glob", "list_dir", "WebSearch", "WebFetch", "tool_search"} +PLAN_ALLOWED: set[str] = EXPLORE_ALLOWED # plan agents are also read-only +BASH_ALLOWED: set[str] = {"Bash", "Read", "Grep", "Glob", "list_dir", "tool_search"} + + +def _get_tool_filters(subagent_type: str) -> tuple[set[str], set[str] | None]: + """Return (extra_blocked_tools, allowed_tools) for a sub-agent type. + + For explore/plan/bash: use allowed_tools whitelist (ToolRegistry skips unmatched). + For general: only block AGENT_DISALLOWED, no whitelist. + """ + agent_type = subagent_type.lower() + allowed_map: dict[str, set[str]] = { + "explore": EXPLORE_ALLOWED, + "plan": PLAN_ALLOWED, + "bash": BASH_ALLOWED, + } + + if agent_type in allowed_map: + return AGENT_DISALLOWED, allowed_map[agent_type] + + # general: only block parent-controlling tools, no whitelist + return AGENT_DISALLOWED, None + + +def _get_subagent_agent_name(subagent_type: str) -> str: + return subagent_type.lower() + + +def _resolve_subagent_model( + workspace_root: Path, + subagent_type: str, + requested_model: str | None, + inherited_model: str, + fallback_model: str | None = None, +) -> str: + def _is_inherit_marker(value: str | None) -> bool: + return value is None or value.lower() in {"default", "inherit"} + + env_model = os.getenv("CLAUDE_CODE_SUBAGENT_MODEL") + if env_model: + return env_model + if requested_model and not _is_inherit_marker(requested_model): + return requested_model + + agent_def = AgentLoader(workspace_root=workspace_root).load_all_agents().get(_get_subagent_agent_name(subagent_type)) + if agent_def and agent_def.model: + return agent_def.model + + if inherited_model and not _is_inherit_marker(inherited_model): + return inherited_model + if fallback_model and not _is_inherit_marker(fallback_model): + return fallback_model + return inherited_model + + +def _normalize_child_workspace_prompt(prompt: str, workspace_root: Path) -> str: + workspace_text = str(workspace_root) + for suffix in ("current working directory", "working directory"): + prompt = prompt.replace(f"{workspace_text}/{suffix}", workspace_text) + return prompt + + +def _filter_fork_messages(messages: list) -> list: + """Filter parent messages for forkContext sub-agent spawning. + + Equivalent to CC's yF0: removes assistant messages whose tool_use blocks + have no matching tool_result in a subsequent user message (orphan tool_use). + Orphan tool_use blocks cause Anthropic API validation errors. + """ + # Collect all tool_use_ids that have a corresponding tool_result + answered: set[str] = set() + for msg in messages: + # ToolMessage or user message with tool_result content + tool_call_id = getattr(msg, "tool_call_id", None) + if tool_call_id: + answered.add(tool_call_id) + content = getattr(msg, "content", None) + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tid = block.get("tool_use_id") or block.get("tool_call_id") + if tid: + answered.add(tid) + + result = [] + for msg in messages: + content = getattr(msg, "content", None) + if isinstance(content, list): + tool_uses = [b for b in content if isinstance(b, dict) and b.get("type") == "tool_use"] + if tool_uses and any(b.get("id") not in answered for b in tool_uses): + continue # skip assistant message with unanswered tool_use + result.append(msg) + return result + + +AGENT_SCHEMA = make_tool_schema( + name="Agent", + description=( + "Launch a sub-agent for independent task execution. " + "Types: explore (read-only codebase search), plan (architecture design, read-only), " + "bash (shell commands only), general (broad tool access except Agent, TaskOutput, and TaskStop). " + "Use for: multi-step tasks, parallel work, tasks needing isolation. " + "Do NOT use for simple file reads or single grep searches — use the tools directly." ), - "parameters": { - "type": "object", - "properties": { - "subagent_type": { - "type": "string", - "description": "Type of agent to spawn (e.g. 'Explore', 'Coder'). Omit for general-purpose.", - }, - "prompt": { - "type": "string", - "description": "Task for the agent", - }, - "name": { - "type": "string", - "description": "Name for the agent (used for SendMessage routing)", - }, - "description": { - "type": "string", - "description": ( - "Short description of what agent will do. Required when run_in_background is true; " - "shown in the background task indicator." - ), - }, - "run_in_background": { - "type": "boolean", - "default": False, - "description": "Fire-and-forget: return immediately with task_id instead of waiting for completion", - }, - "max_turns": { - "type": "integer", - "description": "Maximum turns the agent can take", - }, + properties={ + "subagent_type": { + "type": "string", + "enum": ["explore", "plan", "general", "bash"], + "description": "Type of agent to spawn. Omit for general-purpose.", + }, + "prompt": { + "type": "string", + "description": "Task for the agent", + }, + "name": { + "type": "string", + "description": "Optional display name for the spawned agent", + }, + "description": { + "type": "string", + "description": ( + "Short description of what agent will do. Required when run_in_background is true; shown in the background task indicator." + ), + }, + "run_in_background": { + "type": "boolean", + "default": False, + "description": "Fire-and-forget: return immediately with task_id instead of waiting for completion", + }, + "model": { + "type": "string", + "description": "Optional sub-agent model override. Priority: env > this field > agent frontmatter > inherit.", + }, + "max_turns": { + "type": "integer", + "description": "Maximum turns the agent can take", + }, + "fork_context": { + "type": "boolean", + "default": False, + "description": ( + "Inherit parent conversation history as read-only context. " + "Use when the sub-agent needs background from the parent's work. " + "Adds a ### ENTERING SUB-AGENT ROUTINE ### marker so the sub-agent " + "knows which messages are context vs its actual task." + ), }, - "required": ["prompt"], }, -} - -TASK_OUTPUT_SCHEMA = { - "name": "TaskOutput", - "description": "Get the output of a background agent task by its task_id.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID returned when starting a background agent", - }, + required=["prompt", "description"], +) + +TASK_OUTPUT_SCHEMA = make_tool_schema( + name="TaskOutput", + description=( + "Get output of a background task (agent or bash). Blocks until task completes by default. Returns full text output or error." + ), + properties={ + "task_id": { + "type": "string", + "description": "The task ID returned when starting a background agent", + }, + "block": { + "type": "boolean", + "default": True, + "description": "Whether to wait for completion. Use false for a non-blocking status check.", + }, + "timeout": { + "type": "integer", + "default": 30000, + "minimum": 0, + "maximum": 600000, + "description": "Maximum wait time in milliseconds when block=true (default: 30000, max: 600000).", + }, + }, + required=["task_id"], +) + +TASK_STOP_SCHEMA = make_tool_schema( + name="TaskStop", + description="Cancel a running background task. Sends cancellation signal; task may take a moment to stop.", + properties={ + "task_id": { + "type": "string", + "description": "The task ID to stop", + }, + }, + required=["task_id"], +) + +SEND_MESSAGE_SCHEMA = make_tool_schema( + name="SendMessage", + description="Send a queued message to another running agent by name. Delivered before that agent's next model turn.", + properties={ + "target_name": { + "type": "string", + "description": "Display name of the running target agent", + }, + "message": { + "type": "string", + "description": "Message body to deliver", + }, + "sender_name": { + "type": "string", + "description": "Optional sender label for the delivered message", }, - "required": ["task_id"], }, -} - -TASK_STOP_SCHEMA = { - "name": "TaskStop", - "description": "Stop a running background agent task.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to stop", + required=["target_name", "message"], +) + +ASK_USER_QUESTION_SCHEMA = make_tool_schema( + name="AskUserQuestion", + description=( + "Ask the user one or more structured questions when progress requires their choice or clarification. " + "Use for genuine ambiguity, preference selection, or approval that needs an explicit answer before continuing." + ), + properties={ + "questions": { + "type": "array", + "description": "Questions to present to the user.", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "header": {"type": "string", "description": "Short UI label for the question."}, + "question": {"type": "string", "description": "Full question text shown to the user."}, + "multiSelect": { + "type": "boolean", + "default": False, + "description": "Whether the user may pick multiple options.", + }, + "options": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "label": {"type": "string"}, + "description": {"type": "string"}, + "preview": {"type": "string"}, + }, + "required": ["label", "description"], + }, + }, + }, + "required": ["header", "question", "options"], }, }, - "required": ["task_id"], + "annotations": { + "type": "object", + "description": "Optional structured annotations kept with the question request.", + }, + "metadata": { + "type": "object", + "description": "Optional metadata describing the source of the question request.", + }, }, -} + required=["questions"], +) class _RunningTask: @@ -150,6 +366,33 @@ def get_result(self) -> str | None: BackgroundRun = _RunningTask | _BashBackgroundRun +def _background_run_running_message(running: BackgroundRun) -> str: + return "Command is still running." if isinstance(running, _BashBackgroundRun) else "Agent is still running." + + +def _background_run_result_status(result: str | None) -> str: + return "error" if (result and result.startswith("")) else "completed" + + +async def _wait_for_background_run(running: BackgroundRun, timeout_ms: int) -> bool: + timeout_s = max(timeout_ms, 0) / 1000.0 + if isinstance(running, _RunningTask): + try: + await asyncio.wait_for(asyncio.shield(running.task), timeout=timeout_s) + return True + except TimeoutError: + return running.is_done + + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout_s + while True: + if running.is_done: + return True + if loop.time() >= deadline: + return False + await asyncio.sleep(0.1) + + class AgentService: """Registers Agent, TaskOutput, TaskStop tools into ToolRegistry. @@ -170,11 +413,23 @@ def __init__( model_name: str, queue_manager: Any | None = None, shared_runs: dict[str, BackgroundRun] | None = None, + background_progress_interval_s: float = 30.0, + thread_repo: Any = None, + member_repo: Any = None, + web_app: Any = None, + child_agent_factory: ChildAgentFactory | None = None, ): self._agent_registry = agent_registry self._workspace_root = workspace_root self._model_name = model_name self._queue_manager = queue_manager + self._background_progress_interval_s = background_progress_interval_s + self._thread_repo = thread_repo + self._member_repo = member_repo + self._web_app = web_app + self._child_agent_factory = child_agent_factory or _resolve_default_child_agent_factory() + self._parent_bootstrap: BootstrapConfig | None = None + self._parent_tool_context: Any | None = None # Shared with CommandService so TaskOutput covers both bash and agent runs. self._tasks: dict[str, BackgroundRun] = shared_runs if shared_runs is not None else {} @@ -185,6 +440,7 @@ def __init__( schema=AGENT_SCHEMA, handler=self._handle_agent, source="AgentService", + search_hint="launch sub-agent spawn parallel task independent", ) ) tool_registry.register( @@ -194,6 +450,9 @@ def __init__( schema=TASK_OUTPUT_SCHEMA, handler=self._handle_task_output, source="AgentService", + search_hint="get background task output result poll", + is_read_only=True, + is_concurrency_safe=True, ) ) tool_registry.register( @@ -203,8 +462,74 @@ def __init__( schema=TASK_STOP_SCHEMA, handler=self._handle_task_stop, source="AgentService", + search_hint="stop cancel background task agent", + ) + ) + tool_registry.register( + ToolEntry( + name="SendMessage", + mode=ToolMode.INLINE, + schema=SEND_MESSAGE_SCHEMA, + handler=self._handle_send_message, + source="AgentService", + search_hint="send message running agent delivery queue", ) ) + tool_registry.register( + ToolEntry( + name="AskUserQuestion", + mode=ToolMode.INLINE, + schema=ASK_USER_QUESTION_SCHEMA, + handler=self._handle_ask_user_question, + source="AgentService", + search_hint="ask user question clarification choice preference", + is_read_only=True, + is_concurrency_safe=True, + ) + ) + + @staticmethod + def _normalize_child_sandbox(sandbox_type: str | None) -> str | None: + return None if not sandbox_type or sandbox_type == "local" else sandbox_type + + def _ensure_subagent_thread_metadata( + self, + *, + thread_id: str, + parent_thread_id: str | None, + agent_name: str, + model_name: str, + ) -> None: + if self._thread_repo is None or self._member_repo is None or not parent_thread_id: + return + existing_thread = self._thread_repo.get_by_id(thread_id) + if existing_thread is not None: + return + + parent_thread = self._thread_repo.get_by_id(parent_thread_id) + if parent_thread is None: + return + + member_id = parent_thread["member_id"] + member = self._member_repo.get_by_id(member_id) + if member is None: + return + + created_at = time.time() + branch_index = self._thread_repo.get_next_branch_index(member_id) + sandbox_type = parent_thread.get("sandbox_type") or "local" + cwd = parent_thread.get("cwd") + self._thread_repo.create( + thread_id=thread_id, + member_id=member_id, + user_id=thread_id, + sandbox_type=sandbox_type, + cwd=cwd, + created_at=created_at, + model=model_name or parent_thread.get("model"), + is_main=False, + branch_index=branch_index, + ) async def _handle_agent( self, @@ -213,15 +538,22 @@ async def _handle_agent( name: str | None = None, description: str | None = None, run_in_background: bool = False, + model: str | None = None, max_turns: int | None = None, - ) -> str: + fork_context: bool = False, + tool_context: ToolUseContext | None = None, + ) -> Any: """Spawn an independent LeonAgent and run it with the given prompt.""" from sandbox.thread_context import get_current_thread_id task_id = uuid.uuid4().hex[:8] agent_name = name or f"agent-{task_id}" - thread_id = f"subagent-{task_id}" parent_thread_id = get_current_thread_id() + existing_child = None + lookup_existing_child = getattr(self._agent_registry, "get_latest_by_name_and_parent", None) + if name and parent_thread_id and lookup_existing_child is not None: + existing_child = await lookup_existing_child(name, parent_thread_id) + thread_id = existing_child.thread_id if existing_child is not None and existing_child.status != "running" else f"subagent-{task_id}" # Register in AgentRegistry immediately entry = AgentEntry( @@ -233,6 +565,12 @@ async def _handle_agent( subagent_type=subagent_type, ) await self._agent_registry.register(entry) + self._ensure_subagent_thread_metadata( + thread_id=thread_id, + parent_thread_id=parent_thread_id, + agent_name=agent_name, + model_name=model or self._model_name, + ) # Create async task (independent LeonAgent runs inside) task = asyncio.create_task( @@ -243,33 +581,57 @@ async def _handle_agent( prompt, subagent_type, max_turns, + model=model, description=description or "", run_in_background=run_in_background, + fork_context=fork_context, + parent_tool_context=tool_context, ) ) if run_in_background: # True fire-and-forget: track in self._tasks for TaskOutput/TaskStop running = _RunningTask(task=task, agent_id=task_id, thread_id=thread_id, description=description or "") self._tasks[task_id] = running - return json.dumps( - { + return tool_success( + json.dumps( + { + "task_id": task_id, + "agent_name": agent_name, + "thread_id": thread_id, + "status": "running", + "message": "Agent started in background. Use TaskOutput to get result.", + }, + ensure_ascii=False, + ), + metadata={ "task_id": task_id, - "agent_name": agent_name, - "thread_id": thread_id, - "status": "running", - "message": "Agent started in background. Use TaskOutput to get result.", + "subagent_thread_id": thread_id, + "description": description or agent_name, }, - ensure_ascii=False, ) # Default: parent blocks until sub-agent completes (does not block frontend event loop) try: result = await task await self._agent_registry.update_status(task_id, "completed") - return result + return tool_success( + result, + metadata={ + "task_id": task_id, + "subagent_thread_id": thread_id, + "description": description or agent_name, + }, + ) except Exception as e: await self._agent_registry.update_status(task_id, "error") - return f"Agent failed: {e}" + return tool_error( + f"Agent failed: {e}", + metadata={ + "task_id": task_id, + "subagent_thread_id": thread_id, + "description": description or agent_name, + }, + ) async def _run_agent( self, @@ -279,8 +641,11 @@ async def _run_agent( prompt: str, subagent_type: str, max_turns: int | None, + model: str | None = None, description: str = "", run_in_background: bool = False, + fork_context: bool = False, + parent_tool_context: ToolUseContext | None = None, ) -> str: """Create and run an independent LeonAgent, collect its text output.""" # Isolate this sub-agent from the parent's LangChain callback chain. @@ -294,48 +659,164 @@ async def _run_agent( var_child_runnable_config.set(None) - # Lazy import avoids circular dependency (agent.py imports AgentService) - from core.runtime.agent import create_leon_agent from sandbox.thread_context import get_current_thread_id, set_current_thread_id parent_thread_id = get_current_thread_id() + self._ensure_subagent_thread_metadata( + thread_id=thread_id, + parent_thread_id=parent_thread_id, + agent_name=agent_name, + model_name=model or self._model_name, + ) # emit_fn is set if EventBus is available; used for task lifecycle SSE events - emit_fn = None + emit_fn: EventEmitter | None = None try: from backend.web.event_bus import get_event_bus - event_bus = get_event_bus() - emit_fn = event_bus.make_emitter( - thread_id=parent_thread_id, - agent_id=task_id, - agent_name=agent_name, - ) + if parent_thread_id: + event_bus = get_event_bus() + emit_fn = event_bus.make_emitter( + thread_id=parent_thread_id, + agent_id=task_id, + agent_name=agent_name, + ) except ImportError: pass # backend not available in standalone core usage - agent = None + agent: LeonAgent | None = None + progress_task: asyncio.Task | None = None + progress_stop: asyncio.Event | None = None + child_bootstrap_start_cost = 0.0 + child_bootstrap_start_tool_duration_ms = 0 try: - agent = create_leon_agent( - model_name=self._model_name, - workspace_root=self._workspace_root, - verbose=False, - ) + # Sub-agent context trimming: each spawn creates a fresh LeonAgent + # with its own _build_system_prompt(). No CLAUDE.md content or + # gitStatus is injected into the prompt pipeline (core/runtime/prompts + # has no such injection). Therefore explore/plan/bash sub-agents + # already run lightweight — no extra trimming is needed. + # + # Try to use context fork from parent agent's BootstrapConfig. + # Falls back to create_leon_agent when bootstrap is not available. + # Compute tool filtering for this sub-agent type + extra_blocked, allowed = _get_tool_filters(subagent_type) + agent_name_for_role = _get_subagent_agent_name(subagent_type) + + try: + from core.runtime.fork import create_subagent_context + from core.runtime.fork import fork_context as fork_bootstrap + + # Parent bootstrap is stored on the ToolUseContext or agent instance. + # AgentService stores workspace_root and model_name directly; use those + # to check if a richer bootstrap is available via a shared reference. + # _parent_bootstrap is injected by LeonAgent when building AgentService. + parent_bootstrap = getattr(self, "_parent_bootstrap", None) + child_tool_context = None + if parent_tool_context is not None: + child_tool_context = create_subagent_context(parent_tool_context) + child_bootstrap = child_tool_context.bootstrap + elif parent_bootstrap is not None: + child_bootstrap = fork_bootstrap(parent_bootstrap) + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + child_bootstrap.model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=child_bootstrap.workspace_root, + sandbox=self._normalize_child_sandbox(getattr(child_bootstrap, "sandbox_type", None)), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) + else: + raise AttributeError("no parent bootstrap") + child_bootstrap_start_cost = float(getattr(child_bootstrap, "total_cost_usd", 0.0)) + child_bootstrap_start_tool_duration_ms = int(getattr(child_bootstrap, "total_tool_duration_ms", 0)) + if parent_tool_context is not None: + # @@@sa-05-subagent-policy-resolution + # Role-specific tool envelopes and model priority order must + # be resolved explicitly here instead of leaking through + # prompt text or whichever defaults happen to win later. + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + child_bootstrap.model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=child_bootstrap.workspace_root, + sandbox=self._normalize_child_sandbox(getattr(child_bootstrap, "sandbox_type", None)), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) + # @@@sa-04-child-bootstrap-wiring + # Keep the forked bootstrap/context handoff behind an explicit + # LeonAgent API so AgentService stops reaching into QueryLoop + # internals directly. + assert agent is not None + agent.apply_forked_child_context( + child_bootstrap, + tool_context=child_tool_context, + ) + except (AttributeError, ImportError): + inherited_model = getattr(parent_tool_context.bootstrap, "model_name", None) if parent_tool_context else None + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + inherited_model or self._model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=self._workspace_root, + sandbox=self._normalize_child_sandbox( + getattr(parent_tool_context.bootstrap, "sandbox_type", None) if parent_tool_context else None + ), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) # In async context LeonAgent defers checkpointer init; call ainit() to # ensure state is persisted (and loadable via GET /api/threads/{thread_id}). + assert agent is not None await agent.ainit() + # @@@subagent-prompt-path-sanitize - Parent models sometimes satisfy + # "use absolute paths" by appending natural-language cwd labels onto the + # real workspace path. Normalize the obvious fake suffix before dispatch. + child_workspace_root = Path(getattr(agent, "workspace_root", self._workspace_root)) + prompt = _normalize_child_workspace_prompt(prompt, child_workspace_root) + + if parent_thread_id and parent_thread_id != thread_id: + from sandbox.manager import bind_thread_to_existing_thread_lease + + bind_thread_to_existing_thread_lease(thread_id, parent_thread_id) # Wire child agent events to the parent's EventBus subscription # so the parent SSE stream shows sub-agent activity. if emit_fn is not None: - if hasattr(agent, "runtime") and hasattr(agent.runtime, "bind_thread"): - agent.runtime.bind_thread(activity_sink=emit_fn) + runtime = getattr(agent, "runtime", None) + if runtime is not None and hasattr(runtime, "bind_thread"): + runtime.bind_thread(activity_sink=emit_fn) set_current_thread_id(thread_id) # Notify frontend: task started if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_start", "data": json.dumps( @@ -350,38 +831,95 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission config = {"configurable": {"thread_id": thread_id}} output_parts: list[str] = [] + latest_progress = description or agent_name + + if run_in_background and self._queue_manager and parent_thread_id and self._background_progress_interval_s > 0: + progress_stop = asyncio.Event() + progress_task = asyncio.create_task( + self._emit_background_progress( + task_id=task_id, + agent_name=agent_name, + parent_thread_id=parent_thread_id, + latest_progress=lambda: latest_progress, + stop_event=progress_stop, + ) + ) - async for chunk in agent.agent.astream( - {"messages": [{"role": "user", "content": prompt}]}, - config=config, - stream_mode="updates", - ): - for _, node_update in chunk.items(): - if not isinstance(node_update, dict): - continue - msgs = node_update.get("messages", []) - if not isinstance(msgs, list): - msgs = [msgs] - for msg in msgs: - if msg.__class__.__name__ == "AIMessage": - content = getattr(msg, "content", "") - if isinstance(content, str) and content: - output_parts.append(content) - elif isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - text = block.get("text", "") - if text: - output_parts.append(text) + # Build initial input — with or without forked parent context + if fork_context: + from sandbox.thread_context import get_current_messages + + # @@@pt-04-fork-context-source + # The Agent tool already has an explicit parent ToolUseContext on + # the live ToolRunner path. Forked sub-agents must prefer that + # concrete message snapshot over ambient ContextVar state, or the + # direct runner path silently drops parent context. + parent_msgs = list(parent_tool_context.messages) if parent_tool_context is not None else get_current_messages() + fork_marker = ( + "\n\n### ENTERING SUB-AGENT ROUTINE ###\n" + "Messages above are from the parent thread (read-only context).\n" + "Only complete the specific task assigned below.\n\n" + ) + initial_messages: list = [ + *_filter_fork_messages(parent_msgs), + {"role": "user", "content": fork_marker + prompt}, + ] + else: + initial_messages = [{"role": "user", "content": prompt}] + + if self._web_app is not None: + from backend.web.services.streaming_service import run_child_thread_live + + result = await run_child_thread_live( + agent, + thread_id, + prompt, + self._web_app, + input_messages=initial_messages, + ) + if result: + output_parts.append(result) + latest_progress = self._summarize_progress(result, description or agent_name) + else: + async for chunk in agent.agent.astream( + {"messages": initial_messages}, + config=config, + stream_mode="updates", + ): + for _, node_update in chunk.items(): + if not isinstance(node_update, dict): + continue + msgs = node_update.get("messages", []) + if not isinstance(msgs, list): + msgs = [msgs] + for msg in msgs: + if msg.__class__.__name__ == "AIMessage": + content = getattr(msg, "content", "") + if isinstance(content, str) and content: + output_parts.append(content) + latest_progress = self._summarize_progress(content, description or agent_name) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + if text: + output_parts.append(text) + latest_progress = self._summarize_progress(text, description or agent_name) await self._agent_registry.update_status(task_id, "completed") result = "\n".join(output_parts) or "(Agent completed with no text output)" + if progress_stop is not None: + progress_stop.set() + if progress_task is not None: + await progress_task # Notify frontend: task done if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -393,6 +931,8 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission # Queue notification only for background runs — blocking callers already # received the result as the tool's return value; sending a notification # would trigger a spurious new parent turn. @@ -402,18 +942,23 @@ async def _run_agent( task_id=task_id, status="completed", summary=label, + result=result, description=label, ) self._queue_manager.enqueue(notification, parent_thread_id, notification_type="agent") return result except Exception: + if progress_stop is not None: + progress_stop.set() + if progress_task is not None: + await progress_task logger.exception("[AgentService] Agent %s failed", agent_name) await self._agent_registry.update_status(task_id, "error") # Notify frontend: task error if emit_fn is not None: try: - await emit_fn( + emission = emit_fn( { "event": "task_error", "data": json.dumps( @@ -425,6 +970,8 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: pass if run_in_background and self._queue_manager and parent_thread_id: @@ -433,6 +980,7 @@ async def _run_agent( task_id=task_id, status="error", summary=label, + result="Agent failed", description=label, ) self._queue_manager.enqueue(notification, parent_thread_id, notification_type="agent") @@ -440,37 +988,252 @@ async def _run_agent( finally: if agent is not None: try: - agent.close() + self._merge_child_bootstrap_accumulators( + getattr(self, "_parent_bootstrap", None), + getattr(agent, "_bootstrap", None), + child_bootstrap_start_cost=child_bootstrap_start_cost, + child_bootstrap_start_tool_duration_ms=child_bootstrap_start_tool_duration_ms, + ) + if hasattr(agent, "_agent_service") and hasattr(agent._agent_service, "cleanup_background_runs"): + await agent._agent_service.cleanup_background_runs() + # @@@web-child-persistence - web child threads are user-visible + # thread surfaces. Closing the LeonAgent here marks runtime + # terminated and drops its live/checkpoint bridge right after + # completion, so the child tab collapses to an empty shell. + if self._web_app is None: + # @@@subagent-sandbox-close-skip - Child agents can share the + # parent's lease; closing the child sandbox here can pause the + # shared lease mid-owner-turn. + agent.close(cleanup_sandbox=False) except Exception: pass - async def _handle_task_output(self, task_id: str) -> str: + @staticmethod + def _merge_child_bootstrap_accumulators( + parent_bootstrap: Any, + child_bootstrap: Any, + *, + child_bootstrap_start_cost: float, + child_bootstrap_start_tool_duration_ms: int, + ) -> None: + if parent_bootstrap is None or child_bootstrap is None or parent_bootstrap is child_bootstrap: + return + # @@@sa-03-bootstrap-rollup + # Sub-agent loops start from a forked bootstrap snapshot. At join time we + # need to preserve both the parent's concurrent growth and the child's + # post-fork delta instead of letting one side overwrite the other. + child_cost_delta = max( + 0.0, + float(getattr(child_bootstrap, "total_cost_usd", 0.0)) - child_bootstrap_start_cost, + ) + child_tool_duration_delta = max( + 0, + int(getattr(child_bootstrap, "total_tool_duration_ms", 0)) - child_bootstrap_start_tool_duration_ms, + ) + parent_bootstrap.total_cost_usd = float(getattr(parent_bootstrap, "total_cost_usd", 0.0)) + child_cost_delta + parent_bootstrap.total_tool_duration_ms = int(getattr(parent_bootstrap, "total_tool_duration_ms", 0)) + child_tool_duration_delta + + @staticmethod + def _summarize_progress(text: str, fallback: str) -> str: + collapsed = " ".join(text.split()).strip() + if not collapsed: + return fallback + return collapsed[:120] + + async def _emit_background_progress( + self, + *, + task_id: str, + agent_name: str, + parent_thread_id: str, + latest_progress: Any, + stop_event: asyncio.Event, + ) -> None: + # @@@sa-06-progress-loop - keep prompt-facing coordinator updates on the + # real thread delivery queue instead of inventing a detached parallel channel. + while True: + try: + await asyncio.wait_for(stop_event.wait(), timeout=self._background_progress_interval_s) + return + except TimeoutError: + pass + + if self._queue_manager is None: + return + + notification = format_progress_notification( + task_id, + latest_progress(), + step="running", + ) + self._queue_manager.enqueue( + notification, + parent_thread_id, + notification_type="agent", + source="system", + sender_name=agent_name, + ) + + async def _handle_task_output(self, task_id: str, block: bool = True, timeout: int = 30_000) -> str: """Get output of a background agent task.""" running = self._tasks.get(task_id) if not running: return f"Error: task '{task_id}' not found" + if not block: + if not running.is_done: + return json.dumps( + { + "task_id": task_id, + "status": "running", + "message": _background_run_running_message(running), + }, + ensure_ascii=False, + ) + + result = running.get_result() + return json.dumps( + { + "task_id": task_id, + "status": _background_run_result_status(result), + "result": result, + }, + ensure_ascii=False, + ) + + if not running.is_done: + completed = await _wait_for_background_run(running, min(timeout, 600_000)) + if not completed and not running.is_done: + return json.dumps( + { + "task_id": task_id, + "status": "timeout", + "message": _background_run_running_message(running), + }, + ensure_ascii=False, + ) + if not running.is_done: return json.dumps( { "task_id": task_id, "status": "running", - "message": "Agent is still running.", + "message": _background_run_running_message(running), }, ensure_ascii=False, ) result = running.get_result() - status = "error" if (result and result.startswith("")) else "completed" return json.dumps( { "task_id": task_id, - "status": status, + "status": _background_run_result_status(result), "result": result, }, ensure_ascii=False, ) + async def _handle_send_message( + self, + target_name: str, + message: str, + sender_name: str | None = None, + ) -> str: + if self._queue_manager is None: + return "SendMessage requires queue_manager" + + matches = await self._agent_registry.list_running_by_name(target_name) + if not matches: + return f"Running agent '{target_name}' not found" + if len(matches) > 1: + return ( + f"Running agent name '{target_name}' is ambiguous. " + "Use a unique name before calling SendMessage." + ) + target = matches[0] + + delivered = format_agent_message(sender_name or "agent", message) + self._queue_manager.enqueue( + delivered, + target.thread_id, + notification_type="agent", + source="system", + sender_name=sender_name or "agent", + ) + return f"Message sent to {target.name}." + + async def _handle_ask_user_question( + self, + questions: list[dict[str, Any]], + annotations: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + tool_context: ToolUseContext | None = None, + ) -> Any: + if tool_context is None or tool_context.request_permission is None: + return tool_error("AskUserQuestion requires an interactive owner resolver") + + payload: dict[str, Any] = {"questions": questions} + if annotations is not None: + payload["annotations"] = annotations + if metadata is not None: + payload["metadata"] = metadata + + request_result = tool_context.request_permission( + "AskUserQuestion", + payload, + ToolPermissionContext(is_read_only=True, is_destructive=False), + None, + "Please answer the following questions so Leon can continue.", + ) + request_id = request_result.get("request_id") if isinstance(request_result, dict) else request_result + if not isinstance(request_id, str) or not request_id: + return tool_error("AskUserQuestion could not create a user-facing request") + + return tool_permission_request( + "User input required to continue.", + metadata={ + "decision": "ask", + "request_id": request_id, + "request_kind": "ask_user_question", + }, + ) + + async def _stop_background_run(self, task_id: str, running: BackgroundRun) -> None: + if isinstance(running, _RunningTask): + was_running = not running.task.done() + if was_running: + running.task.cancel() + try: + await running.task + except asyncio.CancelledError: + pass + await self._agent_registry.update_status(running.agent_id, "error") + self._tasks.pop(task_id, None) + return + + if not running.is_done: + process = getattr(running._cmd, "process", None) + wait = getattr(process, "wait", None) if process is not None else None + terminate = getattr(process, "terminate", None) if process is not None else None + kill = getattr(process, "kill", None) if process is not None else None + + if callable(terminate): + terminate() + if callable(wait): + wait_fn = cast(Callable[[], Awaitable[Any]], wait) + try: + await asyncio.wait_for(wait_fn(), timeout=1.0) + except TimeoutError: + if callable(kill): + kill() + await wait_fn() + + self._tasks.pop(task_id, None) + + async def cleanup_background_runs(self) -> None: + for task_id, running in list(self._tasks.items()): + await self._stop_background_run(task_id, running) + async def _handle_task_stop(self, task_id: str) -> str: """Stop a running background agent task.""" running = self._tasks.get(task_id) @@ -480,6 +1243,5 @@ async def _handle_task_stop(self, task_id: str) -> str: if running.is_done: return f"Task {task_id} already completed" - running.task.cancel() - await self._agent_registry.update_status(running.agent_id, "error") + await self._stop_background_run(task_id, running) return f"Task {task_id} cancelled" diff --git a/core/operations.py b/core/operations.py index c0a471b33..768e49859 100644 --- a/core/operations.py +++ b/core/operations.py @@ -2,10 +2,8 @@ from contextvars import ContextVar from dataclasses import dataclass -from pathlib import Path from storage.models import FileOperationRow -from storage.providers.sqlite.file_operation_repo import SQLiteFileOperationRepo # Context variable for tracking current thread (TUI only; web uses sandbox.thread_context) current_thread_id: ContextVar[str] = ContextVar("current_thread_id", default="") @@ -31,16 +29,8 @@ class FileOperation: class FileOperationRecorder: """Records file operations for time travel rollback""" - def __init__(self, db_path: Path | str | None = None, repo=None): - # @@@repo-injection - web path injects Supabase repo; TUI falls back to SQLite via db_path. - if repo is not None: - self._repo = repo - return - if db_path is None: - db_path = Path.home() / ".leon" / "leon.db" - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self._repo = SQLiteFileOperationRepo(self.db_path) + def __init__(self, repo=None): + self._repo = repo def record( self, @@ -52,7 +42,9 @@ def record( after_content: str, changes: list[dict] | None = None, ) -> str: - """Record a file operation""" + """Record a file operation. Noop if no repo configured.""" + if self._repo is None: + return "" return self._repo.record( thread_id=thread_id, checkpoint_id=checkpoint_id, @@ -64,35 +56,42 @@ def record( ) def get_operations_for_thread(self, thread_id: str, status: str = "applied") -> list[FileOperation]: - """Get all operations for a thread""" + if self._repo is None: + return [] rows = self._repo.get_operations_for_thread(thread_id, status=status) return [self._to_file_operation(row) for row in rows] def get_operations_after_checkpoint(self, thread_id: str, checkpoint_id: str) -> list[FileOperation]: - """Get operations after a specific checkpoint (for rollback)""" + if self._repo is None: + return [] rows = self._repo.get_operations_after_checkpoint(thread_id, checkpoint_id) return [self._to_file_operation(row) for row in rows] def get_operations_between_checkpoints(self, thread_id: str, from_checkpoint_id: str, to_checkpoint_id: str) -> list[FileOperation]: - """Get operations between two checkpoints (exclusive of from, inclusive of to)""" + if self._repo is None: + return [] rows = self._repo.get_operations_between_checkpoints(thread_id, from_checkpoint_id, to_checkpoint_id) return [self._to_file_operation(row) for row in rows] def get_operations_for_checkpoint(self, thread_id: str, checkpoint_id: str) -> list[FileOperation]: - """Get all operations for a specific checkpoint""" + if self._repo is None: + return [] rows = self._repo.get_operations_for_checkpoint(thread_id, checkpoint_id) return [self._to_file_operation(row) for row in rows] def count_operations_for_checkpoint(self, thread_id: str, checkpoint_id: str) -> int: - """Count operations for a specific checkpoint""" + if self._repo is None: + return 0 return self._repo.count_operations_for_checkpoint(thread_id, checkpoint_id) def mark_reverted(self, operation_ids: list[str]) -> None: - """Mark operations as reverted""" + if self._repo is None: + return self._repo.mark_reverted(operation_ids) def delete_thread_operations(self, thread_id: str) -> int: - """Delete all operations for a thread""" + if self._repo is None: + return 0 return self._repo.delete_thread_operations(thread_id) def _to_file_operation(self, row: FileOperationRow) -> FileOperation: diff --git a/core/runner.py b/core/runner.py index 6c3902e3c..fddd6b135 100644 --- a/core/runner.py +++ b/core/runner.py @@ -153,7 +153,7 @@ def _print_memory_stats(self, status: dict) -> None: def _process_chunk(self, chunk: dict, result: dict) -> None: """Process streaming chunk, extract tool calls and response""" - for node_name, node_update in chunk.items(): + for _node_name, node_update in chunk.items(): if not isinstance(node_update, dict): continue diff --git a/core/runtime/abort.py b/core/runtime/abort.py new file mode 100644 index 000000000..f95ca4e2f --- /dev/null +++ b/core/runtime/abort.py @@ -0,0 +1,48 @@ +"""Minimal abort controller tree for runtime lifecycle wiring.""" + +from __future__ import annotations + +from collections.abc import Callable + + +class AbortController: + def __init__(self) -> None: + self._aborted = False + self._listeners: dict[int, Callable[[], None]] = {} + self._next_listener_id = 0 + + def abort(self) -> None: + if self._aborted: + return + self._aborted = True + listeners = list(self._listeners.values()) + self._listeners.clear() + for listener in listeners: + listener() + + def is_aborted(self) -> bool: + return self._aborted + + def on_abort(self, listener: Callable[[], None]) -> Callable[[], None]: + if self._aborted: + listener() + return lambda: None + + listener_id = self._next_listener_id + self._next_listener_id += 1 + self._listeners[listener_id] = listener + + def unsubscribe() -> None: + self._listeners.pop(listener_id, None) + + return unsubscribe + + +def create_child_abort_controller(parent: AbortController | None) -> AbortController: + child = AbortController() + if parent is None: + return child + + unsubscribe = parent.on_abort(child.abort) + child.on_abort(unsubscribe) + return child diff --git a/core/runtime/agent.py b/core/runtime/agent.py index e4d7299c6..7c17ad2e9 100644 --- a/core/runtime/agent.py +++ b/core/runtime/agent.py @@ -18,17 +18,16 @@ All paths must be absolute. Full security mechanisms and audit logging. """ +import asyncio +import concurrent.futures +import inspect +import logging import os -import threading from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any -from langchain.agents import create_agent from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage -from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver - -from config.schema import DEFAULT_MODEL # Load .env file _env_file = Path(__file__).parent / ".env" @@ -53,6 +52,11 @@ # Import file operation recorder for time travel from core.operations import get_recorder # noqa: E402 + +# New architecture: ToolRegistry + ToolRunner + Services +from core.runtime.cleanup import CleanupRegistry # noqa: E402 +from core.runtime.loop import QueryLoop # noqa: E402 +from core.runtime.middleware.mcp_instructions import McpInstructionsDeltaMiddleware # noqa: E402 from core.runtime.middleware.memory import MemoryMiddleware # noqa: E402 from core.runtime.middleware.monitor import MonitorMiddleware, apply_usage_patches # noqa: E402 from core.runtime.middleware.prompt_caching import PromptCachingMiddleware # noqa: E402 @@ -60,10 +64,9 @@ # Middleware imports (migrated paths) from core.runtime.middleware.spill_buffer import SpillBufferMiddleware # noqa: E402 - -# New architecture: ToolRegistry + ToolRunner + Services -from core.runtime.registry import ToolRegistry # noqa: E402 +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema # noqa: E402 from core.runtime.runner import ToolRunner # noqa: E402 +from core.runtime.state import AppState, BootstrapConfig # noqa: E402 from core.runtime.validator import ToolValidator # noqa: E402 # Hooks (used by Services) @@ -71,7 +74,9 @@ from core.tools.command.hooks.file_access_logger import FileAccessLoggerHook # noqa: E402 from core.tools.command.hooks.file_permission import FilePermissionHook # noqa: E402 from core.tools.command.service import CommandService # noqa: E402 +from core.tools.cron.service import CronToolService # noqa: E402 from core.tools.filesystem.service import FileSystemService # noqa: E402 +from core.tools.mcp_resources.service import McpResourceToolService # noqa: E402 from core.tools.search.service import SearchService # noqa: E402 from core.tools.skills.service import SkillsService # noqa: E402 from core.tools.task.service import TaskService # noqa: E402 @@ -82,10 +87,44 @@ from core.tools.web.service import WebService # noqa: E402 from storage.container import StorageContainer # noqa: E402 +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sandbox import Sandbox + # @@@langchain-anthropic-streaming-usage-regression apply_usage_patches() +def _make_mcp_tool_entry(tool) -> ToolEntry: + schema_model = getattr(tool, "tool_call_schema", None) + if schema_model is not None and hasattr(schema_model, "model_json_schema"): + parameters = schema_model.model_json_schema() + else: + parameters = { + "type": "object", + "properties": getattr(tool, "args", {}) or {}, + } + + async def mcp_handler(**kwargs): + if hasattr(tool, "ainvoke"): + return await tool.ainvoke(kwargs) + return await asyncio.to_thread(tool.invoke, kwargs) + + return ToolEntry( + name=tool.name, + mode=ToolMode.INLINE, + schema=make_tool_schema( + name=tool.name, + description=getattr(tool, "description", "") or tool.name, + properties={}, + parameter_overrides=parameters, + ), + handler=mcp_handler, + source="mcp", + ) + + class LeonAgent: """ Leon Agent - AI Coding Assistant @@ -108,6 +147,7 @@ def __init__( workspace_root: str | Path | None = None, *, agent: str | None = None, + bundle_dir: str | Path | None = None, allowed_file_extensions: list[str] | None = None, block_dangerous_commands: bool | None = None, block_network_commands: bool | None = None, @@ -119,9 +159,15 @@ def __init__( jina_api_key: str | None = None, sandbox: Any = None, storage_container: StorageContainer | None = None, + thread_repo: Any = None, + member_repo: Any = None, queue_manager: MessageQueueManager | None = None, chat_repos: dict | None = None, + web_app: Any = None, extra_allowed_paths: list[str] | None = None, + extra_blocked_tools: set[str] | None = None, + allowed_tools: set[str] | None = None, + permission_resolver_scope: str = "none", verbose: bool = False, ): """ @@ -138,7 +184,10 @@ def __init__( enable_audit_log: Whether to enable audit logging enable_web_tools: Whether to enable web search and content fetching tools sandbox: Sandbox instance, name string, or None for local + thread_repo: Optional thread metadata repo for backend-integrated subagent registration + member_repo: Optional member repo for backend-integrated subagent registration queue_manager: Shared MessageQueueManager instance (created if not provided) + permission_resolver_scope: Permission request surface for this agent ("none" or "thread") verbose: Whether to output detailed logs (default False) """ self.agent_id: str | None = None @@ -146,11 +195,22 @@ def __init__( self.extra_allowed_paths = extra_allowed_paths self.queue_manager = queue_manager or MessageQueueManager() self._chat_repos: dict | None = chat_repos + self._thread_repo = thread_repo + self._member_repo = member_repo + self._web_app = web_app + self._session_started = False + self._session_ended = False + self._closing = False + self._closed = False + requested_sandbox_name = sandbox if isinstance(sandbox, str) else getattr(sandbox, "name", None) + self._explicit_model_name = model_name is not None # New config system mode self.config, self.models_config = self._load_config( agent_name=agent, + bundle_dir=bundle_dir, workspace_root=workspace_root, + sandbox_name=requested_sandbox_name, model_name=model_name, api_key=api_key, allowed_file_extensions=allowed_file_extensions, @@ -167,8 +227,9 @@ def __init__( from config.schema import DEFAULT_MODEL # noqa: E402 active_model = DEFAULT_MODEL - # Member model override: agent.md's model field takes precedence over global config - if hasattr(self, "_agent_override") and self._agent_override and self._agent_override.model: + # Agent frontmatter model applies only when the caller did not explicitly + # request a model at construction time. + if not self._explicit_model_name and hasattr(self, "_agent_override") and self._agent_override and self._agent_override.model: active_model = self._agent_override.model resolved_model, model_overrides = self.models_config.resolve_model(active_model) self.model_name = resolved_model @@ -177,6 +238,7 @@ def __init__( # Resolve API key (prefer resolved provider from mapping) provider_name = self._resolve_provider_name(resolved_model, model_overrides) p = self.models_config.get_provider(provider_name) if provider_name else None + self._explicit_api_key = api_key is not None self.api_key = api_key or (p.api_key if p else None) or self.models_config.get_api_key() if not self.api_key: @@ -213,63 +275,74 @@ def __init__( } # Initialize checkpointer and MCP tools - self._aiosqlite_conn, mcp_tools = self._init_async_components() + self.checkpointer = None + _conn, mcp_tools = self._init_async_components() # If in async context (running loop detected), _init_async_components # skips init and returns (None, []). Distinguish from Postgres path # which also returns conn=None but DID initialize successfully. - self._needs_async_init = self._aiosqlite_conn is None and self.checkpointer is None + self._needs_async_init = self.checkpointer is None # Set checkpointer to None if in async context (will be initialized later) if self._needs_async_init: self.checkpointer = None # Initialize ToolRegistry and Services (new architecture) - self._tool_registry = ToolRegistry(blocked_tools=self._get_member_blocked_tools()) + blocked = self._get_member_blocked_tools() + if extra_blocked_tools: + blocked = blocked | extra_blocked_tools + self._tool_registry = ToolRegistry( + blocked_tools=blocked, + allowed_tools=allowed_tools, + ) self._init_services() + self._register_mcp_tools(mcp_tools) # Build middleware stack middleware = self._build_middleware_stack() - # Ensure ToolNode is created (middleware tools need at least one BaseTool) + # Ensure the bound model still sees at least one BaseTool-compatible entry. if not mcp_tools and not self._has_middleware_tools(middleware): mcp_tools = [self._create_placeholder_tool()] - # Build system prompt - self.system_prompt = self._build_system_prompt() - custom_prompt = self.config.system_prompt - if custom_prompt: - self.system_prompt += f"\n\n**Custom Instructions:**\n{custom_prompt}" - - # @@@entity-identity — inject chat identity so agent knows who it is in the social layer - if self._chat_repos: - repos = self._chat_repos - uid = repos.get("user_id") - owner_uid = repos.get("owner_user_id", "") - if uid: - entity_repo = repos.get("entity_repo") - entity = entity_repo.get_by_id(uid) if entity_repo else None - member_repo = repos.get("member_repo") - owner_row = member_repo.get_by_id(owner_uid) if member_repo and owner_uid else None - name = entity.name if entity else uid - owner_name = owner_row.name if owner_row else "unknown" - self.system_prompt += ( - f"\n\n**Chat Identity:**\n" - f"- Your name: {name}\n" - f"- Your user_id: {uid}\n" - f"- Your owner: {owner_name} (user_id: {owner_uid})\n" - f"- When you receive a chat notification, READ the message with chat_read(), " - f"then REPLY with chat_send(). Your text output goes to your owner's thread, " - f"not to the chat — only chat_send() delivers to the other party.\n" - ) + self._system_prompt_section_cache: dict[str, str] = {} + self.system_prompt = self._compose_system_prompt() - # Create agent - self.agent = create_agent( + # Build BootstrapConfig for sub-agent forking + self._bootstrap = BootstrapConfig( + workspace_root=self.workspace_root, + original_cwd=Path.cwd(), + project_root=self.workspace_root, + cwd=self.workspace_root, + model_name=self.model_name, + api_key=self.api_key, + sandbox_type=self._sandbox.name, + permission_resolver_scope=permission_resolver_scope, + block_dangerous_commands=self.block_dangerous_commands, + block_network_commands=self.block_network_commands, + enable_audit_log=self.enable_audit_log, + enable_web_tools=self.enable_web_tools, + allowed_file_extensions=self.allowed_file_extensions, + extra_allowed_paths=self.extra_allowed_paths, + model_provider=self._current_model_config.get("model_provider"), + base_url=self._current_model_config.get("base_url"), + ) + self._app_state = AppState() + self.app_state = self._app_state + # Inject bootstrap into AgentService so sub-agents can fork from it + if hasattr(self, "_agent_service"): + self._agent_service._parent_bootstrap = self._bootstrap + + # Create agent via QueryLoop (replaces LangGraph create_agent) + self.agent = QueryLoop( model=self.model, - tools=mcp_tools, system_prompt=SystemMessage(content=[{"type": "text", "text": self.system_prompt}]), middleware=middleware, - checkpointer=self.checkpointer if not self._needs_async_init else None, + checkpointer=self.checkpointer, + registry=self._tool_registry, + app_state=self._app_state, + runtime=self._monitor_middleware.runtime, + bootstrap=self._bootstrap, ) # Get runtime from MonitorMiddleware @@ -286,13 +359,45 @@ def __init__( print("[LeonAgent] Initialized successfully") print(f"[LeonAgent] Workspace: {self.workspace_root}") print(f"[LeonAgent] Audit log: {self.enable_audit_log}") - if self._needs_async_init: + if self.checkpointer is None: print("[LeonAgent] Note: Async components need initialization via ainit()") - # Mark agent as ready (if not needing async init) - if not self._needs_async_init: + # Wire CleanupRegistry for priority-ordered resource teardown + self._cleanup_registry = CleanupRegistry() + self._cleanup_registry.register(self._cleanup_sandbox, priority=2) + self._cleanup_registry.register(self._mark_terminated, priority=3) + self._cleanup_registry.register(self._cleanup_mcp_client, priority=4) + self._cleanup_registry.register(self._cleanup_sqlite_connection, priority=5) + + # Mark agent as ready (checkpointer is None when async init still pending) + if self.checkpointer is not None: self._monitor_middleware.mark_ready() + @property + def sandbox(self) -> "Sandbox": + # @@@public-sandbox-surface - integration callers already drive fs/shell through + # agent.sandbox; make that contract explicit instead of relying on a private attr. + return self._sandbox + + def apply_forked_child_context( + self, + bootstrap: BootstrapConfig, + *, + tool_context: Any | None = None, + ) -> None: + # @@@subagent-fork-wiring + # AgentService should not reach through LeonAgent and mutate QueryLoop + # internals directly. Keep the child bootstrap + abort-controller wiring + # behind one explicit LeonAgent seam. + self._bootstrap = bootstrap + self.agent._bootstrap = bootstrap + if hasattr(self, "_agent_service"): + self._agent_service._parent_bootstrap = bootstrap + if tool_context is not None: + self._agent_service._parent_tool_context = tool_context + if tool_context is not None: + self.agent._tool_abort_controller = tool_context.abort_controller + async def ainit(self): """Complete async initialization (call this if initialized in async context). @@ -300,22 +405,28 @@ async def ainit(self): agent = LeonAgent(sandbox=sandbox) await agent.ainit() """ - if not self._needs_async_init: - return # Already initialized - - # Initialize async components - self._aiosqlite_conn = await self._init_checkpointer() - _mcp_tools = await self._init_mcp_tools() + if self.checkpointer is None: + # Initialize async components + await self._init_checkpointer() + _mcp_tools = await self._init_mcp_tools() + self._register_mcp_tools(_mcp_tools) + + # Update agent with checkpointer + self.agent.checkpointer = self.checkpointer + if hasattr(self, "_memory_middleware"): + # @@@late-checkpointer-fanout - async bringup creates the saver after + # middleware construction, so QueryLoop and MemoryMiddleware must be + # rewired together or rebuild/persistence surfaces drift apart. + self._memory_middleware.checkpointer = self.checkpointer - # Update agent with checkpointer - self.agent.checkpointer = self.checkpointer + self._monitor_middleware.mark_ready() - # Mark as initialized - self._needs_async_init = False - self._monitor_middleware.mark_ready() + if self.verbose: + print("[LeonAgent] Async initialization completed") - if self.verbose: - print("[LeonAgent] Async initialization completed") + if not self._session_started: + await self._run_session_hooks("SessionStart") + self._session_started = True def _init_async_components(self) -> tuple[Any, list]: """Initialize async components (checkpointer and MCP tools). @@ -339,24 +450,31 @@ def _init_async_components(self) -> tuple[Any, list]: self._event_loop = loop # Initialize components - conn = loop.run_until_complete(self._init_checkpointer()) + loop.run_until_complete(self._init_checkpointer()) mcp_tools = loop.run_until_complete(self._init_mcp_tools()) - # DON'T close the loop - let it persist for aiosqlite - # The loop will be cleaned up when Python exits - return conn, mcp_tools + return None, mcp_tools def _has_middleware_tools(self, middleware: list) -> bool: """Check if any middleware has BaseTool instances.""" return any(getattr(m, "tools", None) for m in middleware) + def _register_mcp_tools(self, mcp_tools: list) -> None: + if not mcp_tools: + return + for tool in mcp_tools: + try: + self._tool_registry.register(_make_mcp_tool_entry(tool)) + except Exception as exc: + logger.warning("[LeonAgent] Failed to register MCP tool %s: %s", getattr(tool, "name", ""), exc) + def _create_placeholder_tool(self): - """Create placeholder tool to ensure ToolNode is created.""" + """Create placeholder tool so the bound model still has a BaseTool.""" from langchain_core.tools import tool @tool def _placeholder() -> str: - """Internal placeholder - ensures ToolNode is created for middleware tools.""" + """Internal placeholder for the empty-tool edge.""" return "" return _placeholder @@ -391,10 +509,26 @@ def _get_member_blocked_tools(self) -> set[str]: return blocked + def _get_mcp_server_configs(self) -> dict[str, Any]: + if hasattr(self, "_agent_bundle") and self._agent_bundle and self._agent_bundle.mcp: + return {name: srv for name, srv in self._agent_bundle.mcp.items() if not srv.disabled} + return self.config.mcp.servers + + def _get_mcp_instruction_blocks(self) -> dict[str, str]: + blocks: dict[str, str] = {} + for name, cfg in self._get_mcp_server_configs().items(): + instructions = getattr(cfg, "instructions", None) + if not isinstance(instructions, str) or not instructions.strip(): + continue + blocks[name] = instructions.strip() + return blocks + def _load_config( self, agent_name: str | None, + bundle_dir: str | Path | None, workspace_root: str | Path | None, + sandbox_name: str | None, model_name: str | None, api_key: str | None, allowed_file_extensions: list[str] | None, @@ -410,8 +544,14 @@ def _load_config( """ # Build CLI overrides for runtime config cli_overrides: dict = {} - - if workspace_root is not None: + use_workspace_override = sandbox_name in (None, "", "local") + + if workspace_root is not None and use_workspace_override: + # @@@remote-sandbox-config-root + # Remote child agents may inherit a sandbox cwd like /home/daytona, + # which is valid inside the sandbox but not on the host. Feeding that + # path into LeonSettings makes config validation fail before sandbox + # init ever runs, so only local sandboxes pin workspace_root here. cli_overrides["workspace_root"] = str(workspace_root) # Runtime overrides go into "runtime" section @@ -441,8 +581,14 @@ def _load_config( models_loader = ModelsLoader(workspace_root=workspace_root) models_config = models_loader.load(cli_overrides=models_cli if models_cli else None) + # @@@bundle-dir-wins - member-backed top-level agents need their own bundle even when + # no explicit agent type name is passed through the thread runtime wiring. + if bundle_dir is not None: + bundle_path = Path(bundle_dir).expanduser().resolve() + self._agent_bundle = loader.load_bundle(bundle_path) + self._agent_override = self._agent_bundle.agent.model_copy(update={"source_dir": bundle_path}) # If agent specified, load agent definition to override system_prompt and tools - if agent_name: + elif agent_name: all_agents = loader.load_all_agents() agent_def = all_agents.get(agent_name) if not agent_def: @@ -609,7 +755,16 @@ def _build_model_kwargs(self) -> dict: # Get credentials from the resolved provider p = self.models_config.get_provider(provider) if provider else None - base_url = (p.base_url if p else None) or self.models_config.get_base_url() + env_base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL") + + # @@@explicit-api-key-base-url + # Real-model verification must not be silently redirected to a provider + # config endpoint when the caller explicitly injected credentials for a + # different OpenAI-compatible endpoint. + if self._explicit_api_key and env_base_url: + base_url = env_base_url + else: + base_url = (p.base_url if p else None) or self.models_config.get_base_url() if base_url: kwargs["base_url"] = self._normalize_base_url(base_url, provider) @@ -714,12 +869,71 @@ def update_observation(self, **overrides) -> None: if self.verbose: print(f"[LeonAgent] Observation updated: active={self._observation_config.active}") - def close(self): - """Clean up resources.""" - self._cleanup_sandbox() - self._mark_terminated() - self._cleanup_mcp_client() - self._cleanup_sqlite_connection() + def close(self, *, cleanup_sandbox: bool = True): + """Clean up resources via CleanupRegistry (priority-ordered). + + Falls back to direct cleanup if CleanupRegistry is not initialized. + """ + # @@@close-idempotent - child agents may explicitly skip sandbox cleanup + # and later still hit __del__ on GC; never let a second close silently + # re-enable default sandbox teardown on a shared lease. + if getattr(self, "_closed", False) or getattr(self, "_closing", False): + return + + self._closing = True + session_end_error: Exception | None = None + try: + if getattr(self, "_session_started", False) and not getattr(self, "_session_ended", False): + try: + self._run_async_cleanup(lambda: self._run_session_hooks("SessionEnd"), "SessionEnd hooks") + except Exception as exc: + session_end_error = exc + finally: + self._session_ended = True + + if hasattr(self, "_cleanup_registry") and cleanup_sandbox: + self._run_async_cleanup(self._cleanup_registry.run_cleanup, "CleanupRegistry") + else: + # Fallback for edge cases where __init__ did not complete fully + cleanup_steps = [ + ("monitor", self._mark_terminated), + ("MCP client", self._cleanup_mcp_client), + ("SQLite connection", self._cleanup_sqlite_connection), + ] + if cleanup_sandbox: + cleanup_steps.insert(0, ("sandbox", self._cleanup_sandbox)) + + for step_name, step_fn in cleanup_steps: + try: + step_fn() + except Exception as e: + print(f"[LeonAgent] {step_name} cleanup error: {e}") + + if session_end_error is not None: + raise session_end_error + finally: + self._closed = True + self._closing = False + + def _build_session_hook_payload(self, event: str) -> dict[str, Any]: + return { + "event": event, + "session_id": self._bootstrap.session_id, + "workspace_root": str(self.workspace_root), + "cwd": str(self._bootstrap.cwd or self.workspace_root), + "sandbox": self._sandbox.name, + } + + async def _run_session_hooks(self, event: str) -> None: + hooks = self._app_state.get_session_hooks(event) + if not hooks: + return + + payload = self._build_session_hook_payload(event) + for hook in hooks: + result = hook(payload) + if inspect.isawaitable(result): + await result def _cleanup_sandbox(self) -> None: """Clean up sandbox resources.""" @@ -734,32 +948,29 @@ def _mark_terminated(self) -> None: if hasattr(self, "_monitor_middleware"): self._monitor_middleware.mark_terminated() + _CLEANUP_TIMEOUT: float = 10.0 # seconds; prevents hanging on stuck I/O + @staticmethod def _run_async_cleanup(coro_factory, label: str) -> None: import asyncio try: - running_loop = asyncio.get_running_loop() + asyncio.get_running_loop() except RuntimeError: - running_loop = None - - if running_loop is None: asyncio.run(coro_factory()) return - error: list[Exception] = [] - - def _runner() -> None: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro_factory()) try: - asyncio.run(coro_factory()) + future.result(timeout=LeonAgent._CLEANUP_TIMEOUT) + except concurrent.futures.TimeoutError: + raise RuntimeError( + f"{label} cleanup timed out after {LeonAgent._CLEANUP_TIMEOUT}s — " + f"possible stuck I/O; resource abandoned to prevent hang" + ) except Exception as exc: - error.append(exc) - - thread = threading.Thread(target=_runner, daemon=True) - thread.start() - thread.join() - if error: - raise RuntimeError(f"{label} cleanup failed: {error[0]}") from error[0] + raise RuntimeError(f"{label} cleanup failed: {exc}") from exc def _cleanup_mcp_client(self) -> None: """Clean up MCP client.""" @@ -767,35 +978,15 @@ def _cleanup_mcp_client(self) -> None: return try: - self._run_async_cleanup(lambda: self._mcp_client.close(), "MCP client") + close_fn = getattr(self._mcp_client, "close", None) + if callable(close_fn): + self._run_async_cleanup(close_fn, "MCP client") except Exception as e: print(f"[LeonAgent] MCP cleanup error: {e}") self._mcp_client = None def _cleanup_sqlite_connection(self) -> None: - """Clean up SQLite connection. - - Properly closes aiosqlite connection using asyncio.run() to avoid - hanging on process exit. - """ - if not hasattr(self, "_aiosqlite_conn") or not self._aiosqlite_conn: - return - - try: - import asyncio - - # Close the connection asynchronously - async def _close(): - if self._aiosqlite_conn: - await self._aiosqlite_conn.close() - - # Use asyncio.run() to properly close the connection - asyncio.run(_close()) - except Exception: - # Ignore errors during cleanup - pass - finally: - self._aiosqlite_conn = None + """No-op: SQLite checkpointer removed; Postgres cleanup handled by _pg_saver_ctx.""" def __del__(self): self.close() @@ -830,11 +1021,19 @@ def _build_middleware_stack(self) -> list: if memory_enabled: self._add_memory_middleware(middleware) - # 4. Steering — injects queued messages before model call + # 4. MCP instructions delta — thread-scoped reminder when MCP guidance changes + middleware.append( + McpInstructionsDeltaMiddleware( + get_instruction_blocks=self._get_mcp_instruction_blocks, + get_app_state=lambda: self.app_state, + ) + ) + + # 5. Steering — injects queued messages before model call self._steering_middleware = SteeringMiddleware(queue_manager=self.queue_manager) middleware.append(self._steering_middleware) - # 5. ToolRunner (innermost — routes all ToolRegistry-registered tool calls) + # 6. ToolRunner (innermost — routes all ToolRegistry-registered tool calls) self._tool_runner = ToolRunner( registry=self._tool_registry, validator=ToolValidator(), @@ -843,7 +1042,7 @@ def _build_middleware_stack(self) -> list: # 0. SpillBuffer (outermost — catches oversized tool outputs) # Must be inserted at index 0 AFTER building the list: - # LangChain wraps middlewares as "first = outermost". + # QueryLoop composes middleware so the first entry remains outermost. if self.config.tools.spill_buffer.enabled: spill_cfg = self.config.tools.spill_buffer middleware.insert( @@ -993,6 +1192,17 @@ def _init_services(self) -> None: workspace_root=self.workspace_root, ) + # Cron tools (DEFERRED - backed by existing panel cron_jobs substrate) + self._cron_tool_service = CronToolService( + registry=self._tool_registry, + ) + + self._mcp_resource_tool_service = McpResourceToolService( + registry=self._tool_registry, + client_fn=lambda: getattr(self, "_mcp_client", None), + server_configs_fn=self._get_mcp_server_configs, + ) + # ToolSearch (INLINE - always available for discovering DEFERRED tools) self._tool_search_service = ToolSearchService( registry=self._tool_registry, @@ -1005,8 +1215,12 @@ def _init_services(self) -> None: agent_registry=self._agent_registry, workspace_root=self.workspace_root, model_name=self.model_name, + thread_repo=self._thread_repo, + member_repo=self._member_repo, queue_manager=self.queue_manager, shared_runs=self._background_runs, + web_app=self._web_app, + child_agent_factory=create_leon_agent, ) # Team coordination (TeamCreate/TeamDelete — deferred mode) @@ -1023,51 +1237,37 @@ def _init_services(self) -> None: except ImportError: self._taskboard_service = None - # @@@chat-tools - register chat tools for agents with user identity + # @@@chat-tools - register chat tools for agents with user identity (v2 messaging) if self._chat_repos: repos = self._chat_repos - user_id = repos.get("user_id") - owner_user_id = repos.get("owner_user_id", "") - if user_id: - from core.agents.communication.chat_tool_service import ChatToolService + chat_identity_id = repos.get("chat_identity_id") or repos.get("user_id") + owner_id = repos.get("owner_id", "") + if chat_identity_id: + from messaging.tools.chat_tool_service import ChatToolService - # @@@lazy-runtime — runtime isn't set yet at _init_services() time. - # Pass a callable that resolves runtime lazily at tool call time. self._chat_tool_service = ChatToolService( registry=self._tool_registry, - user_id=user_id, - owner_user_id=owner_user_id, - entity_repo=repos.get("entity_repo"), - chat_service=repos.get("chat_service"), - chat_entity_repo=repos.get("chat_entity_repo"), - chat_message_repo=repos.get("chat_message_repo"), + chat_identity_id=chat_identity_id, + owner_id=owner_id, + messaging_service=repos.get("messaging_service"), + chat_member_repo=repos.get("chat_member_repo"), + messages_repo=repos.get("messages_repo"), member_repo=repos.get("member_repo"), - chat_event_bus=repos.get("chat_event_bus"), - runtime_fn=lambda: getattr(self, "runtime", None), + thread_repo=self._thread_repo, + relationship_repo=repos.get("relationship_repo"), ) - # @@@wechat-tools — register WeChat tools via lazy connection lookup - owner_uid = self._chat_repos.get("owner_user_id", "") if self._chat_repos else "" - if owner_uid: - try: - from core.tools.wechat.service import WeChatToolService - - def _get_wechat_conn(uid=owner_uid): - """Lazy lookup — returns None if registry not on app.state yet.""" - try: - from backend.web.main import app - - registry = getattr(app.state, "wechat_registry", None) - return registry.get(uid) if registry else None - except Exception: - return None + # LSP tools — DEFERRED, always registered, multilspy checked at call time + self._lsp_service = None + try: + from core.tools.lsp.service import LSPService - self._wechat_tool_service = WeChatToolService( - registry=self._tool_registry, - connection_fn=_get_wechat_conn, - ) - except ImportError: - self._wechat_tool_service = None + self._lsp_service = LSPService( + registry=self._tool_registry, + workspace_root=self.workspace_root, + ) + except Exception as e: + logger.debug("[LeonAgent] LSPService init skipped: %s", e) if self.verbose: all_tools = self._tool_registry.list_all() @@ -1078,11 +1278,7 @@ def _get_wechat_conn(uid=owner_uid): async def _init_mcp_tools(self) -> list: mcp_enabled = self.config.mcp.enabled - # Use member bundle MCP config if available, else fall back to global config - if hasattr(self, "_agent_bundle") and self._agent_bundle and self._agent_bundle.mcp: - mcp_servers = {name: srv for name, srv in self._agent_bundle.mcp.items() if not srv.disabled} - else: - mcp_servers = self.config.mcp.servers + mcp_servers = self._get_mcp_server_configs() if not mcp_enabled or not mcp_servers: return [] @@ -1091,10 +1287,21 @@ async def _init_mcp_tools(self) -> list: configs = {} for name, cfg in mcp_servers.items(): + transport = getattr(cfg, "transport", None) if cfg.url: - config = {"transport": "streamable_http", "url": cfg.url} + # @@@mcp-transport-honesty - api-04 requires explicit transport + # config to survive loader -> runtime. URL-based MCP is not + # always streamable_http; websocket/sse must stay explicit. + config = { + "transport": transport or "streamable_http", + "url": cfg.url, + } else: - config = {"transport": "stdio", "command": cfg.command, "args": cfg.args} + config = { + "transport": transport or "stdio", + "command": cfg.command, + "args": cfg.args, + } if cfg.env: config["env"] = cfg.env configs[name] = config @@ -1129,31 +1336,20 @@ async def _init_mcp_tools(self) -> list: async def _init_checkpointer(self): """Initialize async checkpointer for conversation persistence. - Uses Postgres (via Supabase) when LEON_STORAGE_STRATEGY=supabase, - otherwise falls back to local SQLite. + Requires LEON_POSTGRES_URL to be set (Supabase Postgres). """ - strategy = os.getenv("LEON_STORAGE_STRATEGY", "sqlite") pg_url = os.getenv("LEON_POSTGRES_URL") + if not pg_url: + raise RuntimeError("LEON_POSTGRES_URL is required for checkpointer initialization") - if strategy == "supabase" and pg_url: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - - # from_conn_string is an async context manager; enter it and keep - # the reference so the connection pool stays open for the agent's lifetime. - self._pg_saver_ctx = AsyncPostgresSaver.from_conn_string(pg_url) - self.checkpointer = await self._pg_saver_ctx.__aenter__() - await self.checkpointer.setup() - return None # no SQLite conn to track - else: - from storage.providers.sqlite.kernel import connect_sqlite_async + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - db_path = self.db_path - db_path.parent.mkdir(parents=True, exist_ok=True) - conn = await connect_sqlite_async(db_path) - self.checkpointer = AsyncSqliteSaver(conn) - await self.checkpointer.setup() - return conn - return conn + # from_conn_string is an async context manager; enter it and keep + # the reference so the connection pool stays open for the agent's lifetime. + self._pg_saver_ctx = AsyncPostgresSaver.from_conn_string(pg_url) + self.checkpointer = await self._pg_saver_ctx.__aenter__() + await self.checkpointer.setup() + return None # no SQLite conn to track def _is_tool_allowed(self, tool) -> bool: # Extract original tool name without mcp__ prefix @@ -1190,155 +1386,109 @@ def _build_system_prompt(self) -> str: return prompt - def _build_context_section(self) -> str: - """Build the context section based on sandbox mode.""" - if self._sandbox.name != "local": - env_label = self._sandbox.env_label - working_dir = self._sandbox.working_dir - if self._sandbox.name == "docker": - mode_label = "Sandbox (isolated local container)" - else: - mode_label = "Sandbox (isolated cloud environment)" - return f"""- Environment: {env_label} -- Working Directory: {working_dir} -- Mode: {mode_label}""" - else: - import platform - - os_name = platform.system() - if os_name == "Windows": - shell_name = "powershell" - else: - shell_name = os.environ.get("SHELL", "/bin/bash").split("/")[-1] - return f"""- Workspace: `{self.workspace_root}` -- OS: {os_name} -- Shell: {shell_name} -- Mode: Local""" + def _compose_system_prompt(self) -> str: + prompt = self._build_system_prompt() - def _build_rules_section(self) -> str: - """Build shared rules section for all modes.""" - is_sandbox = self._sandbox.name != "local" - working_dir = self._sandbox.working_dir if is_sandbox else self.workspace_root + custom_prompt = self.config.system_prompt + if custom_prompt: + prompt += f"\n\n**Custom Instructions:**\n{custom_prompt}" - rules = [] + # @@@chat-identity — inject chat identity so agent knows who it is in the social layer + if self._chat_repos: + repos = self._chat_repos + uid = repos.get("chat_identity_id") or repos.get("user_id") + owner_uid = repos.get("owner_id", "") + if uid: + member_repo = repos.get("member_repo") + self_member = member_repo.get_by_id(uid) if member_repo else None + if self_member is None and member_repo and self._thread_repo is not None: + thread = self._thread_repo.get_by_user_id(uid) + member_id = thread.get("member_id") if thread else None + if member_id: + self_member = member_repo.get_by_id(member_id) + owner_row = member_repo.get_by_id(owner_uid) if member_repo and owner_uid else None + name = self_member.name if self_member else uid + owner_name = owner_row.name if owner_row else "unknown" + prompt += ( + f"\n\n**Chat Identity:**\n" + f"- Your name: {name}\n" + f"- Your chat identity id: {uid}\n" + f"- The chat tools still use the parameter name user_id for legacy reasons.\n" + f"- Your owner: {owner_name} (human user_id: {owner_uid})\n" + f"- When you receive a chat notification, you MUST read it with chat_read() before deciding what to do.\n" + f"- If that notification already gives you a chat_id, prefer using that exact chat_id directly.\n" + f"- If you reply to the other party, you MUST call chat_send(). Never claim you replied unless chat_send() succeeded.\n" + f"- Your normal text output goes to your owner's thread, not to the chat — only chat_send() delivers to the other party.\n" + ) + return prompt - # Rule 1: Environment-specific - if is_sandbox: - if self._sandbox.name == "docker": - location_rule = "All file and command operations run in a local Docker container, NOT on the user's host filesystem." - else: - location_rule = "All file and command operations run in a remote sandbox, NOT on the user's local machine." - rules.append(f"1. **Sandbox Environment**: {location_rule} The sandbox is an isolated Linux environment.") - else: - rules.append("1. **Workspace**: File operations are restricted to: " + str(self.workspace_root)) + def _invalidate_system_prompt_cache(self) -> None: + self._system_prompt_section_cache.clear() - # Rule 2: Absolute paths - rules.append(f"""2. **Absolute Paths**: All file paths must be absolute paths. - - ✅ Correct: `{working_dir}/project/test.py` - - ❌ Wrong: `test.py` or `./test.py`""") + def _get_cached_prompt_section(self, key: str, builder) -> str: + cached = self._system_prompt_section_cache.get(key) + if cached is not None: + return cached + value = builder() + self._system_prompt_section_cache[key] = value + return value - # Rule 3: Security - if is_sandbox: - rules.append("3. **Security**: The sandbox is isolated. You can install packages, run any commands, and modify files freely.") - else: - rules.append("3. **Security**: Dangerous commands are blocked. All operations are logged.") + def _build_context_section(self) -> str: + from core.runtime.prompts import build_context_section + + def _build() -> str: + is_sandbox = self._sandbox.name != "local" + if is_sandbox: + return build_context_section( + sandbox_name=self._sandbox.name, + sandbox_env_label=self._sandbox.env_label, + sandbox_working_dir=self._sandbox.working_dir, + ) + import platform - # Rule 4: Tool priority - rules.append( - """4. **Tool Priority**: When a built-in tool and an MCP tool (`mcp__*`) have the same functionality, use the built-in tool.""" - ) + os_name = platform.system() + shell_name = "powershell" if os_name == "Windows" else os.environ.get("SHELL", "/bin/bash").split("/")[-1] + return build_context_section( + sandbox_name="local", + workspace_root=str(self.workspace_root), + os_name=os_name, + shell_name=shell_name, + ) - # Rule 5: Dedicated tools over shell - rules.append("""5. **Use Dedicated Tools Instead of Shell Commands**: Do NOT use `Bash` for tasks that have dedicated tools: - - File search → use `Grep` (NOT `rg`, `grep`, or `find` via Bash) - - File listing → use `Glob` (NOT `find` or `ls` via Bash) - - File reading → use `Read` (NOT `cat`, `head`, `tail` via Bash) - - File editing → use `Edit` (NOT `sed` or `awk` via Bash) - - Reserve `Bash` for: git, package managers, build tools, tests, and other system operations.""") + return self._get_cached_prompt_section("context", _build) - # Rule 6: Background task description - rules.append("""6. **Background Task Description**: When using `Bash` or `Agent` with `run_in_background: true`, always include a clear `description` parameter. # noqa: E501 - - The description is shown to the user in the background task indicator. - - Keep it concise (5–10 words), action-oriented, e.g. "Run test suite", "Analyze API codebase". - - Without a description, the raw command or agent name is shown, which is hard to read.""") + def _build_rules_section(self) -> str: + from core.runtime.prompts import build_rules_section + + def _build() -> str: + is_sandbox = self._sandbox.name != "local" + working_dir = self._sandbox.working_dir if is_sandbox else str(self.workspace_root) + return build_rules_section( + is_sandbox=is_sandbox, + sandbox_name=self._sandbox.name, + working_dir=working_dir, + workspace_root=str(self.workspace_root), + spill_buffer_enabled=self.config.tools.spill_buffer.enabled, + spill_keep_recent=self.config.memory.pruning.protect_recent, + ) - return "\n\n".join(rules) + return self._get_cached_prompt_section("rules", _build) def _build_base_prompt(self) -> str: - """Build the base system prompt (context + rules), shared by all modes.""" - context = self._build_context_section() - rules = self._build_rules_section() - - return f"""You are a highly capable AI assistant with access to file and system tools. - -**Context:** -{context} + from core.runtime.prompts import build_base_prompt -**Important Rules:** - -{rules} -""" + return self._get_cached_prompt_section( + "base_prompt", + lambda: build_base_prompt(self._build_context_section(), self._build_rules_section()), + ) def _build_common_prompt_sections(self) -> str: - """Build common prompt sections for both sandbox and local modes.""" - prompt = """ -**Agent Tool (Sub-agent Orchestration):** - -Use the Agent tool to launch specialized sub-agents for complex tasks: -- `explore`: Read-only codebase exploration. Use for: finding files, searching code, understanding implementations. -- `plan`: Design implementation plans. Use for: architecture decisions, multi-step planning. -- `bash`: Execute shell commands. Use for: git operations, running tests, system commands. -- `general`: Full tool access. Use for: independent multi-step tasks requiring file modifications. - -When to use Agent: -- Open-ended searches that may require multiple rounds of exploration -- Tasks that can run independently while you continue other work -- Complex operations that benefit from specialized focus - -When NOT to use Agent: -- Simple file reads (use Read directly) -- Specific searches with known patterns (use Grep directly) -- Quick operations that don't need isolation - -**Todo Tools (Task Management):** - -Use Todo tools to track progress on complex, multi-step tasks: -- `TaskCreate`: Create a new task with subject, description, and activeForm (present continuous for spinner) -- `TaskList`: View all tasks and their status -- `TaskGet`: Get full details of a specific task -- `TaskUpdate`: Update task status (pending → in_progress → completed) or details - -When to use Todo: -- Complex tasks with 3+ distinct steps -- When the user provides multiple tasks to complete -- To show progress on non-trivial work - -When NOT to use Todo: -- Single, straightforward tasks -- Trivial operations that don't need tracking -""" - - # Add Skills section if skills are enabled - skills_enabled = self.config.skills.enabled and self.config.skills.paths - - if skills_enabled: - prompt += """ -**Skills (Specialized Knowledge):** + from core.runtime.prompts import build_common_sections -Use the `load_skill` tool to access specialized domain knowledge and workflows: -- Skills provide focused instructions for specific tasks (e.g., TDD, debugging, git workflows) -- Call `load_skill(skill_name)` to load a skill's content into context -- Available skills are listed in the load_skill tool description - -When to use load_skill: -- When you need specialized guidance for a specific workflow -- To access domain-specific best practices -- When the user mentions a skill by name (e.g., "use TDD skill") - -Progressive disclosure: Skills are loaded on-demand to save tokens. -""" - - return prompt + return self._get_cached_prompt_section( + "common_sections", + lambda: build_common_sections(bool(self.config.skills.enabled and self.config.skills.paths)), + ) def invoke(self, message: str, thread_id: str = "default") -> dict: """Invoke agent with a message (sync version). @@ -1388,6 +1538,174 @@ async def ainvoke(self, message: str, thread_id: str = "default") -> dict: self._monitor_middleware.mark_error(e) raise + async def astream( + self, + message: str, + thread_id: str = "default", + stream_mode: str | list[str] = "updates", + max_budget_usd: float | None = None, + ): + """Stream agent output through a caller-owned LeonAgent surface.""" + try: + async for chunk in self.agent.astream( + {"messages": [{"role": "user", "content": message}]}, + config={"configurable": {"thread_id": thread_id}}, + stream_mode=stream_mode, + ): + yield chunk + if max_budget_usd is not None and self.runtime.cost > max_budget_usd: + raise RuntimeError(f"max_budget_usd exceeded: cost={self.runtime.cost:.6f} budget={max_budget_usd:.6f}") + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + async def aclear_thread(self, thread_id: str = "default") -> None: + """Clear turn-scoped state for a thread while preserving session accumulators.""" + try: + await self.agent.aclear(thread_id) + self._invalidate_system_prompt_cache() + self.system_prompt = self._compose_system_prompt() + self.agent.system_prompt = SystemMessage(content=[{"type": "text", "text": self.system_prompt}]) + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + def clear_thread(self, thread_id: str = "default") -> None: + """Sync wrapper for aclear_thread().""" + import asyncio + + async def _aclear(): + await self.aclear_thread(thread_id) + + try: + if hasattr(self, "_event_loop") and self._event_loop: + self._event_loop.run_until_complete(_aclear()) + else: + asyncio.run(_aclear()) + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + def get_pending_permission_requests(self, thread_id: str | None = None) -> list[dict]: + requests = list(self._app_state.pending_permission_requests.values()) + if thread_id is not None: + requests = [item for item in requests if item.get("thread_id") == thread_id] + return requests + + def get_thread_permission_rules(self, thread_id: str | None = None) -> dict[str, Any]: + state = self._app_state.tool_permission_context + return { + "thread_id": thread_id, + "scope": "session", + "managed_only": state.allowManagedPermissionRulesOnly, + "rules": { + "allow": list(state.alwaysAllowRules.get("session", [])), + "deny": list(state.alwaysDenyRules.get("session", [])), + "ask": list(state.alwaysAskRules.get("session", [])), + }, + } + + def add_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool: + if self._app_state.tool_permission_context.allowManagedPermissionRulesOnly: + return False + + def _update(state: AppState) -> AppState: + permission_state = state.tool_permission_context.model_copy(deep=True) + for bucket in ( + permission_state.alwaysAllowRules.setdefault("session", []), + permission_state.alwaysDenyRules.setdefault("session", []), + permission_state.alwaysAskRules.setdefault("session", []), + ): + while tool_name in bucket: + bucket.remove(tool_name) + target_bucket = { + "allow": permission_state.alwaysAllowRules.setdefault("session", []), + "deny": permission_state.alwaysDenyRules.setdefault("session", []), + "ask": permission_state.alwaysAskRules.setdefault("session", []), + }[behavior] + if tool_name not in target_bucket: + target_bucket.append(tool_name) + return state.model_copy(update={"tool_permission_context": permission_state}) + + self._app_state.set_state(_update) + return True + + def remove_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool: + removed = False + + def _update(state: AppState) -> AppState: + nonlocal removed + permission_state = state.tool_permission_context.model_copy(deep=True) + bucket = { + "allow": permission_state.alwaysAllowRules.setdefault("session", []), + "deny": permission_state.alwaysDenyRules.setdefault("session", []), + "ask": permission_state.alwaysAskRules.setdefault("session", []), + }[behavior] + if tool_name in bucket: + bucket.remove(tool_name) + removed = True + return state.model_copy(update={"tool_permission_context": permission_state}) + + self._app_state.set_state(_update) + return removed + + def resolve_permission_request( + self, + request_id: str, + *, + decision: str, + message: str | None = None, + answers: list[dict[str, Any]] | None = None, + annotations: dict[str, Any] | None = None, + ) -> bool: + pending = self._app_state.pending_permission_requests.get(request_id) + if pending is None: + return False + + resolved = dict(self._app_state.resolved_permission_requests) + payload = { + **pending, + "decision": decision, + "message": message or pending.get("message"), + } + if answers is not None: + payload["answers"] = answers + if annotations is not None: + payload["annotations"] = annotations + resolved[request_id] = payload + still_pending = dict(self._app_state.pending_permission_requests) + still_pending.pop(request_id, None) + self._app_state.set_state( + lambda prev: prev.model_copy( + update={ + "pending_permission_requests": still_pending, + "resolved_permission_requests": resolved, + } + ) + ) + return True + + def drop_permission_request(self, request_id: str) -> bool: + had_pending = request_id in self._app_state.pending_permission_requests + had_resolved = request_id in self._app_state.resolved_permission_requests + if not had_pending and not had_resolved: + return False + + def _drop(state: AppState) -> AppState: + pending = dict(state.pending_permission_requests) + resolved = dict(state.resolved_permission_requests) + pending.pop(request_id, None) + resolved.pop(request_id, None) + return state.model_copy( + update={ + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + } + ) + + self._app_state.set_state(_drop) + return True + def get_response(self, message: str, thread_id: str = "default", **kwargs) -> str: """Get agent's text response. @@ -1411,7 +1729,7 @@ def cleanup(self): def create_leon_agent( - model_name: str = DEFAULT_MODEL, + model_name: str | None = None, api_key: str | None = None, workspace_root: str | Path | None = None, sandbox: Any = None, @@ -1421,7 +1739,7 @@ def create_leon_agent( """Create Leon Agent. Args: - model_name: Model name + model_name: Model name. None means "let LeonAgent resolve defaults". api_key: API key workspace_root: Workspace directory sandbox: Sandbox instance, name string, or None for local diff --git a/core/runtime/checkpoint_store.py b/core/runtime/checkpoint_store.py new file mode 100644 index 000000000..1a27ada07 --- /dev/null +++ b/core/runtime/checkpoint_store.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Protocol + + +@dataclass(frozen=True) +class ThreadCheckpointState: + messages: list + tool_permission_context: dict[str, Any] + pending_permission_requests: dict[str, dict[str, Any]] + resolved_permission_requests: dict[str, dict[str, Any]] + memory_compaction_state: dict[str, Any] + mcp_instruction_state: dict[str, Any] + + +class CheckpointStore(Protocol): + async def load(self, thread_id: str) -> ThreadCheckpointState | None: ... + + async def save(self, thread_id: str, state: ThreadCheckpointState) -> None: ... diff --git a/core/runtime/cleanup.py b/core/runtime/cleanup.py new file mode 100644 index 000000000..d55600684 --- /dev/null +++ b/core/runtime/cleanup.py @@ -0,0 +1,116 @@ +"""CleanupRegistry — priority-ordered async cleanup for LeonAgent lifecycle. + +Aligned with CC Pattern 5: Lifecycle & Cleanup. +Priority numbers: lower = runs first. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +from collections.abc import Awaitable, Callable +from itertools import groupby + +logger = logging.getLogger(__name__) + + +class CleanupRegistry: + """Registry of async cleanup functions executed in priority order on shutdown. + + Usage: + registry = CleanupRegistry() + registry.register(close_db, priority=1) + registry.register(close_sandbox, priority=2) + await registry.run_cleanup() + """ + + def __init__(self): + # List of (priority, fn) — not a dict because same priority can have multiple fns + self._entries: list[tuple[int, Callable[[], Awaitable[None] | None]]] = [] + self._timeout_s = 2.0 + self._cleanup_task: asyncio.Task[None] | None = None + self._shutdown_in_progress = False + self._signal_loop: asyncio.AbstractEventLoop | None = None + self._setup_signal_handlers() + + def register(self, fn: Callable[[], Awaitable[None] | None], priority: int = 5) -> Callable[[], None]: + """Register a cleanup function. + + Args: + fn: Sync or async callable that releases resources. + priority: Execution order — lower number runs first (1 before 2). + """ + entry = (priority, fn) + self._entries.append(entry) + + def unregister() -> None: + try: + self._entries.remove(entry) + except ValueError: + return + + return unregister + + async def run_cleanup(self) -> None: + """Execute all registered cleanup functions in priority order. + + Different priority tiers run in order. Entries inside the same priority + tier run concurrently so one slow cleanup does not serialize its peers. + """ + if self._cleanup_task is not None: + await asyncio.shield(self._cleanup_task) + return + + async def _run_all() -> None: + sorted_entries = sorted(self._entries, key=lambda x: x[0]) + for priority, grouped_entries in groupby(sorted_entries, key=lambda x: x[0]): + await asyncio.gather( + *(self._run_entry(priority, fn) for _, fn in grouped_entries), + return_exceptions=True, + ) + + self._shutdown_in_progress = True + self._cleanup_task = asyncio.create_task(_run_all()) + await asyncio.shield(self._cleanup_task) + + def is_shutting_down(self) -> bool: + return self._shutdown_in_progress + + async def _run_entry(self, priority: int, fn: Callable[[], Awaitable[None] | None]) -> None: + try: + result = fn() + if asyncio.iscoroutine(result): + await asyncio.wait_for(result, timeout=self._timeout_s) + except TimeoutError: + logger.warning("CleanupRegistry: cleanup fn %s timed out after %.2fs", fn, self._timeout_s) + except Exception: + logger.exception("CleanupRegistry: error in cleanup fn %s (priority=%d)", fn, priority) + + def _setup_signal_handlers(self) -> None: + """Register SIGINT/SIGTERM handlers to trigger async cleanup.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return # No running loop yet — signal handlers set up later + self._signal_loop = loop + + signals = [signal.SIGINT, signal.SIGTERM] + if hasattr(signal, "SIGHUP"): + signals.append(signal.SIGHUP) + + for sig in signals: + try: + loop.add_signal_handler(sig, self._handle_signal) + except (NotImplementedError, RuntimeError): + # Windows or non-main thread — skip signal handler setup + pass + + def _handle_signal(self) -> None: + loop = self._signal_loop + if loop is None: + return + if loop.is_running(): + loop.create_task(self.run_cleanup()) + return + loop.run_until_complete(self.run_cleanup()) diff --git a/core/runtime/errors.py b/core/runtime/errors.py index 74ffbfc1e..591ff3090 100644 --- a/core/runtime/errors.py +++ b/core/runtime/errors.py @@ -1,4 +1,13 @@ class InputValidationError(Exception): """Tool parameter validation failed.""" - pass + def __init__( + self, + message: str, + *, + error_code: str | None = None, + details: list[dict[str, object]] | None = None, + ) -> None: + super().__init__(message) + self.error_code = error_code + self.details = [] if details is None else details diff --git a/core/runtime/fork.py b/core/runtime/fork.py new file mode 100644 index 000000000..c3992cf74 --- /dev/null +++ b/core/runtime/fork.py @@ -0,0 +1,91 @@ +"""Context fork for sub-agent spawning. + +When a sub-agent is spawned, it inherits workspace/model/permission configuration +from the parent but gets its own isolated messages and session identity. + +Aligned with CC createSubagentContext() field-by-field fork table. +""" + +from __future__ import annotations + +import copy +import uuid + +from .abort import create_child_abort_controller +from .state import BootstrapConfig, ToolUseContext + + +def fork_context(parent: BootstrapConfig) -> BootstrapConfig: + """Create a child BootstrapConfig for a sub-agent. + + Inherits all workspace identity, model settings, and security flags + from parent. Generates a fresh session_id and sets parent_session_id. + Messages, cost, and turn_count live in AppState — not here. + """ + return BootstrapConfig( + workspace_root=parent.workspace_root, + original_cwd=parent.original_cwd, + project_root=parent.project_root, + cwd=parent.cwd, + model_name=parent.model_name, + api_key=parent.api_key, + sandbox_type=parent.sandbox_type, + block_dangerous_commands=parent.block_dangerous_commands, + block_network_commands=parent.block_network_commands, + enable_audit_log=parent.enable_audit_log, + enable_web_tools=parent.enable_web_tools, + allowed_file_extensions=parent.allowed_file_extensions, + extra_allowed_paths=parent.extra_allowed_paths, + max_turns=parent.max_turns, + # Fresh session identity + session_id=uuid.uuid4().hex, + parent_session_id=parent.session_id, + total_cost_usd=parent.total_cost_usd, + total_tool_duration_ms=parent.total_tool_duration_ms, + # Model settings + model_provider=parent.model_provider, + base_url=parent.base_url, + context_limit=parent.context_limit, + ) + + +def create_subagent_context( + parent: ToolUseContext, + *, + share_set_app_state: bool = False, +) -> ToolUseContext: + """Create a minimally isolated ToolUseContext for sub-agents. + + Default contract: + - bootstrap: fresh fork + - set_app_state: NO-OP + - set_app_state_for_tasks: always reaches the root/session store + - turn-local refs: fresh + - file cache/messages: cloned snapshots + """ + read_file_state = parent.read_file_state + if hasattr(read_file_state, "clone") and callable(read_file_state.clone): + cloned_read_file_state = read_file_state.clone() + else: + # @@@sa-04-read-file-state-clone + # Subagent fork boundaries must isolate nested file cache state too; + # a shallow dict copy leaks child edits back into the parent cache. + cloned_read_file_state = copy.deepcopy(read_file_state) + return ToolUseContext( + bootstrap=fork_context(parent.bootstrap), + get_app_state=parent.get_app_state, + set_app_state=parent.set_app_state if share_set_app_state else (lambda updater: None), + set_app_state_for_tasks=parent.set_app_state_for_tasks or parent.set_app_state, + refresh_tools=parent.refresh_tools, + can_use_tool=parent.can_use_tool, + request_permission=parent.request_permission, + consume_permission_resolution=parent.consume_permission_resolution, + read_file_state=cloned_read_file_state, + loaded_nested_memory_paths=set(), + discovered_skill_names=set(), + discovered_tool_names=set(), + nested_memory_attachment_triggers=set(), + abort_controller=create_child_abort_controller(getattr(parent, "abort_controller", None)), + messages=list(parent.messages), + thread_id=parent.thread_id, + ) diff --git a/core/runtime/langgraph_checkpoint_store.py b/core/runtime/langgraph_checkpoint_store.py new file mode 100644 index 000000000..7e4c1e210 --- /dev/null +++ b/core/runtime/langgraph_checkpoint_store.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import inspect +from typing import Any, cast + +from .checkpoint_store import ThreadCheckpointState + + +class LangGraphCheckpointStore: + def __init__(self, saver: Any): + self._saver = saver + + async def load(self, thread_id: str) -> ThreadCheckpointState | None: + checkpoint = await self._aget_checkpoint(thread_id) + if checkpoint is None: + return None + channel_values = dict(checkpoint.get("channel_values", {}) or {}) + return ThreadCheckpointState( + messages=list(channel_values.get("messages", [])), + tool_permission_context=dict(channel_values.get("tool_permission_context", {}) or {}), + pending_permission_requests=dict(channel_values.get("pending_permission_requests", {}) or {}), + resolved_permission_requests=dict(channel_values.get("resolved_permission_requests", {}) or {}), + memory_compaction_state=dict(channel_values.get("memory_compaction_state", {}) or {}), + mcp_instruction_state=dict(channel_values.get("mcp_instruction_state", {}) or {}), + ) + + async def save(self, thread_id: str, state: ThreadCheckpointState) -> None: + from langgraph.checkpoint.base import CheckpointMetadata, create_checkpoint, empty_checkpoint + + existing_checkpoint = await self._aget_checkpoint(thread_id) + checkpoint = create_checkpoint( + self._normalize_checkpoint_for_write(existing_checkpoint, empty_checkpoint), + None, + len(state.messages), + ) + checkpoint["channel_values"] = { + "messages": state.messages, + "tool_permission_context": state.tool_permission_context, + "pending_permission_requests": state.pending_permission_requests, + "resolved_permission_requests": state.resolved_permission_requests, + "memory_compaction_state": state.memory_compaction_state, + "mcp_instruction_state": state.mcp_instruction_state, + } + new_versions: dict[str, Any] = {} + get_next_version = getattr(self._saver, "get_next_version", None) + if callable(get_next_version): + current_versions = dict(checkpoint.get("channel_versions", {}) or {}) + for channel_name in checkpoint["channel_values"]: + new_versions[channel_name] = get_next_version(current_versions.get(channel_name), None) + checkpoint["channel_versions"] = {**current_versions, **new_versions} + checkpoint["updated_channels"] = list(new_versions) + metadata: CheckpointMetadata = { + "source": "loop", + "step": len(state.messages), + } + await self._saver.aput(self._checkpoint_config(thread_id), checkpoint, metadata, new_versions) + + async def _aget_checkpoint(self, thread_id: str) -> dict[str, Any] | None: + cfg = self._checkpoint_config(thread_id) + aget_tuple = getattr(self._saver, "aget_tuple", None) + if callable(aget_tuple): + checkpoint_tuple_result = aget_tuple(cfg) + checkpoint_tuple = await checkpoint_tuple_result if inspect.isawaitable(checkpoint_tuple_result) else checkpoint_tuple_result + checkpoint_value = getattr(checkpoint_tuple, "checkpoint", None) + if isinstance(checkpoint_value, dict): + return checkpoint_value + aget = getattr(self._saver, "aget", None) + if callable(aget): + checkpoint_result = aget(cfg) + checkpoint_value = await checkpoint_result if inspect.isawaitable(checkpoint_result) else checkpoint_result + if isinstance(checkpoint_value, dict): + return cast(dict[str, Any], checkpoint_value) + return None + + @staticmethod + def _normalize_checkpoint_for_write(raw_checkpoint: Any, empty_checkpoint_factory: Any) -> Any: + checkpoint = empty_checkpoint_factory() + if not isinstance(raw_checkpoint, dict): + return checkpoint + # @@@checkpoint-shape-normalization - local/simple savers often persist only + # channel_values, while LangGraph savers expect the full checkpoint shape. + # Normalize both into one writable base contract before versioning. + for key, default_value in checkpoint.items(): + if key not in raw_checkpoint: + continue + value = raw_checkpoint[key] + if isinstance(default_value, dict): + checkpoint[key] = dict(value or {}) + elif isinstance(default_value, list): + checkpoint[key] = list(value or []) + else: + checkpoint[key] = value + return checkpoint + + @staticmethod + def _checkpoint_config(thread_id: str) -> dict[str, Any]: + return {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} diff --git a/core/runtime/loop.py b/core/runtime/loop.py new file mode 100644 index 000000000..be8136735 --- /dev/null +++ b/core/runtime/loop.py @@ -0,0 +1,2268 @@ +"""QueryLoop — self-managing agentic tool loop replacing LangGraph create_agent. + +Implements CC Pattern 1: Agentic Tool Loop (queryLoop). + +Design: +- AsyncGenerator that alternates LLM sampling and tool execution. +- Exposes the same .astream(input, config, stream_mode) interface as CompiledStateGraph. +- Middleware chain (SpillBuffer/Monitor/PromptCaching/Memory/Steering/ToolRunner) is + preserved exactly — awrap_model_call and awrap_tool_call pass through in order. +- is_concurrency_safe tools execute in parallel; others execute serially. +- Checkpointer (AsyncSqliteSaver) stores/restores message history across calls. +""" + +from __future__ import annotations + +import asyncio +import copy +import inspect +import json +import logging +import re +import uuid +from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass +from enum import StrEnum +from types import SimpleNamespace +from typing import Any + +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, RemoveMessage, SystemMessage, ToolMessage + +from core.runtime.middleware import ( + AgentMiddleware, + ModelRequest, + ModelResponse, + ToolCallRequest, +) + +from .abort import AbortController +from .checkpoint_store import CheckpointStore, ThreadCheckpointState +from .langgraph_checkpoint_store import LangGraphCheckpointStore +from .permissions import ToolPermissionContext, evaluate_permission_rules +from .registry import ToolMode, ToolRegistry +from .state import AppState, BootstrapConfig, ToolPermissionState, ToolUseContext +from .validator import _required_sets_match + +logger = logging.getLogger(__name__) + +_NOOP_HANDLER: Any = None # placeholder for innermost "handler" in middleware chain +_ESCALATED_MAX_OUTPUT_TOKENS = 64000 +_FLOOR_OUTPUT_TOKENS = 3000 +_CONTEXT_OVERFLOW_SAFETY_BUFFER = 1000 +_TRANSIENT_API_MAX_RETRIES = 3 +_TRANSIENT_API_BASE_DELAY_SECONDS = 0.5 +_PROMPT_TOO_LONG_NOTICE_TEXT = "Prompt is too long. Automatic recovery exhausted. Clear the thread or start a new one." + + +class TerminalReason(StrEnum): + completed = "completed" + aborted_streaming = "aborted_streaming" + aborted_tools = "aborted_tools" + model_error = "model_error" + max_turns = "max_turns" + prompt_too_long = "prompt_too_long" + blocking_limit = "blocking_limit" + image_error = "image_error" + hook_stopped = "hook_stopped" + stop_hook_prevented = "stop_hook_prevented" + + +class ContinueReason(StrEnum): + next_turn = "next_turn" + api_retry = "api_retry" + collapse_drain_retry = "collapse_drain_retry" + reactive_compact_retry = "reactive_compact_retry" + max_output_tokens_escalate = "max_output_tokens_escalate" + max_output_tokens_recovery = "max_output_tokens_recovery" + stop_hook_blocking = "stop_hook_blocking" + token_budget_continuation = "token_budget_continuation" + + +@dataclass(frozen=True) +class TerminalState: + reason: TerminalReason + turn_count: int + error: str | None = None + + +@dataclass(frozen=True) +class ContinueState: + reason: ContinueReason + + +@dataclass(frozen=True) +class _ModelErrorRecoveryResult: + messages: list + transition: ContinueState | None + max_output_tokens_recovery_count: int + has_attempted_reactive_compact: bool + max_output_tokens_override: int | None + transient_api_retry_count: int + terminal: TerminalState | None + + +@dataclass(frozen=True) +class _ModelErrorContext: + exc: Exception + error_text: str + thread_id: str + messages: list + turn: int + transition: ContinueState | None + max_output_tokens_recovery_count: int + has_attempted_reactive_compact: bool + max_output_tokens_override: int | None + transient_api_retry_count: int + + +@dataclass +class _TrackedTool: + order: int + tool_call: dict[str, Any] + is_concurrency_safe: bool + status: str = "queued" + task: asyncio.Task[None] | None = None + result: ToolMessage | None = None + + +class StreamingToolExecutor: + def __init__( + self, + *, + execute_tool: Callable[[dict[str, Any], ToolUseContext | None], Awaitable[ToolMessage]], + is_concurrency_safe: Callable[[dict[str, Any]], bool], + lookup_tool: Callable[[str], Any | None], + tool_context: ToolUseContext | None, + ): + self._execute_tool = execute_tool + self._is_concurrency_safe = is_concurrency_safe + self._lookup_tool = lookup_tool + self._tool_context = tool_context + self._tracked: list[_TrackedTool] = [] + self._discarded = False + + def _tool_name(self, tool_call: dict[str, Any]) -> str: + return tool_call.get("name") or tool_call.get("function", {}).get("name", "") + + async def add_tool(self, tool_call: dict[str, Any]) -> None: + if self._discarded: + return + name = self._tool_name(tool_call) + if self._lookup_tool(name) is None: + self._tracked.append( + _TrackedTool( + order=len(self._tracked), + tool_call=tool_call, + is_concurrency_safe=False, + status="completed", + result=self._tool_error(tool_call, f"Tool '{name}' not found"), + ) + ) + return + tracked = _TrackedTool( + order=len(self._tracked), + tool_call=tool_call, + is_concurrency_safe=self._is_concurrency_safe(tool_call), + ) + self._tracked.append(tracked) + self._process_queue() + + async def get_completed_results(self) -> list[ToolMessage]: + await asyncio.sleep(0) + self._process_queue() + ready: list[ToolMessage] = [] + for tracked in self._tracked: + if tracked.status == "yielded": + continue + if tracked.status == "completed" and tracked.result is not None: + tracked.status = "yielded" + ready.append(tracked.result) + continue + break + return ready + + async def drain_remaining(self) -> list[ToolMessage]: + while True: + self._process_queue() + running = [tracked.task for tracked in self._tracked if tracked.status == "executing" and tracked.task is not None] + if not running: + break + await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED) + self._process_queue() + remaining: list[ToolMessage] = [] + for tracked in self._tracked: + if tracked.status == "yielded": + continue + if tracked.status == "completed" and tracked.result is not None: + tracked.status = "yielded" + remaining.append(tracked.result) + return remaining + + async def discard(self, reason: str) -> list[ToolMessage]: + # @@@streaming-tool-discard + # ql-05 must not leave orphaned tool tasks behind when streaming exits + # early. Synthetic error emission is still a later hardening pass, but + # task cleanup itself must happen now. + self._discarded = True + running: list[asyncio.Task[None]] = [] + for tracked in self._tracked: + if tracked.status == "queued": + tracked.status = "completed" + tracked.result = self._synthetic_error(tracked.tool_call, reason) + continue + if tracked.status == "executing" and tracked.task is not None: + tracked.task.cancel() + running.append(tracked.task) + if running: + await asyncio.gather(*running, return_exceptions=True) + for tracked in self._tracked: + if tracked.status == "executing": + tracked.status = "completed" + tracked.result = self._synthetic_error(tracked.tool_call, reason) + return await self.drain_remaining() + + def _process_queue(self) -> None: + if self._discarded: + return + for tracked in self._tracked: + if tracked.status != "queued": + continue + if not self._can_execute(tracked): + break + tracked.status = "executing" + tracked.task = asyncio.create_task(self._run_tool(tracked)) + + def _can_execute(self, tracked: _TrackedTool) -> bool: + executing = [item for item in self._tracked if item.status == "executing"] + if not executing: + return True + if not tracked.is_concurrency_safe: + return False + return all(item.is_concurrency_safe for item in executing) + + async def _run_tool(self, tracked: _TrackedTool) -> None: + # @@@streaming-tool-task-exit + # ql-05 cannot let middleware-level exceptions disappear into a dead + # task. Every tool_use must resolve to a ToolMessage, and queue + # progression must re-run immediately when a task exits. + try: + tracked.result = await self._execute_tool(tracked.tool_call, self._tool_context) + tracked.status = "completed" + except asyncio.CancelledError: + raise + except Exception as exc: + tracked.result = self._tool_error(tracked.tool_call, str(exc)) + tracked.status = "completed" + finally: + if self._should_abort_siblings(tracked): + await self._abort_siblings( + excluding=tracked, + reason="sibling aborted after bash error", + ) + if not self._discarded: + self._process_queue() + + def _should_abort_siblings(self, tracked: _TrackedTool) -> bool: + if tracked.result is None: + return False + return self._tool_name(tracked.tool_call).lower() == "bash" and "" in tracked.result.content + + async def _abort_siblings(self, *, excluding: _TrackedTool, reason: str) -> None: + # @@@bash-sibling-abort + # Claude Code only fan-outs this abort for bash failures. Keep it + # local to the current executor iteration so the parent loop survives + # and later turns can continue with explicit tool errors. + self._discarded = True + running: list[asyncio.Task[None]] = [] + for tracked in self._tracked: + if tracked is excluding or tracked.status in {"completed", "yielded"}: + continue + if tracked.status == "queued": + tracked.status = "completed" + tracked.result = self._tool_error(tracked.tool_call, reason) + continue + if tracked.status == "executing" and tracked.task is not None: + tracked.task.cancel() + running.append(tracked.task) + if running: + await asyncio.gather(*running, return_exceptions=True) + for tracked in self._tracked: + if tracked is excluding or tracked.status != "executing": + continue + tracked.status = "completed" + tracked.result = self._tool_error(tracked.tool_call, reason) + + def _synthetic_error(self, tool_call: dict[str, Any], reason: str) -> ToolMessage: + return self._tool_error( + tool_call, + f"streaming discarded: {reason}", + ) + + def _tool_error(self, tool_call: dict[str, Any], error_text: str) -> ToolMessage: + return ToolMessage( + content=f"{error_text}", + tool_call_id=tool_call.get("id", ""), + name=self._tool_name(tool_call), + ) + + +class QueryLoop: + """Self-managing query loop replacing create_agent. + + The .astream() method is an AsyncGenerator that yields dicts compatible + with LangGraph's stream_mode="updates": + {"agent": {"messages": [AIMessage(...)]}} + {"tools": {"messages": [ToolMessage(...), ...]}} + + The checkpointer attribute is set post-construction (mirrors create_agent pattern). + """ + + @property + def checkpointer(self) -> Any: + return self._checkpointer + + @checkpointer.setter + def checkpointer(self, value: Any) -> None: + self._checkpointer = value + self._checkpoint_store = LangGraphCheckpointStore(value) if value is not None else None + + def __init__( + self, + model: Any, + system_prompt: SystemMessage, + middleware: list[AgentMiddleware], + checkpointer: Any, + registry: ToolRegistry, + app_state: AppState | None = None, + runtime: Any = None, + bootstrap: BootstrapConfig | None = None, + refresh_tools: Any = None, + max_turns: int = 100, + ): + self.model = model + self.system_prompt = system_prompt + self.middleware = middleware + self.checkpointer = checkpointer + self._checkpoint_store: CheckpointStore | None + self._registry = registry + self._app_state = app_state + self._runtime = runtime + self._bootstrap = bootstrap + self._refresh_tools = refresh_tools + self._memory_middleware = next( + (mw for mw in middleware if hasattr(mw, "compact_boundary_index")), + None, + ) + # @@@sa-02-session-tool-refs + # These refs must survive across turns within the same loop/session, + # while turn-local attachment triggers stay ephemeral per ToolUseContext. + self._tool_read_file_state: dict[str, Any] = {} + self._tool_loaded_nested_memory_paths: set[str] = set() + self._tool_discovered_skill_names: set[str] = set() + self._tool_discovered_tool_names_by_thread: dict[str, set[str]] = {} + self._tool_abort_controller = AbortController() + self.max_turns = max_turns + self.last_terminal: TerminalState | None = None + self.last_continue: ContinueState | None = None + + # ------------------------------------------------------------------------- + # Public streaming interface (LangGraph-compatible) + # ------------------------------------------------------------------------- + + async def query( + self, + input: dict, + config: dict | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Raw loop generator with an explicit final terminal event.""" + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + + # Set thread context so MemoryMiddleware can find thread_id via ContextVar + from sandbox.thread_context import set_current_thread_id + + set_current_thread_id(thread_id) + + # Load message history and thread-scoped runtime state from checkpointer + persisted = await self._hydrate_thread_state_from_checkpoint(thread_id) + messages = list(persisted["messages"]) + self._restore_discovered_tool_names_from_messages(thread_id, messages) + + # Parse and append new input messages + new_msgs = self._parse_input(input) + messages.extend(new_msgs) + self._sync_app_state(messages=messages, turn_count=0) + + terminal: TerminalState | None = None + transition: ContinueState | None = None + pending_system_notices: list[HumanMessage] = [] + max_output_tokens_recovery_count = 0 + has_attempted_reactive_compact = False + max_output_tokens_override: int | None = None + transient_api_retry_count = 0 + + turn = 0 + try: + while turn < self.max_turns: + turn += 1 + tool_context = self._build_tool_use_context(messages, thread_id=thread_id) + + messages_for_query, injected_messages = await self._build_query_messages(messages, config) + if injected_messages: + # @@@steer-persist - queue/steer messages accepted before the + # next model call must become durable conversation state, not + # request-only hints, or later replay/history lies about what + # the user actually said mid-run. + messages.extend(injected_messages) + self._sync_app_state(messages=messages, turn_count=turn) + self._sync_tool_context_messages(tool_context, messages_for_query) + + # --- Call model through middleware chain --- + streamed_tool_results: list[ToolMessage] = [] + pending_tool_results: list[ToolMessage] = [] + used_streaming_overlap = False + response: ModelResponse | None = None + ai_msg: AIMessage | None = None + tool_calls: list[dict[str, Any]] = [] + try: + if self._can_stream_tools(): + used_streaming_overlap = True + async for stream_event in self._stream_model_with_tool_overlap( + messages_for_query, + config, + thread_id=thread_id, + tool_context=tool_context, + max_output_tokens_override=max_output_tokens_override, + ): + if stream_event["type"] == "message_chunk": + yield {"message_chunk": stream_event["chunk"]} + continue + if stream_event["type"] == "tools": + chunk_messages = stream_event["messages"] + streamed_tool_results.extend(chunk_messages) + yield {"tools": {"messages": chunk_messages}} + continue + response = stream_event["response"] + ai_msg = stream_event["ai_message"] + tool_calls = stream_event["tool_calls"] + pending_tool_results = stream_event["remaining_tool_results"] + else: + response = await self._invoke_model( + messages_for_query, + config, + thread_id=thread_id, + max_output_tokens_override=max_output_tokens_override, + ) + except Exception as exc: + self._collect_memory_system_notices(pending_system_notices) + handled = await self._handle_model_error_recovery( + exc=exc, + thread_id=thread_id, + messages=messages, + turn=turn, + transition=transition, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + ) + if handled is not None: + messages = handled.messages + transition = handled.transition + max_output_tokens_recovery_count = handled.max_output_tokens_recovery_count + has_attempted_reactive_compact = handled.has_attempted_reactive_compact + max_output_tokens_override = handled.max_output_tokens_override + transient_api_retry_count = handled.transient_api_retry_count + if handled.terminal is not None: + terminal = handled.terminal + break + self._sync_app_state(messages=messages, turn_count=turn) + continue + terminal = TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error=str(exc), + ) + break + + if response is None or ai_msg is None: + ai_messages = [m for m in (response.result if response else []) if isinstance(m, AIMessage)] + if not ai_messages: + # No AI message — unexpected; treat as terminal + terminal = TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error="model returned no AIMessage", + ) + break + ai_msg = ai_messages[0] + self._collect_memory_system_notices(pending_system_notices) + self._sync_tool_context_messages( + tool_context, + response.request_messages or messages_for_query, + ) + + truncated = self._handle_truncated_response_recovery( + ai_msg=ai_msg, + messages=messages, + turn=turn, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + max_output_tokens_override=max_output_tokens_override, + ) + if truncated is not None: + messages = truncated["messages"] + transition = truncated["transition"] + max_output_tokens_recovery_count = truncated["max_output_tokens_recovery_count"] + max_output_tokens_override = truncated["max_output_tokens_override"] + self._sync_app_state(messages=messages, turn_count=turn) + if truncated["yield_ai"]: + yield {"agent": {"messages": [ai_msg]}} + if truncated["terminal"] is not None: + terminal = truncated["terminal"] + break + continue + + self._sync_app_state(messages=messages, turn_count=turn) + + if not tool_calls: + tool_calls = getattr(ai_msg, "tool_calls", None) or [] + if not tool_calls: + # Also check additional_kwargs for older message formats + tool_calls = ai_msg.additional_kwargs.get("tool_calls", []) + + if not tool_calls and not self._ai_message_has_visible_content(ai_msg): + terminal_followthrough_notice = self._get_terminal_followthrough_notice(messages) + if terminal_followthrough_notice is not None: + ai_msg = self._build_terminal_followthrough_fallback(terminal_followthrough_notice) + else: + chat_followthrough_notice = self._get_chat_followthrough_notice(messages) + if chat_followthrough_notice is not None: + ai_msg = self._build_chat_followthrough_fallback(chat_followthrough_notice) + + # Yield agent update (stream_mode="updates" format) + yield {"agent": {"messages": [ai_msg]}} + + if not tool_calls: + # No tool calls → agent is done + if self._ai_message_has_visible_content(ai_msg): + messages.append(ai_msg) + terminal = TerminalState( + reason=TerminalReason.completed, + turn_count=turn, + ) + break + + # Expose current messages for forkContext sub-agent spawning + from sandbox.thread_context import set_current_messages + + set_current_messages(messages + [ai_msg]) + + if used_streaming_overlap: + if pending_tool_results: + yield {"tools": {"messages": pending_tool_results}} + tool_results = streamed_tool_results + pending_tool_results + else: + # --- Execute tools through middleware chain --- + try: + tool_results = await self._execute_tools(tool_calls, response, tool_context) + except Exception as exc: + terminal = TerminalState( + reason=TerminalReason.aborted_tools, + turn_count=turn, + error=str(exc), + ) + break + + # Yield tools update + yield {"tools": {"messages": tool_results}} + + # Advance message history for next turn + messages.append(ai_msg) + messages.extend(tool_results) + if self._tool_results_include_permission_request(tool_results): + terminal = TerminalState( + reason=TerminalReason.completed, + turn_count=turn, + ) + self._sync_app_state(messages=messages, turn_count=turn) + break + await self._refresh_tools_between_turns(tool_context) + transition = ContinueState(reason=ContinueReason.next_turn) + max_output_tokens_recovery_count = 0 + has_attempted_reactive_compact = False + max_output_tokens_override = None + transient_api_retry_count = 0 + self._sync_app_state(messages=messages, turn_count=turn) + except asyncio.CancelledError: + # @@@cancel-persists-live-state - accepted user input from the + # current run must not evaporate just because the run is cancelled + # before the next terminal save. + messages = self._append_system_notices(messages, pending_system_notices) + await self._save_messages(thread_id, messages) + self._sync_app_state(messages=messages, turn_count=turn) + raise + + if terminal is None: + terminal = TerminalState( + reason=TerminalReason.max_turns, + turn_count=turn, + ) + + # Persist message history + self._collect_memory_system_notices(pending_system_notices) + visible_terminal_error = self._build_visible_terminal_error_message(terminal, messages) + if visible_terminal_error is not None: + messages.append(visible_terminal_error) + terminal_notice = self._build_terminal_notice(terminal) + if terminal_notice is not None: + pending_system_notices.append(terminal_notice) + messages = self._append_system_notices(messages, pending_system_notices) + await self._save_messages(thread_id, messages) + self._sync_app_state(messages=messages, turn_count=turn) + self.last_terminal = terminal + self.last_continue = transition + yield {"terminal": terminal, "transition": transition} + + def _make_streaming_tool_executor(self, *, tool_context: ToolUseContext | None) -> StreamingToolExecutor: + return StreamingToolExecutor( + execute_tool=self._execute_single_tool, + is_concurrency_safe=self._tool_is_concurrency_safe, + lookup_tool=self._registry.get, + tool_context=tool_context, + ) + + async def astream( + self, + input: dict, + config: dict | None = None, + stream_mode: str | list[str] = "updates", + ) -> AsyncGenerator[Any, None]: + """Stream agent execution chunks compatible with LangGraph stream modes.""" + requested_modes = [stream_mode] if isinstance(stream_mode, str) else list(stream_mode) + emitted_live_agent_chunks = False + async for event in self.query(input, config=config): + if "terminal" in event: + terminal = event["terminal"] + if terminal is not None and terminal.reason is not TerminalReason.completed: + # @@@astream-terminal-loud-fail + # query() always emits a terminal event, but caller-facing + # astream() must not turn runtime failures into a silent empty + # iterator. Propagate non-completed terminals back to the caller. + raise RuntimeError(self._terminal_error_text(terminal)) + continue + if isinstance(stream_mode, str): + if "message_chunk" in event: + continue + yield event + continue + + if "message_chunk" in event: + if "messages" in requested_modes: + yield ( + "messages", + ( + event["message_chunk"], + {"langgraph_node": "agent"}, + ), + ) + emitted_live_agent_chunks = True + continue + + if "messages" in requested_modes and "agent" in event: + if not emitted_live_agent_chunks: + for msg in event["agent"].get("messages", []): + if not isinstance(msg, AIMessage): + continue + yield ( + "messages", + ( + AIMessageChunk(**msg.model_dump(exclude={"type"})), + {"langgraph_node": "agent"}, + ), + ) + emitted_live_agent_chunks = False + + if "updates" in requested_modes: + yield ("updates", event) + + async def ainvoke( + self, + input: dict, + config: dict | None = None, + stream_mode: str = "updates", + ) -> dict[str, Any]: + """Drain query and return messages plus explicit terminal state.""" + drained_messages: list[Any] = [] + terminal: TerminalState | None = None + transition: ContinueState | None = None + + # @@@ainvoke-drains-astream + # QueryLoop is generator-first. ainvoke exists only as a compatibility + # adapter for callers like LeonAgent.invoke/ainvoke and must not invent + # a separate execution path. + async for event in self.query(input, config=config): + if "terminal" in event: + terminal = event["terminal"] + transition = event.get("transition") + continue + for section in ("agent", "tools"): + drained_messages.extend(event.get(section, {}).get("messages", [])) + + return { + "messages": drained_messages, + "reason": terminal.reason.value if terminal else TerminalReason.completed.value, + "terminal": terminal, + "transition": transition, + } + + async def aget_state(self, config: dict | None = None) -> Any: + """Minimal graph-state bridge for backend/web callers.""" + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + if self._is_runtime_active(): + # @@@active-state-no-clobber - caller surfaces like /permissions and + # /history can poll during an active run. Rehydrating from stale + # checkpoint here would erase live thread-scoped permission state. + values = self._snapshot_live_thread_state(thread_id) + return SimpleNamespace(values=values) + values = await self._hydrate_thread_state_from_checkpoint(thread_id) + return SimpleNamespace(values=values) + + async def aupdate_state( + self, + config: dict | None, + input_data: dict[str, Any] | None, + as_node: str | None = None, + ) -> Any: + """Minimal graph-state update bridge for resumed-thread callers.""" + config = config or {} + input_data = input_data or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + messages = await self._load_messages(thread_id) + raw_updates = input_data.get("messages", []) + + # @@@ql-06-state-bridge - backend/web still speaks the old graph-state + # contract. Only the live caller shapes are supported here: append + # resumed start messages, or apply RemoveMessage-based repairs before + # appending replacement messages. + if as_node == "__start__": + messages.extend(self._parse_input({"messages": raw_updates})) + else: + updates = raw_updates if isinstance(raw_updates, list) else [raw_updates] + remove_ids = {update.id for update in updates if isinstance(update, RemoveMessage) and getattr(update, "id", None)} + if remove_ids: + messages = [message for message in messages if getattr(message, "id", None) not in remove_ids] + messages.extend(update for update in updates if not isinstance(update, RemoveMessage)) + + await self._save_messages(thread_id, messages) + current_turn_count = self._app_state.turn_count if self._app_state is not None else 0 + self._sync_app_state(messages=messages, turn_count=current_turn_count) + self._restore_discovered_tool_names_from_messages(thread_id, messages) + return await self.aget_state(config) + + async def apersist_state(self, thread_id: str) -> None: + """Persist the current thread-scoped loop/app state to the checkpointer.""" + messages = list(self._app_state.messages) if self._app_state is not None else await self._load_messages(thread_id) + await self._save_messages(thread_id, messages) + + # ------------------------------------------------------------------------- + # Model invocation through middleware chain + # ------------------------------------------------------------------------- + + async def _invoke_model( + self, + messages: list, + config: dict, + *, + thread_id: str = "default", + max_output_tokens_override: int | None = None, + ) -> ModelResponse: + """Call model through the full middleware chain (awrap_model_call).""" + + async def innermost_handler(request: ModelRequest) -> ModelResponse: + """Actual model call — innermost of the chain.""" + tools = request.tools or [] + model = request.model + + # Bind tools to model if any + if tools: + try: + bound = model.bind_tools(tools) + except Exception: + bound = model + else: + bound = model + + if max_output_tokens_override is not None and hasattr(bound, "bind"): + try: + bound = bound.bind(max_tokens=max_output_tokens_override) + except Exception: + pass + + # Build message list: system + conversation + call_messages = [] + if request.system_message: + call_messages.append(request.system_message) + call_messages.extend(request.messages) + + result = await bound.ainvoke(call_messages) + if not isinstance(result, list): + result = [result] + return ModelResponse(result=result, request_messages=list(request.messages)) + + # Build ModelRequest + inline_schemas = self._registry.get_inline_schemas(self._get_discovered_tool_names(thread_id)) + request = ModelRequest( + model=self.model, + messages=messages, + system_message=self.system_prompt, + tools=inline_schemas, + ) + + # Walk middleware chain outside-in: each wraps the next. + # Only include middleware that actually overrides awrap_model_call OR wrap_model_call + # (not just inherits the base-class NotImplementedError stub). + handler = innermost_handler + for mw in reversed(self.middleware): + if _mw_overrides_model_call(mw): + handler = _make_model_wrapper(mw, handler) + + return await handler(request) + + def _bind_model( + self, + model: Any, + tools: list | None, + *, + max_output_tokens_override: int | None = None, + ) -> Any: + if tools: + try: + bound = model.bind_tools(tools) + except Exception: + bound = model + else: + bound = model + + if max_output_tokens_override is not None and hasattr(bound, "bind"): + try: + bound = bound.bind(max_tokens=max_output_tokens_override) + except Exception: + pass + return bound + + def _can_stream_tools(self) -> bool: + stream_fn = getattr(self.model, "astream", None) + if not callable(stream_fn): + return False + return type(self.model).__module__ != "unittest.mock" + + async def _prepare_streaming_request( + self, + messages: list, + *, + thread_id: str, + ) -> ModelRequest: + inline_schemas = self._registry.get_inline_schemas(self._get_discovered_tool_names(thread_id)) + request = ModelRequest( + model=self.model, + messages=messages, + system_message=self.system_prompt, + tools=inline_schemas, + ) + + async def prepare_handler(request: ModelRequest) -> ModelResponse: + return ModelResponse( + result=[], + request_messages=list(request.messages), + prepared_request=request, + ) + + handler = prepare_handler + for mw in reversed(self.middleware): + if _mw_overrides_model_call(mw): + handler = _make_model_wrapper(mw, handler) + + response = await handler(request) + return response.prepared_request or request + + async def _stream_model_with_tool_overlap( + self, + messages: list, + config: dict, + *, + thread_id: str, + tool_context: ToolUseContext | None, + max_output_tokens_override: int | None, + ) -> AsyncGenerator[dict[str, Any], None]: + prepared_request = await self._prepare_streaming_request(messages, thread_id=thread_id) + bound = self._bind_model( + prepared_request.model, + prepared_request.tools, + max_output_tokens_override=max_output_tokens_override, + ) + + call_messages = [] + if prepared_request.system_message: + call_messages.append(prepared_request.system_message) + call_messages.extend(prepared_request.messages) + + executor = self._make_streaming_tool_executor(tool_context=tool_context) + aggregate: AIMessageChunk | None = None + seen_tool_ids: set[str] = set() + streamed_tool_calls: list[dict[str, Any]] = [] + + try: + async for chunk in bound.astream(call_messages): + if isinstance(chunk, AIMessage): + chunk = AIMessageChunk(**chunk.model_dump(exclude={"type"})) + elif not isinstance(chunk, AIMessageChunk): + continue + + # @@@stream-chunk-snapshot + # Some providers reuse and mutate the same chunk object across + # yields. Snapshot before yielding/aggregating so the final + # AIMessage cannot collapse to the last empty chunk. + chunk = AIMessageChunk(**chunk.model_dump(exclude={"type"})) + if ( + aggregate is not None + and getattr(chunk, "chunk_position", None) == "last" + and not chunk.content + and not getattr(chunk, "tool_calls", None) + and not getattr(chunk, "invalid_tool_calls", None) + and not getattr(chunk, "tool_call_chunks", None) + and getattr(chunk, "usage_metadata", None) == getattr(aggregate, "usage_metadata", None) + ): + chunk = chunk.model_copy(update={"usage_metadata": None}) + aggregate = chunk if aggregate is None else aggregate + chunk + + yield {"type": "message_chunk", "chunk": chunk} + + tool_call_chunks = getattr(aggregate, "tool_call_chunks", None) or [] + for tool_call in getattr(aggregate, "tool_calls", None) or []: + ready_tool_call = self._normalize_stream_tool_call(tool_call, tool_call_chunks) + if ready_tool_call is None: + continue + call_id = ready_tool_call.get("id") + if not call_id or call_id in seen_tool_ids: + continue + seen_tool_ids.add(call_id) + streamed_tool_calls.append(ready_tool_call) + await executor.add_tool(ready_tool_call) + + completed = await executor.get_completed_results() + if completed: + yield {"type": "tools", "messages": completed} + except Exception: + discarded = await executor.discard(reason="streaming_error") + if discarded: + yield {"type": "tools", "messages": discarded} + raise + + if aggregate is None: + raise RuntimeError("streaming model returned no AIMessageChunk") + + ai_message = AIMessage(**aggregate.model_dump(exclude={"type"})) + self._notify_stream_response(prepared_request, ai_message) + remaining = await executor.drain_remaining() + yield { + "type": "done", + "response": ModelResponse(result=[ai_message], request_messages=list(prepared_request.messages)), + "ai_message": ai_message, + "tool_calls": list(streamed_tool_calls), + "remaining_tool_results": remaining, + } + + def _notify_stream_response(self, request: ModelRequest, ai_message: AIMessage) -> None: + req_dict = {"messages": request.messages} + resp_dict = {"messages": [ai_message]} + for mw in self.middleware: + dispatch = getattr(mw, "_dispatch_monitors", None) + if callable(dispatch): + dispatch("on_response", req_dict, resp_dict) + + async def _build_query_messages(self, messages: list, config: dict) -> tuple[list, list]: + return await self._apply_before_model(list(messages), config) + + async def _apply_before_model(self, messages: list, config: dict) -> tuple[list, list]: + """Run middleware before_model/abefore_model hooks on the live path.""" + current_messages = list(messages) + injected_messages: list[Any] = [] + state = {"messages": current_messages} + + for mw in self.middleware: + update: dict[str, Any] | None = None + abefore = getattr(mw, "abefore_model", None) + before = getattr(mw, "before_model", None) + + if callable(abefore): + maybe_update = abefore(state=state, runtime=None, config=config) + if inspect.isawaitable(maybe_update): + maybe_update = await maybe_update + update = maybe_update if isinstance(maybe_update, dict) else None + elif callable(before): + maybe_update = before(state=state, runtime=None, config=config) + update = maybe_update if isinstance(maybe_update, dict) else None + + if not update: + continue + + new_messages = update.get("messages") + if new_messages: + if not isinstance(new_messages, list): + new_messages = [new_messages] + current_messages.extend(new_messages) + injected_messages.extend(new_messages) + state["messages"] = current_messages + + return current_messages, injected_messages + + def _sync_app_state(self, messages: list, turn_count: int) -> None: + """Keep runtime AppState aligned with the loop's live state.""" + if self._app_state is None: + return + + snapshot = list(messages) + current_cost = self._read_runtime_cost() + bootstrap_cost = self._bootstrap.total_cost_usd if self._bootstrap is not None else 0.0 + cumulative_cost = max(current_cost, self._app_state.total_cost, bootstrap_cost) + compact_boundary_index = self._read_compact_boundary_index() + + # @@@sa-03-cost-accumulator-monotonic + # /clear must preserve session accumulators, so loop sync cannot let a + # lower per-run observation overwrite the accumulated session total. + if self._bootstrap is not None: + self._bootstrap.total_cost_usd = cumulative_cost + + # @@@app-state-sync + # ql-02 needs the loop's local lifecycle to write back into AppState, + # but we still do not have compaction yet. Clamp the boundary so the + # store stays coherent without pretending compaction exists. + def _update(state: AppState) -> AppState: + return state.model_copy( + update={ + "messages": snapshot, + "turn_count": turn_count, + "total_cost": cumulative_cost, + "compact_boundary_index": compact_boundary_index, + } + ) + + self._app_state.set_state(_update) + + def _read_runtime_cost(self) -> float: + if self._runtime is None: + return self._app_state.total_cost if self._app_state is not None else 0.0 + try: + return float(self._runtime.cost) + except Exception: + return self._app_state.total_cost if self._app_state is not None else 0.0 + + def _read_compact_boundary_index(self) -> int: + if self._memory_middleware is None: + return 0 + try: + boundary = int(getattr(self._memory_middleware, "compact_boundary_index", 0)) + except Exception: + return 0 + return max(boundary, 0) + + def _get_discovered_tool_names(self, thread_id: str) -> set[str]: + # @@@dt-03-thread-scoped-deferred-tools - deferred discovery must stay + # isolated per thread_id, or one thread's tool_search silently changes + # another thread's inline schema surface on the next turn. + return self._tool_discovered_tool_names_by_thread.setdefault(thread_id, set()) + + def _restore_discovered_tool_names_from_messages( + self, + thread_id: str, + messages: list, + ) -> None: + discovered: set[str] = set() + for message in messages: + if not isinstance(message, ToolMessage) or getattr(message, "name", None) != "tool_search": + continue + content = getattr(message, "content", None) + if not isinstance(content, str): + continue + try: + payload = json.loads(content) + except Exception: + continue + if not isinstance(payload, list): + continue + for item in payload: + if not isinstance(item, dict): + continue + name = item.get("name") + if not isinstance(name, str): + continue + entry = self._registry.get(name) + if entry is not None and entry.mode == ToolMode.DEFERRED: + discovered.add(name) + self._tool_discovered_tool_names_by_thread[thread_id] = discovered + + def _build_tool_use_context(self, messages: list, *, thread_id: str = "default") -> ToolUseContext | None: + if self._bootstrap is None or self._app_state is None: + return None + has_permission_resolver = self._bootstrap.permission_resolver_scope != "none" + return ToolUseContext( + bootstrap=self._bootstrap, + get_app_state=self._app_state.get_state, + set_app_state=self._app_state.set_state, + refresh_tools=self._refresh_tools, + can_use_tool=lambda name, args, permission_context, request: self._default_can_use_tool( + name=name, + permission_context=permission_context, + ), + request_permission=( + lambda name, args, context, request, message: self._request_permission( + thread_id=thread_id, + name=name, + args=args, + message=message, + ) + ) + if has_permission_resolver + else None, + consume_permission_resolution=lambda name, args, context, request: self._consume_permission_resolution( + thread_id=thread_id, + name=name, + args=args, + ), + read_file_state=self._tool_read_file_state, + loaded_nested_memory_paths=self._tool_loaded_nested_memory_paths, + discovered_skill_names=self._tool_discovered_skill_names, + discovered_tool_names=self._get_discovered_tool_names(thread_id), + nested_memory_attachment_triggers=set(), + abort_controller=self._tool_abort_controller, + messages=list(messages), + thread_id=thread_id, + ) + + def _default_can_use_tool( + self, + *, + name: str, + permission_context: ToolPermissionContext, + ) -> dict[str, Any] | None: + if self._app_state is None: + return None + permission_state = self._app_state.tool_permission_context + merged_context = ToolPermissionContext( + is_read_only=permission_context.is_read_only, + is_destructive=permission_context.is_destructive, + alwaysAllowRules=permission_state.alwaysAllowRules, + alwaysDenyRules=permission_state.alwaysDenyRules, + alwaysAskRules=permission_state.alwaysAskRules, + allowManagedPermissionRulesOnly=permission_state.allowManagedPermissionRulesOnly, + ) + decision = evaluate_permission_rules(name, merged_context) + if ( + decision is not None + and decision.get("decision") == "ask" + and self._bootstrap is not None + and self._bootstrap.permission_resolver_scope == "none" + ): + # @@@permission-headless-fail-loud - ask is only a real product mode + # when this run has an owner-facing resolver. Otherwise fail loudly + # instead of creating a dead-end pending request in hidden state. + return { + "decision": "deny", + "message": f"{decision.get('message')}. No interactive permission resolver is available for this run.", + } + return decision + + def _request_permission( + self, + *, + thread_id: str, + name: str, + args: dict[str, Any], + message: str | None, + ) -> str | None: + if self._app_state is None: + return None + + request_id = uuid.uuid4().hex[:8] + payload = { + "request_id": request_id, + "thread_id": thread_id, + "tool_name": name, + "args": copy.deepcopy(args), + "message": message, + } + + def _store(state: AppState) -> AppState: + pending = dict(state.pending_permission_requests) + pending[request_id] = payload + return state.model_copy(update={"pending_permission_requests": pending}) + + self._app_state.set_state(_store) + return request_id + + def _consume_permission_resolution( + self, + *, + thread_id: str, + name: str, + args: dict[str, Any], + ) -> dict[str, Any] | None: + if self._app_state is None: + return None + + resolved_items = list(self._app_state.resolved_permission_requests.items()) + matched_id: str | None = None + matched_payload: dict[str, Any] | None = None + for request_id, payload in resolved_items: + if payload.get("thread_id") != thread_id: + continue + if payload.get("tool_name") != name: + continue + if payload.get("args") != args: + continue + matched_id = request_id + matched_payload = payload + break + + if matched_id is None or matched_payload is None: + return None + + def _consume(state: AppState) -> AppState: + resolved = dict(state.resolved_permission_requests) + resolved.pop(matched_id, None) + return state.model_copy(update={"resolved_permission_requests": resolved}) + + self._app_state.set_state(_consume) + return { + "decision": matched_payload.get("decision"), + "message": matched_payload.get("message"), + } + + def _sync_tool_context_messages( + self, + tool_context: ToolUseContext | None, + messages: list, + ) -> None: + if tool_context is None: + return + tool_context.messages = list(messages) + + async def _refresh_tools_between_turns(self, tool_context: ToolUseContext | None) -> None: + refresh = self._refresh_tools + if refresh is None and tool_context is not None: + refresh = tool_context.refresh_tools + if refresh is None: + return + result = refresh() + if inspect.isawaitable(result): + await result + + async def _handle_model_error_recovery( + self, + *, + exc: Exception, + thread_id: str, + messages: list, + turn: int, + transition: ContinueState | None, + max_output_tokens_recovery_count: int, + has_attempted_reactive_compact: bool, + max_output_tokens_override: int | None, + transient_api_retry_count: int, + ) -> _ModelErrorRecoveryResult | None: + ctx = _ModelErrorContext( + exc=exc, + error_text=str(exc).lower(), + thread_id=thread_id, + messages=messages, + turn=turn, + transition=transition, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + ) + for strategy in self._model_error_recovery_strategies(): + result = await strategy(ctx) + if result is not None: + return result + return None + + def _model_error_recovery_strategies(self) -> tuple[Callable[[_ModelErrorContext], Awaitable[_ModelErrorRecoveryResult | None]], ...]: + return ( + self._try_context_overflow_escalate, + self._try_transient_api_retry, + self._try_max_output_tokens_recovery, + self._try_prompt_too_long_collapse_drain, + self._try_prompt_too_long_reactive_compact, + self._try_prompt_too_long_terminal, + ) + + async def _try_context_overflow_escalate(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + parsed_overflow = self._parse_context_overflow_override(str(ctx.exc)) + if parsed_overflow is None: + return None + return _ModelErrorRecoveryResult( + messages=ctx.messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_escalate), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=parsed_overflow, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=None, + ) + + async def _try_transient_api_retry(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + if not self._is_transient_api_error(ctx.exc, ctx.error_text): + return None + if ctx.transient_api_retry_count >= _TRANSIENT_API_MAX_RETRIES: + return None + delay_seconds = self._retry_delay_seconds(ctx.exc, ctx.transient_api_retry_count) + if delay_seconds > 0: + await asyncio.sleep(delay_seconds) + return _ModelErrorRecoveryResult( + messages=ctx.messages, + transition=ContinueState(reason=ContinueReason.api_retry), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count + 1, + terminal=None, + ) + + async def _try_max_output_tokens_recovery(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + if "max_output_tokens" not in ctx.error_text: + return None + if ctx.max_output_tokens_override is None: + return _ModelErrorRecoveryResult( + messages=ctx.messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_escalate), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=_ESCALATED_MAX_OUTPUT_TOKENS, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=None, + ) + if ctx.max_output_tokens_recovery_count < 3: + recovered_messages = list(ctx.messages) + recovered_messages.append( + HumanMessage( + content="Output token limit hit. Resume directly with no apology or recap.", + ) + ) + return _ModelErrorRecoveryResult( + messages=recovered_messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_recovery), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count + 1, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=None, + ) + return _ModelErrorRecoveryResult( + messages=ctx.messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_recovery), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=TerminalState( + reason=TerminalReason.model_error, + turn_count=ctx.turn, + error=str(ctx.exc), + ), + ) + + async def _try_prompt_too_long_collapse_drain(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + if not self._is_prompt_too_long_error(ctx.error_text): + return None + if ctx.transition is not None and ctx.transition.reason is ContinueReason.collapse_drain_retry: + return None + drained = await self._recover_from_overflow(ctx.messages) + if drained is None or drained["committed"] <= 0: + return None + return _ModelErrorRecoveryResult( + messages=drained["messages"], + transition=ContinueState(reason=ContinueReason.collapse_drain_retry), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=None, + ) + + async def _try_prompt_too_long_reactive_compact(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + if not self._is_prompt_too_long_error(ctx.error_text): + return None + if ctx.has_attempted_reactive_compact: + return None + compacted = await self._force_reactive_compact(ctx.messages, thread_id=ctx.thread_id) + if compacted is None: + return None + return _ModelErrorRecoveryResult( + messages=compacted, + transition=ContinueState(reason=ContinueReason.reactive_compact_retry), + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=True, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=None, + ) + + async def _try_prompt_too_long_terminal(self, ctx: _ModelErrorContext) -> _ModelErrorRecoveryResult | None: + if not self._is_prompt_too_long_error(ctx.error_text): + return None + return _ModelErrorRecoveryResult( + messages=ctx.messages, + transition=ctx.transition, + max_output_tokens_recovery_count=ctx.max_output_tokens_recovery_count, + has_attempted_reactive_compact=ctx.has_attempted_reactive_compact, + max_output_tokens_override=ctx.max_output_tokens_override, + transient_api_retry_count=ctx.transient_api_retry_count, + terminal=TerminalState( + reason=TerminalReason.prompt_too_long, + turn_count=ctx.turn, + error=str(ctx.exc), + ), + ) + + @staticmethod + def _parse_context_overflow_override(error_message: str) -> int | None: + match = re.search( + r"input length and `max_tokens` exceed context limit: (\d+) \+ (\d+) > (\d+)", + error_message, + ) + if match is None: + return None + input_tokens = int(match.group(1)) + context_limit = int(match.group(3)) + available_context = max(0, context_limit - input_tokens - _CONTEXT_OVERFLOW_SAFETY_BUFFER) + if available_context < _FLOOR_OUTPUT_TOKENS: + return None + return max(_FLOOR_OUTPUT_TOKENS, available_context) + + @staticmethod + def _is_transient_api_error(exc: Exception, error_text: str) -> bool: + status = getattr(exc, "status", None) + return status in {429, 529} or '"type":"overloaded_error"' in error_text + + @staticmethod + def _retry_delay_seconds(exc: Exception, transient_api_retry_count: int) -> float: + headers = getattr(exc, "headers", None) or {} + # @@@retry-after-shape + # Test doubles use plain dict headers while SDK errors expose a Headers-like + # object. Keep this probe shape-tolerant so the loop can honor retry-after + # without forcing a specific exception class. + if hasattr(headers, "get"): + retry_after = headers.get("retry-after") + else: + retry_after = None + try: + if retry_after is not None: + return max(0.0, float(retry_after)) + except (TypeError, ValueError): + pass + return _TRANSIENT_API_BASE_DELAY_SECONDS * (2**transient_api_retry_count) + + def _handle_truncated_response_recovery( + self, + *, + ai_msg: AIMessage, + messages: list, + turn: int, + max_output_tokens_recovery_count: int, + max_output_tokens_override: int | None, + ) -> dict[str, Any] | None: + if not self._is_max_output_truncated(ai_msg): + return None + + if max_output_tokens_override is None: + return { + "messages": messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_escalate), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count, + "max_output_tokens_override": _ESCALATED_MAX_OUTPUT_TOKENS, + "yield_ai": False, + "terminal": None, + } + + if max_output_tokens_recovery_count < 3: + recovered_messages = list(messages) + recovered_messages.append(ai_msg) + recovered_messages.append( + HumanMessage( + content="Output token limit hit. Resume directly with no apology or recap.", + ) + ) + return { + "messages": recovered_messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_recovery), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count + 1, + "max_output_tokens_override": max_output_tokens_override, + "yield_ai": False, + "terminal": None, + } + + surfaced_messages = list(messages) + surfaced_messages.append(ai_msg) + return { + "messages": surfaced_messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_recovery), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count, + "max_output_tokens_override": max_output_tokens_override, + "yield_ai": True, + "terminal": TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error="max_output_tokens", + ), + } + + async def _force_reactive_compact(self, messages: list, *, thread_id: str) -> list | None: + if self._memory_middleware is None: + return None + compact = getattr(self._memory_middleware, "compact_messages_for_recovery", None) + if not callable(compact): + return None + signature = inspect.signature(compact) + if "thread_id" in signature.parameters: + compacted = compact(messages, thread_id=thread_id) + else: + compacted = compact(messages) + if not inspect.isawaitable(compacted): + raise TypeError("compact_messages_for_recovery must return an awaitable") + return await compacted + + async def _recover_from_overflow(self, messages: list) -> dict[str, Any] | None: + # @@@collapse-drain-single-shot + # ql-04 needs collapse-drain and reactive-compact to stay as separate + # phases. The drain hook is optional, but if present it only gets one + # chance before prompt-too-long falls through to reactive compaction. + for middleware in self.middleware: + recover = getattr(middleware, "recover_from_overflow", None) + if not callable(recover): + continue + drained = recover(messages) + if inspect.isawaitable(drained): + drained = await drained + if drained is None: + return None + committed = int(getattr(drained, "get", lambda *_: 0)("committed", 0)) + updated_messages = getattr(drained, "get", lambda *_: None)("messages") + if committed <= 0 or not isinstance(updated_messages, list): + return None + return {"committed": committed, "messages": list(updated_messages)} + return None + + @staticmethod + def _is_prompt_too_long_error(error_text: str) -> bool: + return ( + "prompt is too long" in error_text + or "prompt too long" in error_text + or "context length" in error_text + or "maximum context length" in error_text + ) + + @staticmethod + def _is_max_output_truncated(message: AIMessage) -> bool: + response_metadata = getattr(message, "response_metadata", None) or {} + additional_kwargs = getattr(message, "additional_kwargs", None) or {} + finish_reason = ( + response_metadata.get("finish_reason") + or response_metadata.get("stop_reason") + or additional_kwargs.get("finish_reason") + or additional_kwargs.get("stop_reason") + ) + return finish_reason in {"length", "max_tokens", "max_output_tokens"} + + # ------------------------------------------------------------------------- + # Tool execution through middleware chain + # ------------------------------------------------------------------------- + + async def _execute_tools( + self, + tool_calls: list, + model_response: ModelResponse, + tool_context: ToolUseContext | None, + ) -> list[ToolMessage]: + """Execute tool calls respecting concurrency safety, via middleware chain.""" + results: dict[int, ToolMessage] = {} + + async def execute_batch(batch: list[tuple[int, dict]]) -> None: + if not batch: + return + batch_results = await asyncio.gather( + *[self._execute_single_tool(tool_call, tool_context) for _, tool_call in batch], + return_exceptions=True, + ) + for (idx, tool_call), result in zip(batch, batch_results): + if isinstance(result, BaseException): + results[idx] = ToolMessage( + content=f"{result}", + tool_call_id=tool_call.get("id", ""), + name=tool_call.get("name", ""), + ) + continue + if not isinstance(result, ToolMessage): + raise TypeError(f"Tool executor returned unexpected result type: {type(result)!r}") + results[idx] = result + + safe_batch: list[tuple[int, dict]] = [] + for idx, tool_call in enumerate(tool_calls): + # @@@tool-order-boundary + # te-01 needs the non-streaming path to keep the same queue barrier + # semantics as the streaming executor: contiguous safe tools may fan + # out together, but any unsafe tool flushes the batch and blocks the + # next safe tool until it finishes. + if self._tool_is_concurrency_safe(tool_call): + safe_batch.append((idx, tool_call)) + continue + + await execute_batch(safe_batch) + safe_batch = [] + try: + results[idx] = await self._execute_single_tool(tool_call, tool_context) + except Exception as exc: + results[idx] = ToolMessage( + content=f"{exc}", + tool_call_id=tool_call.get("id", ""), + name=tool_call.get("name", ""), + ) + + await execute_batch(safe_batch) + return [results[i] for i in range(len(tool_calls))] + + async def _execute_single_tool( + self, + tool_call: dict, + tool_context: ToolUseContext | None, + ) -> ToolMessage: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + call_id = tool_call.get("id", "") + args = tool_call.get("args", {}) or tool_call.get("function", {}).get("arguments", {}) + + if isinstance(args, str): + import json + + try: + args = json.loads(args) + except Exception: + args = {} + + normalized_call = {"name": name, "args": args, "id": call_id} + tc_request = ToolCallRequest( + tool_call=normalized_call, + tool=None, + state=tool_context, + runtime=self._runtime, # type: ignore[arg-type] + ) + + async def innermost_tool_handler(req: ToolCallRequest) -> ToolMessage: + tc = req.tool_call + t_name = tc.get("name", "") + t_id = tc.get("id", "") + t_args = tc.get("args", {}) + entry = self._registry.get(t_name) + if entry is None: + return ToolMessage( + content=f"Tool '{t_name}' not found", + tool_call_id=t_id, + name=t_name, + ) + try: + import asyncio as _asyncio + + if _asyncio.iscoroutinefunction(entry.handler): + result = await entry.handler(**t_args) + else: + result = await _asyncio.to_thread(entry.handler, **t_args) + return ToolMessage(content=str(result), tool_call_id=t_id, name=t_name) + except Exception as e: + return ToolMessage( + content=f"{e}", + tool_call_id=t_id, + name=t_name, + ) + + tool_handler = innermost_tool_handler + for mw in reversed(self.middleware): + if _mw_overrides_tool_call(mw): + tool_handler = _make_tool_wrapper(mw, tool_handler) + + return await tool_handler(tc_request) + + def _tool_is_concurrency_safe(self, tool_call: dict) -> bool: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + entry = self._registry.get(name) + if entry is None: + return False + safety = entry.is_concurrency_safe + if callable(safety): + args = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + args = {} + try: + return bool(safety(args if isinstance(args, dict) else {})) + except Exception: + return False + return bool(safety) + + def _tool_call_is_ready(self, tool_call: dict) -> bool: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + entry = self._registry.get(name) + if entry is None: + return True + + args = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + return False + if not isinstance(args, dict): + return False + + schema = entry.get_schema() or {} + parameters = schema.get("parameters", {}) if isinstance(schema, dict) else {} + return _required_sets_match(parameters, args) if isinstance(parameters, dict) else True + + def _normalize_stream_tool_call( + self, + tool_call: dict, + tool_call_chunks: list[dict[str, Any]], + ) -> dict[str, Any] | None: + call_id = tool_call.get("id") + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + args: Any = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + args = {} + + raw_arg_chunks: list[str] = [] + for chunk in tool_call_chunks: + if chunk.get("id") != call_id: + continue + if chunk.get("name"): + name = chunk["name"] + raw_args = chunk.get("args") + if raw_args in (None, ""): + continue + if isinstance(raw_args, str): + raw_arg_chunks.append(raw_args) + else: + args = raw_args + + if raw_arg_chunks: + try: + import json as _json + + args = _json.loads("".join(raw_arg_chunks)) + except Exception: + return None + + normalized = {"name": name, "args": args, "id": call_id} + if not self._tool_call_is_ready(normalized): + return None + return normalized + + # ------------------------------------------------------------------------- + # Checkpointer persistence + # ------------------------------------------------------------------------- + + async def _load_messages(self, thread_id: str) -> list: + """Load message history from checkpointer (if available).""" + state = await self._load_thread_checkpoint_state(thread_id) + return list(state.messages) if state is not None else [] + + async def _load_thread_checkpoint_state(self, thread_id: str) -> ThreadCheckpointState | None: + if self._checkpoint_store is None: + return None + try: + return await self._checkpoint_store.load(thread_id) + except Exception: + logger.debug("QueryLoop: could not load checkpoint for thread %s", thread_id) + return None + + async def _load_checkpoint_channel_values(self, thread_id: str) -> dict[str, Any]: + """Compatibility helper for tests and bridge callers that still inspect channel_values.""" + state = await self._load_thread_checkpoint_state(thread_id) + if state is None: + return {} + return { + "messages": list(state.messages), + "tool_permission_context": dict(state.tool_permission_context), + "pending_permission_requests": dict(state.pending_permission_requests), + "resolved_permission_requests": dict(state.resolved_permission_requests), + "memory_compaction_state": dict(state.memory_compaction_state), + "mcp_instruction_state": dict(state.mcp_instruction_state), + } + + def _thread_permission_state_snapshot( + self, + thread_id: str, + ) -> tuple[dict[str, Any], dict[str, dict[str, Any]], dict[str, dict[str, Any]]]: + if self._app_state is None: + return {}, {}, {} + + permission_context = copy.deepcopy(self._app_state.tool_permission_context.model_dump()) + pending = { + key: copy.deepcopy(value) + for key, value in self._app_state.pending_permission_requests.items() + if value.get("thread_id") == thread_id + } + resolved = { + key: copy.deepcopy(value) + for key, value in self._app_state.resolved_permission_requests.items() + if value.get("thread_id") == thread_id + } + return permission_context, pending, resolved + + def _thread_memory_state_snapshot(self, thread_id: str) -> dict[str, Any]: + if self._memory_middleware is None: + return {} + snapshot = getattr(self._memory_middleware, "snapshot_thread_state", None) + if not callable(snapshot): + return {} + raw_snapshot = snapshot(thread_id) or {} + if not isinstance(raw_snapshot, dict): + return {} + return {str(key): value for key, value in raw_snapshot.items()} + + def _thread_mcp_instruction_state_snapshot(self, thread_id: str) -> dict[str, Any]: + if self._app_state is None: + return {} + announced_blocks = dict(self._app_state.announced_mcp_instruction_blocks.get(thread_id, {})) + return {"announced_blocks": announced_blocks} + + def _is_runtime_active(self) -> bool: + current_state = getattr(self._runtime, "current_state", None) + return getattr(current_state, "value", current_state) == "active" + + def _snapshot_live_thread_state(self, thread_id: str) -> dict[str, Any]: + messages = list(self._app_state.messages) if self._app_state is not None else [] + permission_context, pending, resolved = self._thread_permission_state_snapshot(thread_id) + memory_state = self._thread_memory_state_snapshot(thread_id) + return { + "messages": messages, + "tool_permission_context": permission_context, + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + "memory_compaction_state": memory_state, + "mcp_instruction_state": self._thread_mcp_instruction_state_snapshot(thread_id), + } + + def _restore_thread_permission_state( + self, + thread_id: str, + *, + permission_context: dict[str, Any], + pending: dict[str, dict[str, Any]], + resolved: dict[str, dict[str, Any]], + ) -> None: + if self._app_state is None: + return + + # @@@permission-checkpoint-bridge - pending/resolved permission requests + # are thread-scoped runtime state, not display-only metadata. They must + # survive checkpoint replay so backend/UI surfaces stay honest after an + # idle reload or agent recreation. + def _update(state: AppState) -> AppState: + kept_pending = {key: value for key, value in state.pending_permission_requests.items() if value.get("thread_id") != thread_id} + kept_pending.update(copy.deepcopy(pending)) + kept_resolved = {key: value for key, value in state.resolved_permission_requests.items() if value.get("thread_id") != thread_id} + kept_resolved.update(copy.deepcopy(resolved)) + return state.model_copy( + update={ + "tool_permission_context": ToolPermissionState.model_validate(copy.deepcopy(permission_context)), + "pending_permission_requests": kept_pending, + "resolved_permission_requests": kept_resolved, + } + ) + + self._app_state.set_state(_update) + + def _restore_thread_memory_state( + self, + thread_id: str, + *, + memory_state: dict[str, Any], + ) -> None: + if self._memory_middleware is None: + return + restore = getattr(self._memory_middleware, "restore_thread_state", None) + if callable(restore): + restore(thread_id, memory_state) + + def _restore_thread_mcp_instruction_state( + self, + thread_id: str, + *, + mcp_instruction_state: dict[str, Any], + ) -> None: + if self._app_state is None: + return + announced_blocks = mcp_instruction_state.get("announced_blocks", {}) + if not isinstance(announced_blocks, dict): + announced_blocks = {} + kept = {key: value for key, value in self._app_state.announced_mcp_instruction_blocks.items() if key != thread_id} + kept[thread_id] = {name: block for name, block in announced_blocks.items() if isinstance(name, str) and isinstance(block, str)} + self._app_state.announced_mcp_instruction_blocks = kept + + async def _hydrate_thread_state_from_checkpoint(self, thread_id: str) -> dict[str, Any]: + checkpoint_state = await self._load_thread_checkpoint_state(thread_id) + messages = list(checkpoint_state.messages) if checkpoint_state is not None else [] + permission_context = dict(checkpoint_state.tool_permission_context) if checkpoint_state is not None else {} + pending = dict(checkpoint_state.pending_permission_requests) if checkpoint_state is not None else {} + resolved = dict(checkpoint_state.resolved_permission_requests) if checkpoint_state is not None else {} + memory_state = dict(checkpoint_state.memory_compaction_state) if checkpoint_state is not None else {} + mcp_instruction_state = dict(checkpoint_state.mcp_instruction_state) if checkpoint_state is not None else {} + turn_count = self._app_state.turn_count if self._app_state is not None else 0 + self._sync_app_state(messages=messages, turn_count=turn_count) + self._restore_thread_permission_state( + thread_id, + permission_context=permission_context, + pending=pending, + resolved=resolved, + ) + self._restore_thread_memory_state( + thread_id, + memory_state=memory_state, + ) + self._restore_thread_mcp_instruction_state( + thread_id, + mcp_instruction_state=mcp_instruction_state, + ) + return { + "messages": messages, + "tool_permission_context": permission_context, + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + "memory_compaction_state": memory_state, + "mcp_instruction_state": mcp_instruction_state, + } + + async def _save_messages(self, thread_id: str, messages: list) -> None: + """Persist message history to checkpointer.""" + if self._checkpoint_store is None: + return + try: + permission_context, pending_requests, resolved_requests = self._thread_permission_state_snapshot(thread_id) + memory_state = self._thread_memory_state_snapshot(thread_id) + mcp_instruction_state = self._thread_mcp_instruction_state_snapshot(thread_id) + await self._checkpoint_store.save( + thread_id, + ThreadCheckpointState( + messages=list(messages), + tool_permission_context=permission_context, + pending_permission_requests=pending_requests, + resolved_permission_requests=resolved_requests, + memory_compaction_state=memory_state, + mcp_instruction_state=mcp_instruction_state, + ), + ) + except Exception: + logger.debug("QueryLoop: could not save checkpoint for thread %s", thread_id, exc_info=True) + + def _collect_memory_system_notices(self, pending_notices: list[HumanMessage]) -> None: + if self._memory_middleware is None: + return + consume_many = getattr(self._memory_middleware, "consume_pending_notices", None) + notices: list[dict[str, Any]] = [] + if callable(consume_many): + maybe_notices = consume_many() + if isinstance(maybe_notices, list): + notices = [notice for notice in maybe_notices if isinstance(notice, dict)] + else: + consume_one = getattr(self._memory_middleware, "consume_latest_compaction_notice", None) + if callable(consume_one): + notice = consume_one() + if isinstance(notice, dict): + notices = [notice] + for notice in notices: + pending_notices.append( + HumanMessage( + content=str(notice.get("content") or ""), + metadata={ + "source": "system", + "notification_type": str(notice.get("notification_type") or "compact"), + "compact_boundary_index": int(notice.get("compact_boundary_index") or 0), + }, + ) + ) + + def _append_system_notices(self, messages: list, notices: list[HumanMessage]) -> list: + if not notices: + return messages + # @@@compact-notice-persist - compaction changes the model-visible + # boundary, but the notice is for the owner surface only. Persist it + # after the run settles so replay stays honest without perturbing the + # same run's next model call. + return list(messages) + list(notices) + + def _build_terminal_notice(self, terminal: TerminalState | None) -> HumanMessage | None: + # @@@terminal-recovery-notice - recovery exhaustion must survive cold + # rebuilds. Persist one owner-visible system notice instead of leaving + # prompt-too-long as a hot-stream-only error. + if terminal is None or terminal.reason is not TerminalReason.prompt_too_long: + return None + return HumanMessage( + content=_PROMPT_TOO_LONG_NOTICE_TEXT, + metadata={"source": "system"}, + ) + + def _terminal_error_text(self, terminal: TerminalState) -> str: + if terminal.reason is TerminalReason.prompt_too_long: + return _PROMPT_TOO_LONG_NOTICE_TEXT + return terminal.error or terminal.reason.value + + def _build_visible_terminal_error_message( + self, + terminal: TerminalState, + messages: list[Any], + ) -> AIMessage | None: + if terminal.reason is TerminalReason.completed: + return None + error_text = self._terminal_error_text(terminal).strip() + if not error_text: + return None + last_message = messages[-1] if messages else None + if isinstance(last_message, AIMessage) and self._ai_message_has_visible_content(last_message): + return None + return AIMessage(content=f"Error: {error_text}") + + async def aclear(self, thread_id: str) -> None: + """Clear turn-scoped state for a thread while preserving session accumulators.""" + await self._save_messages(thread_id, []) + + self._tool_read_file_state.clear() + self._tool_loaded_nested_memory_paths.clear() + self._tool_discovered_skill_names.clear() + self._tool_discovered_tool_names_by_thread.pop(thread_id, None) + + if self._memory_middleware is not None: + summary_store = getattr(self._memory_middleware, "summary_store", None) + if summary_store is not None: + # @@@clear-thread-clears-summary-store - api-05 requires /clear + # to wipe replayable compaction state, not just in-memory cache. + summary_store.delete_thread_summaries(thread_id) + if hasattr(self._memory_middleware, "_cached_summary"): + setattr(self._memory_middleware, "_cached_summary", None) + if hasattr(self._memory_middleware, "_summary_restored"): + setattr(self._memory_middleware, "_summary_restored", False) + if hasattr(self._memory_middleware, "_summary_thread_id"): + setattr(self._memory_middleware, "_summary_thread_id", None) + if hasattr(self._memory_middleware, "_compact_up_to_index"): + setattr(self._memory_middleware, "_compact_up_to_index", 0) + clear_thread_state = getattr(self._memory_middleware, "clear_thread_state", None) + if callable(clear_thread_state): + clear_thread_state(thread_id) + + if self._app_state is not None: + preserved_total_cost = self._app_state.total_cost + preserved_tool_overrides = dict(self._app_state.tool_overrides) + pending_requests = { + key: value for key, value in self._app_state.pending_permission_requests.items() if value.get("thread_id") != thread_id + } + resolved_requests = { + key: value for key, value in self._app_state.resolved_permission_requests.items() if value.get("thread_id") != thread_id + } + + def _reset(state: AppState) -> AppState: + return state.model_copy( + update={ + "messages": [], + "turn_count": 0, + "total_cost": preserved_total_cost, + "compact_boundary_index": 0, + "tool_overrides": preserved_tool_overrides, + "pending_permission_requests": pending_requests, + "resolved_permission_requests": resolved_requests, + } + ) + + self._app_state.set_state(_reset) + + await self._save_messages(thread_id, []) + + if self._bootstrap is not None: + old_session_id = self._bootstrap.session_id + self._bootstrap.parent_session_id = old_session_id + self._bootstrap.session_id = uuid.uuid4().hex + + # ------------------------------------------------------------------------- + # Input parsing + # ------------------------------------------------------------------------- + + @staticmethod + def _parse_input(input: dict | None) -> list: + """Convert input dict to list of LangChain message objects.""" + if input is None: + return [] + raw_messages = input.get("messages", []) + result = [] + for msg in raw_messages: + if hasattr(msg, "content"): + result.append(msg) + elif isinstance(msg, dict): + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "user": + result.append(HumanMessage(content=content)) + elif role == "assistant": + result.append(AIMessage(content=content)) + else: + result.append(HumanMessage(content=content)) + return result + + @staticmethod + def _ai_message_has_visible_content(message: AIMessage) -> bool: + content = getattr(message, "content", None) + if isinstance(content, str): + return content.strip() != "" + if isinstance(content, list): + for item in content: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and str(item.get("text", "")).strip(): + return True + return False + return bool(content) + + @staticmethod + def _tool_results_include_permission_request(tool_results: list[ToolMessage]) -> bool: + for tool_result in tool_results: + additional_kwargs = getattr(tool_result, "additional_kwargs", None) or {} + meta = additional_kwargs.get("tool_result_meta") + if isinstance(meta, dict) and meta.get("kind") == "permission_request": + return True + return False + + @staticmethod + def _get_terminal_followthrough_notice(messages: list[Any]) -> HumanMessage | None: + if not messages: + return None + last_message = messages[-1] + if last_message.__class__.__name__ != "HumanMessage": + return None + metadata = getattr(last_message, "metadata", None) or {} + if metadata.get("source") != "system": + return None + if metadata.get("notification_type") not in {"agent", "command"}: + return None + content = getattr(last_message, "content", "") + text = content if isinstance(content, str) else str(content) + if "CommandNotification" not in text and "task-notification" not in text: + return None + return last_message + + @staticmethod + def _get_chat_followthrough_notice(messages: list[Any]) -> HumanMessage | None: + if not messages: + return None + last_message = messages[-1] + if last_message.__class__.__name__ != "HumanMessage": + return None + metadata = getattr(last_message, "metadata", None) or {} + if metadata.get("source") != "external": + return None + if metadata.get("notification_type") != "chat": + return None + content = getattr(last_message, "content", "") + text = content if isinstance(content, str) else str(content) + if "New message from" not in text or "read_messages(chat_id=" not in text: + return None + return last_message + + @classmethod + def _build_terminal_followthrough_fallback(cls, notice: HumanMessage) -> AIMessage: + metadata = getattr(notice, "metadata", None) or {} + notification_type = str(metadata.get("notification_type") or "task") + content = getattr(notice, "content", "") + text = content if isinstance(content, str) else str(content) + status_match = re.search(r"(.*?)", text, flags=re.IGNORECASE | re.DOTALL) + status = status_match.group(1).strip().lower() if status_match else "" + subject = "command" if notification_type == "command" else "agent" + # @@@terminal-followthrough-fallback - terminal background notifications + # must never collapse into notice-only durable history when the model + # reentry stays silent; surface the silence explicitly instead. + if status == "completed": + reply = f"Background {subject} completed, but the followthrough assistant reply was empty." + elif status == "cancelled": + reply = f"Background {subject} was cancelled, but the followthrough assistant reply was empty." + elif status == "error": + reply = f"Background {subject} failed, but the followthrough assistant reply was empty." + else: + reply = f"Background {subject} update arrived, but the followthrough assistant reply was empty." + return AIMessage(content=reply) + + @classmethod + def _build_chat_followthrough_fallback(cls, notice: HumanMessage) -> AIMessage: + content = getattr(notice, "content", "") + text = content if isinstance(content, str) else str(content) + chat_id_match = re.search(r'read_messages\(chat_id="([^"]+)"\)', text) + if chat_id_match: + chat_id = chat_id_match.group(1) + reply = ( + f"I received a chat notification, but the followthrough assistant reply was empty. " + f'Read it with read_messages(chat_id="{chat_id}") before deciding whether to reply.' + ) + else: + reply = "I received a chat notification, but the followthrough assistant reply was empty." + return AIMessage(content=reply) + + +# ------------------------------------------------------------------------- +# Closure helpers (avoid late-binding bugs in loop-built lambdas) +# ------------------------------------------------------------------------- + + +def _make_model_wrapper(mw: AgentMiddleware, next_handler): + """Build an awrap_model_call wrapper that correctly closes over mw and next_handler.""" + + async def wrapper(request: ModelRequest) -> ModelResponse: + return await mw.awrap_model_call(request, next_handler) + + return wrapper + + +def _make_tool_wrapper(mw: AgentMiddleware, next_handler): + """Build an awrap_tool_call wrapper that correctly closes over mw and next_handler.""" + + async def wrapper(request: ToolCallRequest) -> ToolMessage: + return await mw.awrap_tool_call(request, next_handler) + + return wrapper + + +# ------------------------------------------------------------------------- +# Middleware override detection helpers +def _mw_overrides_model_call(mw: AgentMiddleware) -> bool: + """True if mw actually overrides awrap_model_call (not just inherits the base stub).""" + mw_type = type(mw) + own_fn = mw_type.__dict__.get("awrap_model_call") + if own_fn is not None: + return True + own_sync = mw_type.__dict__.get("wrap_model_call") + return own_sync is not None + + +def _mw_overrides_tool_call(mw: AgentMiddleware) -> bool: + """True if mw actually overrides awrap_tool_call (not just inherits the base stub).""" + mw_type = type(mw) + own_fn = mw_type.__dict__.get("awrap_tool_call") + if own_fn is not None: + return True + own_sync = mw_type.__dict__.get("wrap_tool_call") + return own_sync is not None diff --git a/core/runtime/middleware/__init__.py b/core/runtime/middleware/__init__.py index e69de29bb..f777a7fde 100644 --- a/core/runtime/middleware/__init__.py +++ b/core/runtime/middleware/__init__.py @@ -0,0 +1,79 @@ +"""Local runtime middleware protocol and request/response types. + +This replaces the phantom `langchain.agents.middleware.types` dependency for +the current runtime stack. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, replace +from typing import Any, ClassVar + +from langchain_core.messages import ToolMessage + + +@dataclass(frozen=True) +class ModelRequest: + model: Any + messages: list + system_message: Any = None + tools: list | None = None + + def override(self, **changes: Any) -> ModelRequest: + return replace(self, **changes) + + +@dataclass(frozen=True) +class ModelResponse: + result: list + request_messages: list | None = None + prepared_request: ModelRequest | None = None + + +ModelCallResult = ModelResponse + + +@dataclass(frozen=True) +class ToolCallRequest: + tool_call: dict + tool: Any = None + state: Any = None + runtime: Any = None + + def override(self, **changes: Any) -> ToolCallRequest: + return replace(self, **changes) + + +class AgentMiddleware: + """Minimal chain-of-responsibility middleware base for the runtime stack.""" + + tools: ClassVar[tuple[Any, ...]] = () + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + return await handler(request) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage], + ) -> ToolMessage: + return handler(request) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage]], + ) -> ToolMessage: + return await handler(request) diff --git a/core/runtime/middleware/mcp_instructions.py b/core/runtime/middleware/mcp_instructions.py new file mode 100644 index 000000000..7cff4c7cb --- /dev/null +++ b/core/runtime/middleware/mcp_instructions.py @@ -0,0 +1,80 @@ +"""Thread-scoped MCP instruction delta injection. + +Mycel does not have CC's attachment plane. Keep this contract smaller: +- MCP server configs may carry `instructions` +- the loop stores which server names have already been announced per thread +- on the next turn after a change, inject one delta SystemMessage +""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any + +from langchain_core.messages import SystemMessage + +from core.runtime.middleware import AgentMiddleware +from core.runtime.state import AppState + +_DELTA_TAG = "mcp_instructions_delta" + + +def _format_instruction_block(server_name: str, instructions: str) -> str: + return f"## {server_name}\n{instructions.strip()}" + + +def _render_delta_message(*, added: dict[str, str], removed: list[str]) -> SystemMessage: + payload = { + "added_names": sorted(added), + "removed_names": sorted(removed), + } + blocks = [ + "", + f"<{_DELTA_TAG}>{json.dumps(payload, ensure_ascii=False)}", + "MCP server instructions changed for this thread.", + ] + if added: + blocks.append("Use the newly available MCP instructions below for subsequent turns:") + blocks.extend(_format_instruction_block(name, added[name]) for name in sorted(added)) + if removed: + blocks.append("The following MCP servers are no longer active for this thread:") + blocks.extend(f"- {name}" for name in sorted(removed)) + blocks.append("") + return SystemMessage(content="\n".join(blocks)) + + +class McpInstructionsDeltaMiddleware(AgentMiddleware): + """Injects MCP instruction deltas once per thread when the connected set changes.""" + + def __init__( + self, + *, + get_instruction_blocks: Callable[[], dict[str, str]], + get_app_state: Callable[[], AppState | None], + ) -> None: + self._get_instruction_blocks = get_instruction_blocks + self._get_app_state = get_app_state + + def before_model(self, state: dict[str, Any], runtime: Any = None, config: dict[str, Any] | None = None) -> dict[str, Any] | None: + app_state = self._get_app_state() + if app_state is None: + return None + + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + current_blocks = {name: block for name, block in self._get_instruction_blocks().items() if block.strip()} + announced_blocks = { + name: block + for name, block in app_state.announced_mcp_instruction_blocks.get(thread_id, {}).items() + if isinstance(name, str) and isinstance(block, str) and block.strip() + } + + added_names = sorted(name for name, block in current_blocks.items() if announced_blocks.get(name) != block) + removed_names = sorted(name for name in announced_blocks if name not in current_blocks) + if not added_names and not removed_names: + return None + + app_state.announced_mcp_instruction_blocks[thread_id] = dict(current_blocks) + added = {name: current_blocks[name] for name in added_names} + return {"messages": [_render_delta_message(added=added, removed=removed_names)]} diff --git a/core/runtime/middleware/memory/compactor.py b/core/runtime/middleware/memory/compactor.py index 67599b534..defbb7221 100644 --- a/core/runtime/middleware/memory/compactor.py +++ b/core/runtime/middleware/memory/compactor.py @@ -10,13 +10,22 @@ from langchain_core.messages import HumanMessage, SystemMessage +# CC L4b Legacy Compact: system prompt is simple (~200 tokens) — NOT inherited from parent. +# Using a distinct simple system prompt prevents reusing the parent conversation's cache +# (different system prompt → different prefix hash), and reduces input token cost. +COMPACT_SYSTEM_PROMPT = "You are a helpful AI assistant tasked with summarizing conversations." + SUMMARY_PROMPT = """\ -Provide a detailed summary for continuing our conversation. Include: -1. Key decisions made and their rationale -2. Files created, modified, or read and their current state -3. Errors encountered and how they were resolved -4. Outstanding tasks and current progress -5. Important context that would be needed to continue the work +Summarize this conversation in the following 9 sections: +1. Request/Intent — what the user asked for +2. Technical Concepts — key technologies and approaches discussed +3. Files/Code — files created or modified and their current state +4. Errors — errors encountered and how they were resolved +5. Problem Solving — decisions made and rationale +6. User Messages — key user inputs and feedback +7. Pending Tasks — unfinished work +8. Current Work — what was actively being done at the end +9. Next Step — the immediate next action needed Be concise but retain all information needed to continue seamlessly.""" SPLIT_TURN_PREFIX_PROMPT = """\ @@ -80,19 +89,41 @@ def split_messages(self, messages: list[Any]) -> tuple[list[Any], list[Any]]: return messages[:split_idx], messages[split_idx:] - async def compact(self, messages_to_summarize: list[Any], model: Any) -> str: + async def compact( + self, + messages_to_summarize: list[Any], + model: Any, + compact_boundary: int = 0, + ) -> str: """Generate a summary of the given messages using the LLM. + Aligned with CC L4b Legacy Compact: + - Uses COMPACT_SYSTEM_PROMPT (simple, ~200 tokens — NOT parent system prompt) + - No tools passed (extended thinking disabled, tools=[]) + - Slices from compact_boundary forward + - max_tokens capped at 20000 (CC max summary output) + Returns plain text summary string. """ - # Build the summarization request + # Slice from compact_boundary forward (CC: from last compact_boundary marker) + if compact_boundary > 0 and compact_boundary < len(messages_to_summarize): + messages_to_summarize = messages_to_summarize[compact_boundary:] + formatted = self._format_messages_for_summary(messages_to_summarize) + # CC L4b: system prompt is simple — does NOT inherit parent's system prompt. + # No tools, no extended thinking. summary_messages = [ - SystemMessage(content=SUMMARY_PROMPT), - HumanMessage(content=f"Here is the conversation to summarize:\n\n{formatted}"), + SystemMessage(content=COMPACT_SYSTEM_PROMPT), + HumanMessage(content=f"Summarize this conversation:\n\n{formatted}\n\n{SUMMARY_PROMPT}"), ] - response = await model.ainvoke(summary_messages) + # Bind max_tokens=20000 (CC max summary output), no tools + try: + bound_model = model.bind(max_tokens=20000) + except Exception: + bound_model = model + + response = await bound_model.ainvoke(summary_messages) return response.content if hasattr(response, "content") else str(response) def _estimate_msg_tokens(self, msg: Any) -> int: diff --git a/core/runtime/middleware/memory/middleware.py b/core/runtime/middleware/memory/middleware.py index 8775e1c21..c4d4f2362 100644 --- a/core/runtime/middleware/memory/middleware.py +++ b/core/runtime/middleware/memory/middleware.py @@ -7,19 +7,22 @@ from __future__ import annotations +import json import logging from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any -from langchain.agents.middleware.types import ( +from langchain_core.messages import SystemMessage + +from core.runtime.checkpoint_store import CheckpointStore +from core.runtime.langgraph_checkpoint_store import LangGraphCheckpointStore +from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, ModelResponse, ) -from langchain_core.messages import SystemMessage - from storage.contracts import SummaryRepo from .compactor import ContextCompactor @@ -27,6 +30,7 @@ from .summary_store import SummaryStore logger = logging.getLogger(__name__) +_COMPACTION_BREAKER_THRESHOLD = 3 class MemoryMiddleware(AgentMiddleware): @@ -36,7 +40,7 @@ class MemoryMiddleware(AgentMiddleware): Layer 2 (Compaction): LLM summarization when context exceeds threshold """ - tools = [] # no tools injected + tools = () # no tools injected def __init__( self, @@ -75,6 +79,8 @@ def __init__( # Persistent storage summary_db_path = db_path or Path.home() / ".leon" / "leon.db" self.summary_store = SummaryStore(summary_db_path, summary_repo=summary_repo) if (db_path or summary_repo) else None + self._checkpointer: Any = None + self._checkpoint_store: CheckpointStore | None = None self.checkpointer = checkpointer # Injected references (set by agent.py after construction) @@ -86,6 +92,10 @@ def __init__( self._cached_summary: str | None = None self._compact_up_to_index: int = 0 self._summary_restored: bool = False + self._summary_thread_id: str | None = None + self._pending_owner_notices: list[dict[str, Any]] = [] + self._compaction_failure_counts_by_thread: dict[str, int] = {} + self._compaction_breaker_open_by_thread: dict[str, bool] = {} if verbose: print("[MemoryMiddleware] Initialized") @@ -101,6 +111,15 @@ def set_model(self, model: Any, model_config: dict[str, Any] | None = None) -> N self._model = model self._model_config = model_config + @property + def checkpointer(self) -> Any: + return self._checkpointer + + @checkpointer.setter + def checkpointer(self, value: Any) -> None: + self._checkpointer = value + self._checkpoint_store = LangGraphCheckpointStore(value) if value is not None else None + @property def _resolved_model(self) -> Any: """Return model with config bound so it uses the correct model/provider.""" @@ -125,6 +144,10 @@ def set_runtime(self, runtime: Any) -> None: """Inject AgentRuntime reference (called by agent.py).""" self._runtime = runtime + @property + def compact_boundary_index(self) -> int: + return self._compact_up_to_index + # ========== AgentMiddleware interface ========== async def awrap_model_call( @@ -134,13 +157,18 @@ async def awrap_model_call( ) -> ModelCallResult: messages = list(request.messages) original_count = len(messages) + thread_id = self._extract_thread_id(request) # Restore summary from store if not already done if not self._summary_restored and self.summary_store: - thread_id = self._extract_thread_id(request) if thread_id: await self._restore_summary_from_store(thread_id) self._summary_restored = True + self._summary_thread_id = thread_id + elif self.summary_store and thread_id and self._summary_thread_id != thread_id: + await self._restore_summary_from_store(thread_id) + self._summary_restored = True + self._summary_thread_id = thread_id sys_tokens = self._estimate_system_tokens(request) @@ -173,8 +201,9 @@ async def awrap_model_call( ) if self.compactor.should_compact(estimated, self._context_limit, self._compaction_threshold) and self._model: - thread_id = self._extract_thread_id(request) - messages = await self._do_compact(messages, thread_id) + compacted = await self._attempt_compaction(messages, thread_id=thread_id) + if compacted is not None: + messages = compacted elif self._cached_summary and self._compact_up_to_index > 0: if self._compact_up_to_index <= len(messages): summary_msg = SystemMessage(content=f"[Conversation Summary]\n{self._cached_summary}") @@ -190,7 +219,14 @@ async def awrap_model_call( final_tokens = self._estimate_tokens(messages) + sys_tokens print(f"[Memory] Final: {len(messages)} msgs (~{final_tokens} tokens) sent to LLM (original: {original_count} msgs)") - return await handler(request.override(messages=messages)) + response = await handler(request.override(messages=messages)) + if response.request_messages is None: + return ModelResponse( + result=response.result, + request_messages=list(messages), + prepared_request=response.prepared_request, + ) + return response async def _do_compact(self, messages: list[Any], thread_id: str | None = None) -> list[Any]: """Execute compaction: summarize old messages, return compacted list.""" @@ -219,6 +255,9 @@ async def _do_compact(self, messages: list[Any], thread_id: str | None = None) - self._cached_summary = summary_text self._compact_up_to_index = len(messages) - len(to_keep) + self._summary_restored = True + self._summary_thread_id = thread_id + self._record_compaction_notice() if self.summary_store and thread_id: try: @@ -257,6 +296,7 @@ async def force_compact(self, messages: list[Any]) -> dict[str, Any] | None: summary_text = await self.compactor.compact(to_summarize, self._resolved_model) self._cached_summary = summary_text self._compact_up_to_index = len(messages) - len(to_keep) + self._record_compaction_notice() return { "stats": { "summarized": len(to_summarize), @@ -267,6 +307,24 @@ async def force_compact(self, messages: list[Any]) -> dict[str, Any] | None: if self._runtime: self._runtime.set_flag("is_compacting", False) + async def compact_messages_for_recovery(self, messages: list[Any], thread_id: str | None = None) -> list[Any] | None: + """Force a compaction pass and return the compacted message list.""" + if not self._model: + return None + + pruned = self.pruner.prune(messages) + to_summarize, to_keep = self.compactor.split_messages(pruned) + if len(to_summarize) < 2: + return None + + return await self._attempt_compaction( + pruned, + thread_id=thread_id or self._current_thread_id(), + respect_breaker=False, + record_failures=False, + clear_breaker_on_success=True, + ) + def _estimate_tokens(self, messages: list[Any]) -> int: """Estimate total tokens for messages (chars // 2).""" total = 0 @@ -306,6 +364,110 @@ def _extract_thread_id(self, request: ModelRequest) -> str | None: return configurable.get("thread_id") return getattr(configurable, "thread_id", None) if configurable else None + def consume_pending_notices(self) -> list[dict[str, Any]]: + notices = list(self._pending_owner_notices) + self._pending_owner_notices.clear() + return notices + + def snapshot_thread_state(self, thread_id: str) -> dict[str, Any]: + return { + "failure_count": int(self._compaction_failure_counts_by_thread.get(thread_id, 0)), + "breaker_open": bool(self._compaction_breaker_open_by_thread.get(thread_id, False)), + } + + def restore_thread_state(self, thread_id: str, state: dict[str, Any] | None) -> None: + payload = dict(state or {}) + failure_count = int(payload.get("failure_count") or 0) + breaker_open = bool(payload.get("breaker_open", False)) + if failure_count > 0: + self._compaction_failure_counts_by_thread[thread_id] = failure_count + else: + self._compaction_failure_counts_by_thread.pop(thread_id, None) + if breaker_open: + self._compaction_breaker_open_by_thread[thread_id] = True + else: + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def clear_thread_state(self, thread_id: str) -> None: + self._compaction_failure_counts_by_thread.pop(thread_id, None) + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def _record_compaction_notice(self) -> None: + content = f"Conversation compacted. Earlier {self._compact_up_to_index} message(s) are now represented by a summary." + self._queue_owner_notice( + { + "content": content, + "notification_type": "compact", + "compact_boundary_index": self._compact_up_to_index, + } + ) + + def _current_thread_id(self) -> str | None: + from sandbox.thread_context import get_current_thread_id + + return get_current_thread_id() + + async def _attempt_compaction( + self, + messages: list[Any], + *, + thread_id: str | None, + respect_breaker: bool = True, + record_failures: bool = True, + clear_breaker_on_success: bool = False, + ) -> list[Any] | None: + # @@@compaction-breaker-scope - match cc-src's narrower boundary: + # the breaker blocks later automatic compaction attempts, but reactive + # recovery may still try once and clear the breaker on success. + if respect_breaker and thread_id and self._compaction_breaker_open_by_thread.get(thread_id, False): + return None + try: + compacted = await self._do_compact(messages, thread_id) + except Exception as exc: + logger.error("[Memory] Compaction failed for thread %s: %s", thread_id or "", exc) + if record_failures: + self._record_compaction_failure(thread_id, exc) + return None + self._record_compaction_success(thread_id, clear_breaker=clear_breaker_on_success) + return compacted + + def _record_compaction_success(self, thread_id: str | None, *, clear_breaker: bool = False) -> None: + if not thread_id: + return + self._compaction_failure_counts_by_thread.pop(thread_id, None) + if clear_breaker: + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def _record_compaction_failure(self, thread_id: str | None, exc: Exception) -> None: + if not thread_id: + return + failures = int(self._compaction_failure_counts_by_thread.get(thread_id, 0)) + 1 + self._compaction_failure_counts_by_thread[thread_id] = failures + if failures < _COMPACTION_BREAKER_THRESHOLD or self._compaction_breaker_open_by_thread.get(thread_id, False): + return + self._compaction_breaker_open_by_thread[thread_id] = True + self._queue_owner_notice( + { + "content": "Automatic compaction disabled for this thread after repeated failures. Clear the thread or start a new one.", + "notification_type": "compact_breaker", + "failure_count": failures, + "error": str(exc), + } + ) + + def _queue_owner_notice(self, notice: dict[str, Any]) -> None: + self._pending_owner_notices.append(dict(notice)) + if self._runtime and hasattr(self._runtime, "emit_activity_event"): + # @@@memory-owner-notices - compaction boundary and breaker state are + # owner-facing runtime facts, so stream and cold rebuild must share + # the same notice payload instead of inventing separate surfaces. + self._runtime.emit_activity_event( + { + "event": "notice", + "data": json.dumps(notice, ensure_ascii=False), + } + ) + async def _restore_summary_from_store(self, thread_id: str) -> None: """Restore summary from SummaryStore.""" if not thread_id: @@ -314,6 +476,10 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: ) try: + if self.summary_store is None: + return + self._cached_summary = None + self._compact_up_to_index = 0 summary_data = self.summary_store.get_latest_summary(thread_id) if not summary_data: @@ -332,6 +498,7 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: self._cached_summary = summary_data.summary_text self._compact_up_to_index = summary_data.compact_up_to_index + self._summary_thread_id = thread_id if self.verbose: print( @@ -342,21 +509,25 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: ) except Exception as e: + self._cached_summary = None + self._compact_up_to_index = 0 logger.error(f"[Memory] Failed to restore summary: {e}") async def _rebuild_summary_from_checkpointer(self, thread_id: str) -> None: """Rebuild summary from checkpointer when store data is corrupted.""" try: + if self.summary_store is None or self._checkpoint_store is None: + return if self.verbose: print(f"[Memory] Rebuilding summary from checkpointer for thread {thread_id}...") - checkpoint = self.checkpointer.get({"configurable": {"thread_id": thread_id}}) - if not checkpoint: + checkpoint_state = await self._checkpoint_store.load(thread_id) + if checkpoint_state is None: if self.verbose: print("[Memory] No checkpoint found, skipping rebuild") return - messages = checkpoint.get("channel_values", {}).get("messages", []) + messages = list(checkpoint_state.messages) if not messages: if self.verbose: print("[Memory] No messages in checkpoint, skipping rebuild") diff --git a/core/runtime/middleware/memory/summary_store.py b/core/runtime/middleware/memory/summary_store.py index 6fcff004c..553d162fa 100644 --- a/core/runtime/middleware/memory/summary_store.py +++ b/core/runtime/middleware/memory/summary_store.py @@ -64,8 +64,9 @@ def __init__(self, db_path: Path | None = None, summary_repo: SummaryRepo | None if summary_repo is not None: self._repo = summary_repo else: + resolved_db_path = self.db_path # @@@connect_injection - keep _connect as an indirection point so existing retry/rollback tests can patch it. - self._repo = SQLiteSummaryRepo(db_path, connect_fn=lambda p: _connect(p)) + self._repo = SQLiteSummaryRepo(resolved_db_path, connect_fn=lambda p: _connect(Path(p))) self._ensure_tables() def _ensure_tables(self) -> None: @@ -126,6 +127,8 @@ def save_summary( logger.error(f"[SummaryStore] Save failed after {max_retries} attempts: {e}") raise + raise RuntimeError("Summary save loop exited without returning or raising") + def get_latest_summary( self, thread_id: str, diff --git a/core/runtime/middleware/monitor/cost.py b/core/runtime/middleware/monitor/cost.py index 4b09c2a51..08615af02 100644 --- a/core/runtime/middleware/monitor/cost.py +++ b/core/runtime/middleware/monitor/cost.py @@ -112,7 +112,7 @@ def _load_cache() -> tuple[dict[str, dict[str, str]], dict[str, int], dict[str, if not cache_path.exists(): return None try: - data = json.loads(cache_path.read_text()) + data = json.loads(cache_path.read_text(encoding="utf-8")) if time.time() - data.get("timestamp", 0) > _CACHE_TTL: return None models = data.get("models", {}) @@ -128,7 +128,7 @@ def _save_cache(models: dict[str, dict[str, str]], context_limits: dict[str, int try: _CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) data = {"timestamp": time.time(), "models": models, "context_limits": context_limits, "providers": providers} - _CACHE_PATH.write_text(json.dumps(data)) + _CACHE_PATH.write_text(json.dumps(data), encoding="utf-8") except Exception: pass @@ -163,11 +163,17 @@ def fetch_openrouter_pricing() -> dict[str, dict[str, Decimal]]: cached = _load_cache() if cached: models_raw, ctx, provs = cached - _pricing_data = _deserialize_costs(models_raw) - _context_limits = ctx - _model_providers = provs - _initialized = True - return _pricing_data + cached_costs = _deserialize_costs(models_raw) + # @@@pricing-cache-integrity - older CI caches can carry context/provider + # metadata with an empty model-pricing payload, which makes cost + # calculation silently degrade while context-limit tests still pass. + # Treat that cache as invalid and fall through to bundled/API reload. + if cached_costs: + _pricing_data = cached_costs + _context_limits = ctx + _model_providers = provs + _initialized = True + return _pricing_data _pricing_data = _fetch_from_openrouter() or _load_bundled() _initialized = True @@ -219,7 +225,10 @@ def _load_bundled() -> dict[str, dict[str, Decimal]]: if not _BUNDLED_PATH.exists(): return {} try: - data = json.loads(_BUNDLED_PATH.read_text()) + # @@@bundled-models-utf8 - Windows runners do not default to UTF-8. + # The bundled OpenRouter snapshot contains non-ASCII descriptions, so + # implicit decoding can fail and silently collapse pricing/context data. + data = json.loads(_BUNDLED_PATH.read_text(encoding="utf-8")) result: dict[str, dict[str, Decimal]] = {} ctx_result: dict[str, int] = {} prov_result: dict[str, str] = {} diff --git a/core/runtime/middleware/monitor/middleware.py b/core/runtime/middleware/monitor/middleware.py index 218ebcd06..adff96818 100644 --- a/core/runtime/middleware/monitor/middleware.py +++ b/core/runtime/middleware/monitor/middleware.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable from typing import Any -from langchain.agents.middleware.types import ( +from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, @@ -25,7 +25,7 @@ class MonitorMiddleware(AgentMiddleware): 提供 AgentRuntime 聚合所有监控数据。 """ - tools = [] # 不注入工具 + tools = () # 不注入工具 def __init__(self, context_limit: int = 0, model_name: str = "", verbose: bool = False): self.verbose = verbose @@ -113,6 +113,9 @@ async def awrap_model_call( self._state_monitor.mark_error(e) raise + if response.prepared_request is not None: + return response + messages = response.result if hasattr(response, "result") else [response] resp_dict = {"messages": messages} diff --git a/core/runtime/middleware/monitor/token_monitor.py b/core/runtime/middleware/monitor/token_monitor.py index 255092704..7071d0141 100644 --- a/core/runtime/middleware/monitor/token_monitor.py +++ b/core/runtime/middleware/monitor/token_monitor.py @@ -1,8 +1,11 @@ """Token 使用量监控(6 项分项追踪)""" +from __future__ import annotations + from typing import Any from .base import BaseMonitor +from .cost import CostCalculator class TokenMonitor(BaseMonitor): @@ -24,7 +27,7 @@ def __init__(self): self.total_tokens = 0 # 总计 # 成本计算器(由 MonitorMiddleware 注入) - self.cost_calculator = None + self.cost_calculator: CostCalculator | None = None def on_request(self, request: dict[str, Any]) -> None: """请求前:无操作(call_count 在 on_response 中计数)""" diff --git a/core/runtime/middleware/prompt_caching/__init__.py b/core/runtime/middleware/prompt_caching/__init__.py index 87f4e92b4..361b124a8 100644 --- a/core/runtime/middleware/prompt_caching/__init__.py +++ b/core/runtime/middleware/prompt_caching/__init__.py @@ -1,8 +1,8 @@ """Anthropic prompt caching middleware. Requires: - - `langchain`: For agent middleware framework - - `langchain-anthropic`: For `ChatAnthropic` model (already a dependency) + - local `core.runtime.middleware` protocol types + - `langchain-anthropic`: For `ChatAnthropic` model """ from collections.abc import Awaitable, Callable @@ -10,9 +10,10 @@ from warnings import warn from langchain_anthropic.chat_models import ChatAnthropic +from langchain_core.messages import SystemMessage try: - from langchain.agents.middleware.types import ( + from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, @@ -20,9 +21,9 @@ ) except ImportError as e: msg = ( - "AnthropicPromptCachingMiddleware requires 'langchain' to be installed. " - "This middleware is designed for use with LangChain agents. " - "Install it with: pip install langchain" + "AnthropicPromptCachingMiddleware requires the local " + "'core.runtime.middleware' protocol definitions and " + "'langchain-anthropic' to be importable." ) raise ImportError(msg) from e @@ -32,7 +33,7 @@ class PromptCachingMiddleware(AgentMiddleware): Optimizes API usage by caching conversation prefixes for Anthropic models. - Requires both `langchain` and `langchain-anthropic` packages to be installed. + Requires the local runtime middleware protocol plus `langchain-anthropic`. Learn more about Anthropic prompt caching [here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching). @@ -68,6 +69,26 @@ def __init__( self.min_messages_to_cache = min_messages_to_cache self.unsupported_model_behavior = unsupported_model_behavior + def _apply_system_cache(self, request: ModelRequest) -> ModelRequest: + """Add cache_control to the first (static) block of system_message. + + Anthropic prompt caching requires cache_control on the system content + blocks, not on messages. Marking the first block caches the entire + static system prefix (identity + tool rules) across sessions. + """ + sm = request.system_message + if sm is None: + return request + content = sm.content + if isinstance(content, str): + new_content: list = [{"type": "text", "text": content, "cache_control": {"type": self.type}}] + elif isinstance(content, list) and content: + first = {**content[0], "cache_control": {"type": self.type}} + new_content = [first, *content[1:]] + else: + return request + return request.override(system_message=SystemMessage(content=new_content)) + def _should_apply_caching(self, request: ModelRequest) -> bool: """Check if caching should be applied to the request. @@ -112,12 +133,7 @@ def wrap_model_call( """ if not self._should_apply_caching(request): return handler(request) - - new_model_settings = { - **request.model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return handler(request.override(model_settings=new_model_settings)) + return handler(self._apply_system_cache(request)) async def awrap_model_call( self, @@ -135,12 +151,7 @@ async def awrap_model_call( """ if not self._should_apply_caching(request): return await handler(request) - - new_model_settings = { - **request.model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return await handler(request.override(model_settings=new_model_settings)) + return await handler(self._apply_system_cache(request)) __all__ = ["PromptCachingMiddleware"] diff --git a/core/runtime/middleware/queue/__init__.py b/core/runtime/middleware/queue/__init__.py index f3d08f337..cf97229dc 100644 --- a/core/runtime/middleware/queue/__init__.py +++ b/core/runtime/middleware/queue/__init__.py @@ -2,7 +2,12 @@ from storage.contracts import QueueItem -from .formatters import format_background_notification, format_chat_notification, format_wechat_message +from .formatters import ( + format_agent_message, + format_background_notification, + format_chat_notification, + format_progress_notification, +) from .manager import MessageQueueManager from .middleware import SteeringMiddleware @@ -10,7 +15,8 @@ "MessageQueueManager", "QueueItem", "SteeringMiddleware", + "format_agent_message", "format_background_notification", "format_chat_notification", - "format_wechat_message", + "format_progress_notification", ] diff --git a/core/runtime/middleware/queue/formatters.py b/core/runtime/middleware/queue/formatters.py index 1e7821187..85034f7b4 100644 --- a/core/runtime/middleware/queue/formatters.py +++ b/core/runtime/middleware/queue/formatters.py @@ -11,13 +11,51 @@ def format_chat_notification(sender_name: str, chat_id: str, unread_count: int, signal: str | None = None) -> str: - """Lightweight notification — agent must chat_read to see content. + """Lightweight notification — agent must read_messages to see content. @@@v3-notification-only — no message content injected. Agent calls - chat_read(chat_id=...) to read, then chat_send() to reply. + read_messages(chat_id=...) to read, then send_message() to reply. """ signal_hint = f" [signal: {signal}]" if signal and signal != "open" else "" - return f"\nNew message from {sender_name} in chat {chat_id} ({unread_count} unread).{signal_hint}\n" + return ( + "\n" + f"New message from {sender_name} in chat {chat_id} ({unread_count} unread).{signal_hint}\n" + f'Read it with read_messages(chat_id="{chat_id}").\n' + f'Reply with send_message(chat_id="{chat_id}", content="...").\n' + "Prefer using this exact chat_id directly.\n" + "Do not treat your normal assistant text as a chat reply.\n" + "" + ) + + +def format_agent_message(sender_name: str, message: str) -> str: + """Format inter-agent delivery for steering injection on the next turn.""" + return ( + "\n" + "\n" + f" {escape(sender_name)}\n" + f" {escape(message)}\n" + "\n" + "" + ) + + +def format_progress_notification( + agent_id: str, + description: str, + *, + step: str = "running", +) -> str: + """Format background worker progress for coordinator-style prompt injection.""" + return ( + "\n" + "\n" + f" {escape(agent_id)}\n" + f" {escape(step)}\n" + f" {escape(description)}\n" + "\n" + "" + ) def format_background_notification( @@ -31,7 +69,7 @@ def format_background_notification( """Format background task completion as system-reminder XML.""" parts = [ "", - "", + "", f" {task_id}", f" {status}", ] @@ -44,29 +82,11 @@ def format_background_notification( parts.append(f" {escape(truncated)}") if usage: parts.append(f" {json.dumps(usage)}") - parts.append("") + parts.append("") parts.append("") return "\n".join(parts) -def format_wechat_message(sender_name: str, user_id: str, text: str) -> str: - """Format incoming WeChat message for thread delivery. - - Agent sees: full message with user_id metadata (needed for wechat_send reply). - Frontend sees: just the message text (system-reminder stripped). - """ - return ( - f"{text}\n" - "\n" - "\n" - f" {escape(sender_name)}\n" - f" {escape(user_id)}\n" - "\n" - 'To reply, use wechat_send(user_id="' + escape(user_id) + '", text="...").\n' - "" - ) - - def format_command_notification( command_id: str, status: Literal["completed", "failed"], diff --git a/core/runtime/middleware/queue/manager.py b/core/runtime/middleware/queue/manager.py index fd155b94d..f7ea1466f 100644 --- a/core/runtime/middleware/queue/manager.py +++ b/core/runtime/middleware/queue/manager.py @@ -11,7 +11,7 @@ from collections.abc import Callable from pathlib import Path -from storage.contracts import QueueItem, QueueRepo +from storage.contracts import NotificationType, QueueItem, QueueRepo logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def enqueue( self, content: str, thread_id: str, - notification_type: str = "steer", + notification_type: NotificationType = "steer", source: str | None = None, sender_id: str | None = None, sender_name: str | None = None, diff --git a/core/runtime/middleware/queue/middleware.py b/core/runtime/middleware/queue/middleware.py index ccb9c30be..714d0bd54 100644 --- a/core/runtime/middleware/queue/middleware.py +++ b/core/runtime/middleware/queue/middleware.py @@ -10,30 +10,65 @@ from collections.abc import Awaitable, Callable from typing import Any -from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig -try: - from langchain.agents.middleware.types import ( - AgentMiddleware, - ModelCallResult, - ModelRequest, - ModelResponse, - ToolCallRequest, +from core.runtime.middleware import ( + AgentMiddleware, + ModelCallResult, + ModelRequest, + ModelResponse, + ToolCallRequest, +) +from core.runtime.notifications import is_terminal_background_notification + +from .manager import MessageQueueManager + +logger = logging.getLogger(__name__) + +_STEER_NON_PREEMPTIVE_SYSTEM_NOTE = ( + "Steer requests accepted during an active run are non-preemptive. " + "If any tool call from the interrupted run already started, it was allowed to finish and its side effects may " + "already have happened. Do not claim that prior work was interrupted, prevented, cancelled, or rolled back. " + "Treat the steer as instructions for what to do next after that completed work, and answer honestly about any " + "side effects that may already exist." +) + + +def _is_terminal_background_notification(item: Any) -> bool: + return is_terminal_background_notification( + getattr(item, "content", None), + source="system", + notification_type=getattr(item, "notification_type", None), ) -except ImportError: - class AgentMiddleware: - pass - ModelRequest = Any - ModelResponse = Any - ModelCallResult = Any - ToolCallRequest = Any +def _is_owner_steer_message(message: Any) -> bool: + if message.__class__.__name__ != "HumanMessage": + return False + metadata = getattr(message, "metadata", {}) or {} + return bool(metadata.get("is_steer") or (metadata.get("source") == "owner" and metadata.get("notification_type") == "steer")) -from .manager import MessageQueueManager -logger = logging.getLogger(__name__) +def _apply_steer_contract(request: ModelRequest) -> ModelRequest: + if not any(_is_owner_steer_message(message) for message in request.messages): + return request + + system_message = request.system_message + if system_message is None: + return request.override(system_message=SystemMessage(content=_STEER_NON_PREEMPTIVE_SYSTEM_NOTE)) + + content = getattr(system_message, "content", None) + if isinstance(content, str): + if _STEER_NON_PREEMPTIVE_SYSTEM_NOTE in content: + return request + # @@@steer-honesty-contract - mid-run steer stays a real user message in + # durable history, but the live model call also needs an explicit + # non-preemptive contract so it cannot overclaim that already-started + # tool work was stopped or never produced side effects. + return request.override(system_message=SystemMessage(content=f"{content}\n\n{_STEER_NON_PREEMPTIVE_SYSTEM_NOTE}")) + + return request.override(messages=[SystemMessage(content=_STEER_NON_PREEMPTIVE_SYSTEM_NOTE), *request.messages]) class SteeringMiddleware(AgentMiddleware): @@ -66,6 +101,20 @@ async def awrap_tool_call( """Async pure passthrough — never skip tool calls.""" return await handler(request) + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(_apply_steer_contract(request)) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(_apply_steer_contract(request)) + def before_model( self, state: Any, @@ -79,7 +128,27 @@ def before_model( return None items = self._queue_manager.drain_all(thread_id) - rt = self._agent_runtime + inject_now = [] + deferred = [] + for item in items: + if _is_terminal_background_notification(item): + deferred.append(item) + else: + inject_now.append(item) + # @@@followup-defer - terminal background notifications must never be + # injected inline into an active run. Their stable contract is a + # dedicated followthrough notice-only turn, regardless of the current + # run source. + for item in deferred: + self._queue_manager.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + ) + items = inject_now if not items: return None @@ -109,14 +178,15 @@ def before_model( # breaks the turn at the steer injection point. # user_message is NOT emitted here — wake_handler already did it # at enqueue time (@@@steer-instant-feedback). - if has_steer and rt and hasattr(rt, "emit_activity_event"): - rt.emit_activity_event( + agent_runtime = self._agent_runtime + if has_steer and agent_runtime and hasattr(agent_runtime, "emit_activity_event"): + agent_runtime.emit_activity_event( { "event": "run_done", "data": json.dumps({"thread_id": thread_id}), } ) - rt.emit_activity_event( + agent_runtime.emit_activity_event( { "event": "run_start", "data": json.dumps({"thread_id": thread_id, "showing": True}), diff --git a/core/runtime/middleware/spill_buffer/middleware.py b/core/runtime/middleware/spill_buffer/middleware.py index ca519cb27..66390718d 100644 --- a/core/runtime/middleware/spill_buffer/middleware.py +++ b/core/runtime/middleware/spill_buffer/middleware.py @@ -2,28 +2,16 @@ from __future__ import annotations +import json +import mimetypes +import posixpath from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any from langchain_core.messages import ToolMessage -try: - from langchain.agents.middleware.types import ( - AgentMiddleware, - ModelRequest, - ModelResponse, - ToolCallRequest, - ) -except ImportError: - - class AgentMiddleware: # type: ignore[no-redef] - pass - - ModelRequest = Any - ModelResponse = Any - ToolCallRequest = Any - +from core.runtime.middleware import AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest from core.tools.filesystem.backend import FileSystemBackend from .spill import spill_if_needed @@ -57,6 +45,53 @@ def __init__( self.thresholds: dict[str, int] = thresholds or {} self.default_threshold = default_threshold + def _rewrite_mcp_blocks(self, content: Any, *, tool_call_id: str) -> Any: + if not isinstance(content, list): + return content + + lines: list[str] = [] + saw_mcp_blocks = False + + for index, block in enumerate(content): + if not isinstance(block, dict): + return content + + kind = block.get("type") + if kind == "text": + lines.append(str(block.get("text", ""))) + continue + + saw_mcp_blocks = True + mime_type = str(block.get("mime_type") or "application/octet-stream") + guessed_ext = mimetypes.guess_extension(mime_type.split(";", 1)[0].strip()) or ".bin" + + if isinstance(block.get("base64"), str): + payload_path = posixpath.join( + self.workspace_root, + ".leon", + "tool-results", + f"{tool_call_id}-{index}{guessed_ext}.base64", + ) + # @@@mcp-binary-handoff - api-04 keeps Leon's sandbox/file + # abstraction by persisting encoded payloads through fs_backend + # instead of writing host-local bytes behind the sandbox's back. + write_result = self.fs_backend.write_file(payload_path, block["base64"]) + if hasattr(write_result, "success") and not write_result.success: + raise RuntimeError(write_result.error or f"failed to persist MCP payload to {payload_path}") + lines.append(f"MCP binary content ({mime_type}) saved to {payload_path} as base64 payload.") + continue + + if isinstance(block.get("url"), str): + lines.append(f"MCP {kind} content available at {block['url']} ({mime_type})") + continue + + lines.append(json.dumps(block, ensure_ascii=False, default=str)) + + if not saw_mcp_blocks: + text_only = "\n".join(line for line in lines if line) + return text_only if text_only else content + return "\n".join(line for line in lines if line) + # -- model call: pass-through ------------------------------------------ def wrap_model_call( @@ -81,6 +116,19 @@ def _maybe_spill(self, request: ToolCallRequest, result: ToolMessage) -> ToolMes if tool_name in SKIP_TOOLS: return result + source = result.additional_kwargs.get("tool_result_meta", {}).get("source") + normalized_content = result.content + if source == "mcp": + normalized_content = self._rewrite_mcp_blocks( + normalized_content, + tool_call_id=request.tool_call.get("id", "unknown"), + ) + if normalized_content is not result.content: + result = result.model_copy(update={"content": normalized_content}) + + if isinstance(result.content, str) and not result.content.strip(): + return result.model_copy(update={"content": f"({tool_name} completed with no output)"}) + threshold = self.thresholds.get(tool_name, self.default_threshold) tool_call_id = request.tool_call.get("id", "unknown") @@ -93,10 +141,10 @@ def _maybe_spill(self, request: ToolCallRequest, result: ToolMessage) -> ToolMes ) if spilled is not result.content: - return ToolMessage( - content=spilled, - tool_call_id=result.tool_call_id, - ) + # @@@spill-message-preservation - replacing content must not discard + # metadata/name/id; te-03 is about persisted handoff, not rebuilding + # a thinner ToolMessage shell. + return result.model_copy(update={"content": spilled}) return result def wrap_tool_call( diff --git a/core/runtime/middleware/spill_buffer/spill.py b/core/runtime/middleware/spill_buffer/spill.py index 8246a4f33..58cfa470e 100644 --- a/core/runtime/middleware/spill_buffer/spill.py +++ b/core/runtime/middleware/spill_buffer/spill.py @@ -2,7 +2,7 @@ from __future__ import annotations -import os +import posixpath from typing import Any from core.tools.filesystem.backend import FileSystemBackend @@ -10,6 +10,14 @@ PREVIEW_BYTES = 2048 +def _format_preview(content: str) -> str: + preview = content[:PREVIEW_BYTES] + cutoff = preview.rfind("\n") + if cutoff >= PREVIEW_BYTES // 2: + return preview[:cutoff] + return preview + + def spill_if_needed( content: Any, threshold_bytes: int, @@ -36,8 +44,8 @@ def spill_if_needed( if size <= threshold_bytes: return content - spill_dir = os.path.join(workspace_root, ".leon", "tool-results") - spill_path = os.path.join(spill_dir, f"{tool_call_id}.txt") + spill_dir = posixpath.join(workspace_root, ".leon", "tool-results") + spill_path = posixpath.join(spill_dir, f"{tool_call_id}.txt") write_note = "" try: @@ -50,10 +58,15 @@ def spill_if_needed( write_note = f"\n\n(Warning: failed to save full output to disk: {exc})" spill_path = "" - preview = content[:PREVIEW_BYTES] + # @@@persisted-output-wrapper - te-03 is about durable handoff semantics, + # not "shorter string". The model must see an explicit persisted artifact + # boundary plus the re-read path, otherwise we silently amputate context. + preview = _format_preview(content) return ( - f"Output too large ({size} bytes). Full output saved to: {spill_path}" - f"\n\nUse read_file to view specific sections with offset and limit parameters." - f"\n\nPreview (first {PREVIEW_BYTES} bytes):\n{preview}\n..." - f"{write_note}" + f'' + f"\nSize: {size} bytes" + f"\nUse read_file to inspect the full persisted output." + f"\nPreview (first {PREVIEW_BYTES} bytes):\n{preview}\n..." + f"{write_note}\n" + f"" ) diff --git a/core/runtime/notifications.py b/core/runtime/notifications.py new file mode 100644 index 000000000..f70ffc1fa --- /dev/null +++ b/core/runtime/notifications.py @@ -0,0 +1,13 @@ +from __future__ import annotations + + +def is_terminal_background_notification( + content: str | None, + *, + source: str | None, + notification_type: str | None, +) -> bool: + if source != "system" or notification_type not in {"agent", "command"}: + return False + text = content or "" + return "" in text or "" in text diff --git a/core/runtime/permissions.py b/core/runtime/permissions.py new file mode 100644 index 000000000..37c182ed7 --- /dev/null +++ b/core/runtime/permissions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +PERMISSION_RULE_SOURCES = ( + "userSettings", + "projectSettings", + "localSettings", + "flagSettings", + "policySettings", + "cliArg", + "session", +) + + +@dataclass(frozen=True) +class ToolPermissionContext: + is_read_only: bool + is_destructive: bool = False + # @@@camelcase-permission-surface - external state/routes already speak this camelCase shape. + alwaysAllowRules: dict[str, list[str]] | None = None # noqa: N815 + alwaysDenyRules: dict[str, list[str]] | None = None # noqa: N815 + alwaysAskRules: dict[str, list[str]] | None = None # noqa: N815 + allowManagedPermissionRulesOnly: bool = False # noqa: N815 + + +def can_auto_approve(context: ToolPermissionContext) -> bool: + return context.is_read_only and not context.is_destructive + + +def _active_sources(context: ToolPermissionContext) -> tuple[str, ...]: + if context.allowManagedPermissionRulesOnly: + return ("policySettings",) + return PERMISSION_RULE_SOURCES + + +def _extract_tool_name(rule: str) -> str: + rule = rule.strip() + open_paren = rule.find("(") + return rule if open_paren == -1 else rule[:open_paren] + + +def _find_matching_rule( + rule_buckets: dict[str, list[str]] | None, + tool_name: str, + *, + sources: tuple[str, ...], +) -> str | None: + if not rule_buckets: + return None + for source in sources: + for rule in rule_buckets.get(source, []): + if _extract_tool_name(rule) == tool_name: + return rule + return None + + +def evaluate_permission_rules( + tool_name: str, + context: ToolPermissionContext, +) -> dict[str, Any] | None: + sources = _active_sources(context) + + deny_rule = _find_matching_rule(context.alwaysDenyRules, tool_name, sources=sources) + if deny_rule is not None: + return {"decision": "deny", "message": f"Permission denied by rule: {deny_rule}"} + + ask_rule = _find_matching_rule(context.alwaysAskRules, tool_name, sources=sources) + if ask_rule is not None: + return {"decision": "ask", "message": f"Permission required by rule: {ask_rule}"} + + allow_rule = _find_matching_rule(context.alwaysAllowRules, tool_name, sources=sources) + if allow_rule is not None: + return {"decision": "allow", "message": f"Permission allowed by rule: {allow_rule}"} + + return None diff --git a/core/runtime/prompts.py b/core/runtime/prompts.py new file mode 100644 index 000000000..6077cf371 --- /dev/null +++ b/core/runtime/prompts.py @@ -0,0 +1,217 @@ +"""System prompt builders — pure functions, no agent state. + +Extracted from LeonAgent so agent.py stays lean. + +Middleware Stack +- MemoryMiddleware: trims/compacts conversation context before model calls. +- MonitorMiddleware: aggregates runtime metrics and observes model execution. +- PromptCachingMiddleware: enables Anthropic prompt caching for eligible requests. +- SteeringMiddleware: drains queued messages and injects them before the next model call. +- SpillBufferMiddleware: spills oversized tool outputs to disk and replaces them with previews. +""" + +from __future__ import annotations + +from typing import NamedTuple + + +class RuleSpec(NamedTuple): + title: str + body: str + details: tuple[str, ...] = () + + +def _render_rule(index: int, rule: RuleSpec) -> str: + rendered = f"{index}. **{rule.title}**: {rule.body}" + if not rule.details: + return rendered + return rendered + "\n" + "\n".join(f" - {detail}" for detail in rule.details) + + +def _build_core_rules(*, is_sandbox: bool, sandbox_name: str, workspace_root: str, working_dir: str) -> list[RuleSpec]: + rules: list[RuleSpec] = [] + if is_sandbox: + if sandbox_name == "docker": + location_rule = "All file and command operations run in a local Docker container, NOT on the user's host filesystem." + else: + location_rule = "All file and command operations run in a remote sandbox, NOT on the user's local machine." + rules.append(RuleSpec("Sandbox Environment", f"{location_rule} The sandbox is an isolated Linux environment.")) + else: + rules.append(RuleSpec("Workspace", "File operations are restricted to: " + workspace_root)) + + rules.append( + RuleSpec( + "Absolute Paths", + "All file paths must be absolute paths.", + ( + f"Correct: `{working_dir}/project/test.py`", + "Wrong: `test.py` or `./test.py`", + ), + ) + ) + + if is_sandbox: + security = "The sandbox is isolated. You can install packages, run any commands, and modify files freely." + else: + security = "Dangerous commands are blocked. All operations are logged." + rules.append(RuleSpec("Security", security)) + return rules + + +def _build_risk_rules() -> list[RuleSpec]: + return [ + RuleSpec( + "Risky Actions", + "Ask before destructive, hard-to-reverse, or shared-state actions.", + ( + "Examples: deleting files, force-pushing, dropping tables, killing unfamiliar processes, modifying shared infrastructure.", + "If you see unexpected state, investigate before deleting or overwriting it.", + ), + ), + RuleSpec( + "No URL Guessing", + "Do not guess URLs unless the user provided them or you are confident they are directly relevant to programming help.", + ), + RuleSpec( + "Minimal Change", + "Do not add features, refactor code, or make speculative abstractions beyond what the task requires.", + ( + "Don't create helpers, utilities, or abstractions for one-time operations.", + "Don't add error handling, fallbacks, or validation for scenarios that can't happen.", + ), + ), + ] + + +def _build_tool_preference_rules() -> list[RuleSpec]: + return [ + RuleSpec( + "Tool Priority", + "When a built-in tool and an MCP tool (`mcp__*`) have the same functionality, use the built-in tool.", + ), + RuleSpec( + "Tool Preference", + "Prefer dedicated tools over `Bash` when a built-in tool already matches the job.", + ( + "Use `Read` instead of `cat`, `head`, or `tail`.", + "Use `Edit` instead of shell text-munging for file edits.", + "Use `Write` instead of heredoc or echo redirection for file creation.", + "Use `Glob`/`Grep` for file discovery and content search before falling back to `Bash`.", + ), + ), + ] + + +def _build_interaction_rules() -> list[RuleSpec]: + return [] + + +def _build_function_result_clearing_rules(*, spill_buffer_enabled: bool, spill_keep_recent: int) -> list[RuleSpec]: + if not spill_buffer_enabled: + return [] + return [ + RuleSpec( + "Function Result Clearing", + f"Old tool results may be cleared from context to free up space. The {spill_keep_recent} most recent results are always kept.", + ( + "When working with tool results, write down any important information " + "you might need later in your response, as the original tool result " + "may be cleared later.", + ), + ) + ] + + +def _build_rule_specs( + *, + is_sandbox: bool, + sandbox_name: str, + workspace_root: str, + working_dir: str, + spill_buffer_enabled: bool, + spill_keep_recent: int, +) -> list[RuleSpec]: + rules: list[RuleSpec] = [] + rules.extend( + _build_core_rules( + is_sandbox=is_sandbox, + sandbox_name=sandbox_name, + workspace_root=workspace_root, + working_dir=working_dir, + ) + ) + rules.extend(_build_risk_rules()) + rules.extend(_build_tool_preference_rules()) + rules.extend( + _build_function_result_clearing_rules( + spill_buffer_enabled=spill_buffer_enabled, + spill_keep_recent=spill_keep_recent, + ) + ) + rules.extend(_build_interaction_rules()) + return rules + + +def build_context_section( + *, + sandbox_name: str, + sandbox_env_label: str = "", + sandbox_working_dir: str = "", + workspace_root: str = "", + os_name: str = "", + shell_name: str = "", +) -> str: + if sandbox_name != "local": + mode_label = "Sandbox (isolated local container)" if sandbox_name == "docker" else "Sandbox (isolated cloud environment)" + return f"""- Environment: {sandbox_env_label} +- Working Directory: {sandbox_working_dir} +- Mode: {mode_label}""" + return f"""- Workspace: `{workspace_root}` +- OS: {os_name} +- Shell: {shell_name} +- Mode: Local""" + + +def build_rules_section( + *, + is_sandbox: bool, + sandbox_name: str = "", + working_dir: str, + workspace_root: str, + spill_buffer_enabled: bool = False, + spill_keep_recent: int = 0, +) -> str: + rule_specs = _build_rule_specs( + is_sandbox=is_sandbox, + sandbox_name=sandbox_name, + workspace_root=workspace_root, + working_dir=working_dir, + spill_buffer_enabled=spill_buffer_enabled, + spill_keep_recent=spill_keep_recent, + ) + return "\n\n".join(_render_rule(index, rule) for index, rule in enumerate(rule_specs, start=1)) + + +def build_base_prompt(context: str, rules: str) -> str: + return f"""You are a highly capable AI assistant with access to file and system tools. + +**Context:** +{context} + +**Important Rules:** + +{rules} +""" + + +_AGENT_TOOL_SECTION = """ +**Sub-agent Types:** +- `explore`: Read-only codebase exploration (Grep, Glob, Read only) +- `plan`: Architecture design and planning (read-only tools) +- `bash`: Shell command execution (Bash + read tools) +- `general`: Full tool access for independent multi-step tasks +""" + + +def build_common_sections(skills_enabled: bool) -> str: + return _AGENT_TOOL_SECTION diff --git a/core/runtime/registry.py b/core/runtime/registry.py index f6a87f008..4b9de4ccb 100644 --- a/core/runtime/registry.py +++ b/core/runtime/registry.py @@ -1,11 +1,46 @@ from __future__ import annotations from collections.abc import Awaitable, Callable +from copy import deepcopy from dataclasses import dataclass from enum import Enum +from typing import Any, NotRequired, Required, TypedDict -Handler = Callable[..., str] | Callable[..., Awaitable[str]] -SchemaProvider = dict | Callable[[], dict] +from core.runtime.tool_result import ToolResultEnvelope + +type ToolSchema = dict[str, Any] +type ToolHandlerResult = str | ToolResultEnvelope +type ToolArgs = dict[str, Any] +type ToolPropertySchema = dict[str, Any] +type ToolProperties = dict[str, ToolPropertySchema] + +type Handler = Callable[..., ToolHandlerResult] | Callable[..., Awaitable[ToolHandlerResult]] +type SchemaProvider = ToolSchema | Callable[[], ToolSchema] +type ConcurrencySafety = bool | Callable[[ToolArgs], bool] +type ToolInputValidator = Callable[[ToolArgs, Any], ToolArgs | None] | Callable[[ToolArgs, Any], Awaitable[ToolArgs | None]] + + +class _ToolEntryDefaults(TypedDict): + search_hint: str + is_concurrency_safe: ConcurrencySafety + is_read_only: bool + is_destructive: bool + context_schema: ToolSchema | None + validate_input: ToolInputValidator | None + + +class _ToolEntryBuildArgs(TypedDict, total=False): + name: Required[str] + mode: Required[ToolMode] + schema: Required[SchemaProvider] + handler: Required[Handler] + source: Required[str] + search_hint: NotRequired[str] + is_concurrency_safe: NotRequired[ConcurrencySafety] + is_read_only: NotRequired[bool] + is_destructive: NotRequired[bool] + context_schema: NotRequired[ToolSchema | None] + validate_input: NotRequired[ToolInputValidator | None] class ToolMode(Enum): @@ -20,11 +55,50 @@ class ToolEntry: schema: SchemaProvider handler: Handler source: str - - def get_schema(self) -> dict: + search_hint: str = "" # 3-10 word capability description for ToolSearch matching + is_concurrency_safe: ConcurrencySafety = False # fail-closed: assume not safe + is_read_only: bool = False # fail-closed: assume write operation + is_destructive: bool = False # advisory metadata for permission/UI layers + context_schema: ToolSchema | None = None # fields this tool needs from ToolUseContext + validate_input: ToolInputValidator | None = None + + def get_schema(self) -> ToolSchema: return self.schema() if callable(self.schema) else self.schema +TOOL_DEFAULTS: _ToolEntryDefaults = { + "search_hint": "", + "is_concurrency_safe": False, + "is_read_only": False, + "is_destructive": False, + "context_schema": None, + "validate_input": None, +} + + +def make_tool_schema( + *, + name: str, + description: str, + properties: ToolProperties, + required: list[str] | None = None, + parameter_overrides: ToolSchema | None = None, +) -> ToolSchema: + parameters: ToolSchema = { + "type": "object", + "properties": properties, + } + if required: + parameters["required"] = required + if parameter_overrides: + parameters.update(parameter_overrides) + return { + "name": name, + "description": description, + "parameters": parameters, + } + + class ToolRegistry: """Central registry for all tools. @@ -55,23 +129,70 @@ def register(self, entry: ToolEntry) -> None: def get(self, name: str) -> ToolEntry | None: return self._tools.get(name) - def get_inline_schemas(self) -> list[dict]: - return [e.get_schema() for e in self._tools.values() if e.mode == ToolMode.INLINE] - - def search(self, query: str) -> list[ToolEntry]: - """Return all matching tools (including inline) for tool_search.""" - q = query.lower() - results = [] - for entry in self._tools.values(): + def get_inline_schemas(self, discovered_tool_names: set[str] | None = None) -> list[dict]: + discovered_tool_names = discovered_tool_names or set() + return [ + self._sanitize_schema_for_model(e.get_schema()) + for e in self._tools.values() + if e.mode == ToolMode.INLINE or e.name in discovered_tool_names + ] + + def _sanitize_schema_for_model(self, schema: dict) -> dict: + # @@@tool-schema-sanitize - runtime-only schema metadata is useful for + # validator/readiness, but provider tool schemas must stay within the + # subset the live model API accepts. + def _walk(value: Any) -> Any: + if isinstance(value, dict): + return {key: _walk(child) for key, child in value.items() if not (isinstance(key, str) and key.startswith("x-leon-"))} + if isinstance(value, list): + return [_walk(item) for item in value] + return value + + return _walk(deepcopy(schema)) + + def search(self, query: str, *, modes: set[ToolMode] | None = None) -> list[ToolEntry]: + """Return matching tools with ranked relevance. + + Supports ``select:Name1,Name2`` for exact selection. + Otherwise ranks by: search_hint > name > description. + """ + q = query.strip() + entries = [entry for entry in self._tools.values() if modes is None or entry.mode in modes] + + # --- select: exact lookup --- + if q.lower().startswith("select:"): + names = [n.strip() for n in q[len("select:") :].split(",") if n.strip()] + results = [self._tools[n] for n in names if n in self._tools and (modes is None or self._tools[n].mode in modes)] + return results + + # --- keyword search with ranking --- + keywords = q.lower().split() + if not keywords: + return list(entries) + + scored: list[tuple[int, ToolEntry]] = [] + for entry in entries: schema = entry.get_schema() - name = schema.get("name", "") - desc = schema.get("description", "") - if q in name.lower() or q in desc.lower(): - results.append(entry) - # If no match, return all - if not results: - results = list(self._tools.values()) - return results + name_lower = entry.name.lower() + hint_lower = entry.search_hint.lower() + desc_lower = schema.get("description", "").lower() + + score = 0 + for kw in keywords: + if kw in hint_lower: + score += 3 + if kw in name_lower: + score += 2 + if kw in desc_lower: + score += 1 + if score > 0: + scored.append((score, entry)) + + if not scored: + return [] + + scored.sort(key=lambda x: x[0], reverse=True) + return [entry for _, entry in scored] def list_all(self) -> list[ToolEntry]: return list(self._tools.values()) diff --git a/core/runtime/runner.py b/core/runtime/runner.py index ade917216..15fffb02c 100644 --- a/core/runtime/runner.py +++ b/core/runtime/runner.py @@ -1,23 +1,44 @@ from __future__ import annotations import asyncio +import copy +import inspect import json import logging +import threading from collections.abc import Awaitable, Callable +from typing import Any, cast -from langchain.agents.middleware.types import ( +from langchain_core.messages import ToolMessage + +from core.runtime.middleware import ( AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest, ) -from langchain_core.messages import ToolMessage from .errors import InputValidationError +from .permissions import ToolPermissionContext from .registry import ToolRegistry +from .tool_result import ( + ToolResultEnvelope, + materialize_tool_message, + tool_error, + tool_permission_denied, + tool_permission_request, + tool_success, +) from .validator import ToolValidator logger = logging.getLogger(__name__) +DEFAULT_ASYNC_HOOK_TIMEOUT_S = 15.0 + + +class _ToolSpecificValidationError(Exception): + def __init__(self, message: str, error_code: str | None = None): + super().__init__(message) + self.error_code = error_code class ToolRunner(AgentMiddleware): @@ -48,9 +69,9 @@ def _inject_tools(self, request: ModelRequest) -> ModelRequest: def _extract_call_info(self, request: ToolCallRequest) -> tuple[str, dict, str]: tool_call = request.tool_call - name = tool_call.get("name") + name = tool_call.get("name") or "" args = tool_call.get("args", {}) - call_id = tool_call.get("id", "") + call_id = tool_call.get("id", "") or "" if isinstance(args, str): try: @@ -60,49 +81,612 @@ def _extract_call_info(self, request: ToolCallRequest) -> tuple[str, dict, str]: return name, args, call_id - def _validate_and_run(self, name: str, args: dict, call_id: str) -> ToolMessage: - entry = self._registry.get(name) - if entry is None: - return None # not our tool + @staticmethod + def _get_request_hook(request: ToolCallRequest, hook_name: str): + state = getattr(request, "state", None) + if state is None: + return None + if isinstance(state, dict): + hook = state.get(hook_name) + else: + hook = vars(state).get(hook_name) + if hook is None: + return None + if isinstance(hook, list): + return hook + return hook if callable(hook) else None - schema = entry.get_schema() + @staticmethod + async def _apply_result_hooks( + hook_or_hooks, + payload: ToolMessage | ToolResultEnvelope, + request: ToolCallRequest, + ) -> ToolMessage | ToolResultEnvelope: + if hook_or_hooks is None: + return payload + hooks = hook_or_hooks if isinstance(hook_or_hooks, list) else [hook_or_hooks] + current = payload + + async def _invoke(hook): + updated = hook(copy.deepcopy(payload), request) + if asyncio.iscoroutine(updated): + updated = await ToolRunner._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + for updated in await asyncio.gather(*(_invoke(hook) for hook in hooks)): + if updated is not None: + current = updated + return current + + def _normalize_result(self, result: Any) -> ToolResultEnvelope: + if isinstance(result, ToolResultEnvelope): + return result + return tool_success(result) + + @staticmethod + def _resolve_context_path(state: Any, path: str) -> Any: + current = state + for segment in path.split("."): + if segment == "app_state": + current = current.get_app_state() + continue + if isinstance(current, dict): + current = current[segment] + else: + current = getattr(current, segment) + return current + + @staticmethod + def _inject_handler_context(entry, args: dict, request: ToolCallRequest) -> dict: + state = getattr(request, "state", None) + if state is None: + return args try: - self._validator.validate(schema, args) - except InputValidationError as e: - return ToolMessage( - content=f"InputValidationError: {name} failed due to the following issue:\n{e}", - tool_call_id=call_id, - name=name, - ) + signature = inspect.signature(entry.handler) + except (TypeError, ValueError): + return args + accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) + injected = dict(args) + + context_schema = getattr(entry, "context_schema", None) or {} + if isinstance(context_schema, dict): + # @@@pt-02-context-schema-mapping + # Pattern 2 only becomes real once declared ToolUseContext field + # mappings are injected into handler kwargs on the live path. + for param_name, context_path in context_schema.items(): + if param_name in injected: + continue + if not accepts_kwargs and param_name not in signature.parameters: + continue + injected[param_name] = ToolRunner._resolve_context_path(state, context_path) + + if "tool_context" in injected: + return injected + if accepts_kwargs or "tool_context" in signature.parameters: + # @@@sa-04-tool-context-injection + # The sub-agent boundary only becomes real once the live ToolUseContext + # can cross the tool runner into handlers that explicitly opt in. + injected["tool_context"] = state + return injected + + @staticmethod + def _coerce_permission_response(result) -> tuple[str | None, str | None]: + if result is None: + return None, None + if isinstance(result, str): + return result, None + if isinstance(result, dict): + decision = result.get("decision") or result.get("permission") + message = result.get("message") + return decision, message + decision = getattr(result, "decision", None) or getattr(result, "permission", None) + message = getattr(result, "message", None) + return decision, message + + @staticmethod + def _permission_denied_result(decision: str, message: str | None) -> ToolResultEnvelope: + if decision == "ask": + text = message or "Permission required" + else: + text = message or "Permission denied" + return tool_permission_denied( + text, + metadata={"decision": decision, "error_type": "permission_resolution"}, + ) + + @staticmethod + def _permission_request_result(request_id: str, message: str | None) -> ToolResultEnvelope: + return tool_permission_request( + message or "Permission required", + metadata={ + "decision": "ask", + "request_id": request_id, + "error_type": "permission_resolution", + }, + ) + @staticmethod + def _materialize_permission_ask( + request_id: str | None, + message: str | None, + ) -> ToolResultEnvelope: + # @@@permission-ask-materialization + # Ask is only honest when a concrete request surface exists. Otherwise + # fail loudly as a deny so caller metadata matches the actual runtime. + if request_id is not None: + return ToolRunner._permission_request_result(request_id, message) + return ToolRunner._permission_denied_result("deny", message) + + @staticmethod + def _run_awaitable_sync(awaitable): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(awaitable) + + result_box: list[Any] = [] + error_box: list[BaseException] = [] + + # @@@sync-awaitable-bridge - sync tool entrypoints still need to consume + # async permission checkers even when called from a live event loop. + def _runner() -> None: + try: + result_box.append(asyncio.run(awaitable)) + except BaseException as exc: # pragma: no cover - re-raised below + error_box.append(exc) + + thread = threading.Thread(target=_runner, daemon=True) + thread.start() + thread.join() + + if error_box: + raise error_box[0] + return result_box[0] if result_box else None + + @staticmethod + def _get_async_hook_timeout_s(request: ToolCallRequest) -> float: + state = getattr(request, "state", None) + if state is None: + return DEFAULT_ASYNC_HOOK_TIMEOUT_S + hook_timeout_ms = state.get("hook_timeout_ms") if isinstance(state, dict) else getattr(state, "hook_timeout_ms", None) + if isinstance(hook_timeout_ms, (int, float)) and hook_timeout_ms > 0: + return float(hook_timeout_ms) / 1000.0 + hook_timeout_s = state.get("hook_timeout_s") if isinstance(state, dict) else getattr(state, "hook_timeout_s", None) + if isinstance(hook_timeout_s, (int, float)) and hook_timeout_s > 0: + return float(hook_timeout_s) + return DEFAULT_ASYNC_HOOK_TIMEOUT_S + + @staticmethod + async def _await_async_hook_with_timeout( + request: ToolCallRequest, + awaitable, + *, + hook_name: str, + ): + timeout_s = ToolRunner._get_async_hook_timeout_s(request) + task = asyncio.create_task(awaitable) try: - result = entry.handler(**args) + return await asyncio.wait_for(task, timeout=timeout_s) + except TimeoutError: + logger.warning("Async hook %s timed out after %.3fs; ignoring hook result", hook_name, timeout_s) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + return None + + @staticmethod + def _await_async_hook_with_timeout_sync( + request: ToolCallRequest, + awaitable, + *, + hook_name: str, + ): + return ToolRunner._run_awaitable_sync( + ToolRunner._await_async_hook_with_timeout( + request, + awaitable, + hook_name=hook_name, + ) + ) + + @staticmethod + def _get_state_callable(request: ToolCallRequest, name: str): + state = getattr(request, "state", None) + if state is None: + return None + return state.get(name) if isinstance(state, dict) else getattr(state, name, None) + + async def _consume_permission_resolution_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + ) -> tuple[str | None, str | None]: + consumer = self._get_state_callable(request, "consume_permission_resolution") + if not callable(consumer): + return None, None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = consumer(name, args, permission_context, request) + if asyncio.iscoroutine(result): + result = await result + return self._coerce_permission_response(result) + + async def _request_permission_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + message: str | None, + ) -> str | None: + requester = self._get_state_callable(request, "request_permission") + if not callable(requester): + return None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = requester(name, args, permission_context, request, message) + if asyncio.iscoroutine(result): + result = await result + if isinstance(result, dict): + request_id = result.get("request_id") + return request_id if isinstance(request_id, str) else None + return result if isinstance(result, str) else None + + async def _run_tool_specific_validation_async(self, entry, args: dict, request: ToolCallRequest) -> dict: + validator = getattr(entry, "validate_input", None) + if validator is None: + return args + result = validator(dict(args), request) + if asyncio.iscoroutine(result): + result = await result + if result is None: + return args + if isinstance(result, dict): + if result.get("result") is False or result.get("ok") is False: + raise _ToolSpecificValidationError( + result.get("message") or "Tool-specific validation failed", + result.get("errorCode") or result.get("error_code"), + ) + return result + raise InputValidationError(str(result)) + + async def _run_pre_tool_use_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + ) -> tuple[dict, str | None, str | None]: + hooks = self._get_request_hook(request, "pre_tool_use") + if hooks is None: + return args, None, None + payload = {"name": name, "args": dict(args), "entry": entry} + permission: str | None = None + message: str | None = None + hook_list = hooks if isinstance(hooks, list) else [hooks] + + async def _invoke(hook): + updated = hook({"name": name, "args": dict(args), "entry": entry}, request) + if asyncio.iscoroutine(updated): + updated = await self._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + # @@@pt-06-hook-fanout + # Pattern 6 requires hooks to fan out instead of impersonating a + # middleware chain. We still fold results back in hook-list order so + # the aggregation stays deterministic. + for updated in await asyncio.gather(*(_invoke(hook) for hook in hook_list)): + if updated is None: + continue + if isinstance(updated, dict): + if "args" in updated: + next_args = updated["args"] + if isinstance(next_args, dict): + payload["args"] = {**payload["args"], **next_args} + else: + payload["args"] = next_args + if "name" in updated: + payload["name"] = updated["name"] + if "entry" in updated: + payload["entry"] = updated["entry"] + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission == "deny" and permission != "deny": + permission = new_permission + message = new_message + elif new_permission == "ask" and permission not in {"deny", "ask"}: + permission = new_permission + message = new_message + elif new_permission == "allow" and permission is None: + permission = new_permission + message = new_message + return payload["args"], permission, message + + async def _run_permission_request_hooks_async( + self, + request: ToolCallRequest, + *, + name: str, + entry, + message: str | None, + ) -> tuple[str | None, str | None]: + hooks = self._get_request_hook(request, "permission_request_hooks") + if hooks is None: + return None, message + payload = {"name": name, "entry": entry, "message": message} + permission: str | None = None + hook_message = message + hook_list = hooks if isinstance(hooks, list) else [hooks] + + async def _invoke(hook): + updated = hook(payload, request) + if asyncio.iscoroutine(updated): + updated = await self._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + for updated in await asyncio.gather(*(_invoke(hook) for hook in hook_list)): + if updated is None: + continue + if isinstance(updated, dict): + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission == "deny" and permission != "deny": + permission = new_permission + elif new_permission == "ask" and permission not in {"deny", "ask"}: + permission = new_permission + elif new_permission == "allow" and permission is None: + permission = new_permission + if new_message is not None: + hook_message = new_message + return permission, hook_message + + async def _resolve_permission_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + hook_permission: str | None, + hook_message: str | None, + ) -> ToolResultEnvelope | None: + if hook_permission == "deny": + return self._permission_denied_result("deny", hook_message) + + checker = self._get_state_callable(request, "can_use_tool") + rule_permission: str | None = None + rule_message: str | None = None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + if callable(checker): + result = checker(name, args, permission_context, request) if asyncio.iscoroutine(result): - result = asyncio.get_event_loop().run_until_complete(result) - return ToolMessage(content=str(result), tool_call_id=call_id, name=name) - except Exception as e: - logger.exception("Tool %s execution failed", name) - return ToolMessage( - content=f"{e}", - tool_call_id=call_id, + result = await result + rule_permission, rule_message = self._coerce_permission_response(result) + + # @@@permission-resolution-precedence - only consume one-shot approvals when current state still asks. + if rule_permission == "ask": + resolved_permission, resolved_message = await self._consume_permission_resolution_async( + request, + name=name, + args=args, + entry=entry, + ) + if resolved_permission == "allow": + return None + if resolved_permission in {"deny", "ask"}: + return self._permission_denied_result(resolved_permission, resolved_message) + request_hook_permission, request_hook_message = await self._run_permission_request_hooks_async( + request, name=name, + entry=entry, + message=rule_message, ) + if request_hook_permission == "allow": + return None + if request_hook_permission in {"deny", "ask"}: + return self._permission_denied_result(request_hook_permission, request_hook_message) + rule_message = request_hook_message + + if hook_permission == "allow": + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = await self._request_permission_async( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None + + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = await self._request_permission_async( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None - async def _validate_and_run_async(self, name: str, args: dict, call_id: str) -> ToolMessage | None: + def _materialize_result( + self, + envelope: ToolResultEnvelope, + *, + name: str, + call_id: str, + source: str, + ) -> ToolMessage: + return materialize_tool_message( + envelope, + tool_call_id=call_id, + name=name, + source=source, + ) + + @staticmethod + def _entry_source(entry) -> str: + return "mcp" if getattr(entry, "source", None) == "mcp" else "local" + + def _finalize_registered_result( + self, + envelope: ToolResultEnvelope, + *, + name: str, + call_id: str, + source: str, + ) -> ToolMessage | ToolResultEnvelope: + if source == "mcp": + return envelope + return self._materialize_result( + envelope, + name=name, + call_id=call_id, + source=source, + ) + + async def _finalize_tool_result_async( + self, + request: ToolCallRequest, + result: ToolMessage | ToolResultEnvelope, + *, + name: str, + call_id: str, + source: str, + ) -> ToolMessage: + if isinstance(result, ToolResultEnvelope): + hook_name = self._select_hook_name(result.kind) + hooks = self._get_request_hook(request, hook_name) + hooked = await self._apply_result_hooks(hooks, result, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source=source) + + meta = result.additional_kwargs.get("tool_result_meta", {}) + hook_name = self._select_hook_name(meta.get("kind")) + hooks = self._get_request_hook(request, hook_name) + hooked = await self._apply_result_hooks(hooks, result, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source=source) + + @staticmethod + def _select_hook_name(kind: str) -> str: + if kind == "error": + return "post_tool_use_failure" + if kind == "permission_denied": + return "permission_denied_hooks" + return "post_tool_use" + + @staticmethod + def _input_validation_metadata(error: InputValidationError) -> dict[str, object]: + metadata: dict[str, object] = {"error_type": "input_validation"} + if error.error_code: + metadata["error_code"] = error.error_code + if error.details: + metadata["error_details"] = error.details + return metadata + + async def _validate_and_run_async( + self, + request: ToolCallRequest, + name: str, + args: dict, + call_id: str, + ) -> ToolMessage | ToolResultEnvelope | None: entry = self._registry.get(name) if entry is None: return None + source = self._entry_source(entry) schema = entry.get_schema() try: self._validator.validate(schema, args) except InputValidationError as e: - return ToolMessage( - content=f"InputValidationError: {name} failed due to the following issue:\n{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"InputValidationError: {name} failed due to the following issue:\n{e}", + metadata=self._input_validation_metadata(e), + ), + name=name, + call_id=call_id, + source=source, + ) + try: + args = await self._run_tool_specific_validation_async(entry, args, request) + except _ToolSpecificValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation", "error_code": e.error_code}, + ), + name=name, + call_id=call_id, + source=source, + ) + except InputValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation"}, + ), name=name, + call_id=call_id, + source=source, ) + args, hook_permission, hook_message = await self._run_pre_tool_use_async( + request, + name=name, + args=args, + entry=entry, + ) + permission_result = await self._resolve_permission_async( + request, + name=name, + args=args, + entry=entry, + hook_permission=hook_permission, + hook_message=hook_message, + ) + if permission_result is not None: + return self._finalize_registered_result( + permission_result, + name=name, + call_id=call_id, + source=source, + ) + + args = self._inject_handler_context(entry, args, request) try: if asyncio.iscoroutinefunction(entry.handler): result = await entry.handler(**args) @@ -113,13 +697,22 @@ async def _validate_and_run_async(self, name: str, args: dict, call_id: str) -> result = await asyncio.to_thread(entry.handler, **args) if asyncio.iscoroutine(result): result = await result - return ToolMessage(content=str(result), tool_call_id=call_id, name=name) + return self._finalize_registered_result( + self._normalize_result(result), + name=name, + call_id=call_id, + source=source, + ) except Exception as e: logger.exception("Tool %s execution failed", name) - return ToolMessage( - content=f"{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"{e}", + metadata={"error_type": "tool_execution"}, + ), name=name, + call_id=call_id, + source=source, ) # -- Model call wrappers -- @@ -146,10 +739,26 @@ def wrap_tool_call( handler: Callable[[ToolCallRequest], ToolMessage], ) -> ToolMessage: name, args, call_id = self._extract_call_info(request) - result = self._validate_and_run(name, args, call_id) + entry = self._registry.get(name) + result: ToolMessage | ToolResultEnvelope | None = self._run_awaitable_sync( + self._validate_and_run_async(request, name, args, call_id) + ) if result is not None: - return result - return handler(request) + source = self._entry_source(entry) if entry is not None else "local" + return cast( + ToolMessage, + self._run_awaitable_sync( + self._finalize_tool_result_async( + request, + result, + name=name, + call_id=call_id, + source=source, + ) + ), + ) + upstream = handler(request) + return upstream async def awrap_tool_call( self, @@ -157,7 +766,32 @@ async def awrap_tool_call( handler: Callable[[ToolCallRequest], Awaitable[ToolMessage]], ) -> ToolMessage: name, args, call_id = self._extract_call_info(request) - result = await self._validate_and_run_async(name, args, call_id) + entry = self._registry.get(name) + source = self._entry_source(entry) if entry is not None else "local" + result = await self._validate_and_run_async(request, name, args, call_id) if result is not None: - return result - return await handler(request) + # @@@tool-result-ordering + # te-02 keeps local tools materialize-first, but registered MCP + # tools must stay envelope-first so post hooks can see and modify + # structured output before final ToolMessage creation. + return await self._finalize_tool_result_async( + request, + result, + name=name, + call_id=call_id, + source=source, + ) + + upstream = await handler(request) + post_tool_use = self._get_request_hook(request, "post_tool_use") + if isinstance(upstream, ToolResultEnvelope): + # MCP/upstream path: post hooks get first shot at the structured + # result, and only then do we materialize the ToolMessage. + hooked = await self._apply_result_hooks(post_tool_use, upstream, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source="mcp") + if isinstance(upstream, ToolMessage): + hooked = await self._apply_result_hooks(post_tool_use, upstream, request) + return hooked if isinstance(hooked, ToolMessage) else self._materialize_result(hooked, name=name, call_id=call_id, source="mcp") + return upstream diff --git a/core/runtime/state.py b/core/runtime/state.py new file mode 100644 index 000000000..80b53a4c2 --- /dev/null +++ b/core/runtime/state.py @@ -0,0 +1,172 @@ +"""Three-layer state models aligned with CC architecture. + +Layer 1: BootstrapConfig — survives /clear, process-level constants +Layer 2: AppState — per-session mutable state (Zustand-style store) +Layer 3: ToolUseContext — per-turn, holds live closures to AppState +""" + +from __future__ import annotations + +import uuid +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from .abort import AbortController +from .permissions import ToolPermissionContext + + +class ToolPermissionState(BaseModel): + # @@@camelcase-permission-surface - persisted/thread API surface already uses camelCase keys. + alwaysAllowRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + alwaysDenyRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + alwaysAskRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + allowManagedPermissionRulesOnly: bool = False # noqa: N815 + + +class BootstrapConfig(BaseModel): + """Process-level configuration that survives /clear. + + Analogous to CC Bootstrap State (~85 fields). Contains workspace + identity, model config, security flags, and API credentials. + """ + + workspace_root: Path + original_cwd: Path | None = None + project_root: Path | None = None + cwd: Path | None = None + model_name: str + api_key: str | None = None + sandbox_type: str = "local" + permission_resolver_scope: str = "none" + + # Security flags (fail-closed defaults) + block_dangerous_commands: bool = True + block_network_commands: bool = False + enable_audit_log: bool = True + enable_web_tools: bool = False + + # File access + allowed_file_extensions: list[str] | None = None + extra_allowed_paths: list[str] | None = None + + # Turn limits + max_turns: int | None = None + + # Session identity + session_id: str = Field(default_factory=lambda: uuid.uuid4().hex) + parent_session_id: str | None = None + + # Session accumulators that survive turn-level resets + total_cost_usd: float = 0.0 + total_tool_duration_ms: int = 0 + + # Model settings + model_provider: str | None = None + base_url: str | None = None + context_limit: int | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def model_post_init(self, __context: Any) -> None: + self.workspace_root = Path(self.workspace_root) + self.original_cwd = Path(self.original_cwd) if self.original_cwd is not None else self.workspace_root + self.project_root = Path(self.project_root) if self.project_root is not None else self.workspace_root + self.cwd = Path(self.cwd) if self.cwd is not None else self.project_root + + +class AppState(BaseModel): + """Per-session mutable state. Analogous to CC AppState store. + + Implements a minimal Zustand-style store with getState/setState. + Not reactive — no subscriptions needed for Python backend. + """ + + messages: list = Field(default_factory=list) + turn_count: int = 0 + total_cost: float = 0.0 + compact_boundary_index: int = 0 + # Map of tool_name -> is_enabled (runtime overrides) + tool_overrides: dict[str, bool] = Field(default_factory=dict) + tool_permission_context: ToolPermissionState = Field(default_factory=ToolPermissionState) + pending_permission_requests: dict[str, dict[str, Any]] = Field(default_factory=dict) + resolved_permission_requests: dict[str, dict[str, Any]] = Field(default_factory=dict) + announced_mcp_instruction_blocks: dict[str, dict[str, str]] = Field(default_factory=dict) + # @@@session-hooks-not-watchers - keep this surface local and lifecycle-scoped. + # File watching remains a later outer-layer concern so Leon keeps the + # filesystem + terminal core decoupled. + session_hooks: dict[str, list[Any]] = Field(default_factory=dict) + + def get_state(self) -> AppState: + return self + + def set_state(self, updater: Callable[[AppState], AppState]) -> AppState: + updated = updater(self) + # Mutate in place (Python idiom — no immutable constraint needed here) + for field_name in AppState.model_fields: + setattr(self, field_name, getattr(updated, field_name)) + return self + + def add_session_hook(self, event: str, hook: Any) -> None: + hooks = list(self.session_hooks.get(event, [])) + hooks.append(hook) + self.session_hooks[event] = hooks + + def remove_session_hook(self, event: str, hook: Any) -> None: + hooks = [candidate for candidate in self.session_hooks.get(event, []) if candidate != hook] + if hooks: + self.session_hooks[event] = hooks + else: + self.session_hooks.pop(event, None) + + def get_session_hooks(self, event: str) -> list[Any]: + return list(self.session_hooks.get(event, [])) + + +AppStateUpdater = Callable[[AppState], AppState] +AppStateGetter = Callable[[], AppState] +AppStateSetter = Callable[[AppStateUpdater], AppState | None] +RefreshToolsHook = Callable[[], Awaitable[None] | None] +PermissionDecision = dict[str, Any] | None +PermissionChecker = Callable[ + [str, dict[str, Any], ToolPermissionContext, object], + PermissionDecision | Awaitable[PermissionDecision], +] +PermissionRequester = Callable[ + [str, dict[str, Any], ToolPermissionContext, object, str | None], + str | dict[str, Any] | None | Awaitable[str | dict[str, Any] | None], +] +PermissionResolutionConsumer = Callable[ + [str, dict[str, Any], ToolPermissionContext, object], + PermissionDecision | Awaitable[PermissionDecision], +] + + +class ToolUseContext(BaseModel): + """Per-turn context bag. Analogous to CC ToolUseContext. + + Carries live closures to AppState so tools can read/mutate session state. + Sub-agents receive a NO-OP set_app_state to prevent write-through. + """ + + bootstrap: BootstrapConfig + get_app_state: AppStateGetter = Field(exclude=True) + set_app_state: AppStateSetter = Field(exclude=True) + set_app_state_for_tasks: AppStateSetter | None = Field(default=None, exclude=True) + refresh_tools: RefreshToolsHook | None = Field(default=None, exclude=True) + can_use_tool: PermissionChecker | None = Field(default=None, exclude=True) + request_permission: PermissionRequester | None = Field(default=None, exclude=True) + consume_permission_resolution: PermissionResolutionConsumer | None = Field(default=None, exclude=True) + read_file_state: Any = Field(default_factory=dict, exclude=True) + loaded_nested_memory_paths: Any = Field(default_factory=set, exclude=True) + discovered_skill_names: Any = Field(default_factory=set, exclude=True) + discovered_tool_names: Any = Field(default_factory=set, exclude=True) + nested_memory_attachment_triggers: Any = Field(default_factory=set, exclude=True) + abort_controller: AbortController = Field(default_factory=AbortController, exclude=True) + messages: list = Field(default_factory=list) + thread_id: str = "default" + turn_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8]) + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/core/runtime/tool_result.py b/core/runtime/tool_result.py new file mode 100644 index 000000000..1ccd24288 --- /dev/null +++ b/core/runtime/tool_result.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.messages import ToolMessage + + +@dataclass +class ToolResultEnvelope: + kind: str + content: Any + is_error: bool = False + top_level_blocks: list[Any] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +def tool_success(content: Any, *, metadata: dict[str, Any] | None = None) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="success", + content=content, + metadata=dict(metadata or {}), + ) + + +def tool_error(content: str, *, metadata: dict[str, Any] | None = None) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="error", + content=content, + is_error=True, + metadata=dict(metadata or {}), + ) + + +def tool_permission_denied( + content: str, + *, + top_level_blocks: list[Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="permission_denied", + content=content, + is_error=True, + top_level_blocks=list(top_level_blocks or []), + metadata=dict(metadata or {}), + ) + + +def tool_permission_request( + content: str, + *, + top_level_blocks: list[Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="permission_request", + content=content, + top_level_blocks=list(top_level_blocks or []), + metadata=dict(metadata or {}), + ) + + +def materialize_tool_message( + envelope: ToolResultEnvelope, + *, + tool_call_id: str, + name: str, + source: str, +) -> ToolMessage: + additional_kwargs = { + "tool_result_meta": { + "kind": envelope.kind, + "source": source, + "top_level_blocks": list(envelope.top_level_blocks), + **dict(envelope.metadata), + } + } + return ToolMessage( + content=envelope.content, + tool_call_id=tool_call_id, + name=name, + additional_kwargs=additional_kwargs, + ) diff --git a/core/runtime/validator.py b/core/runtime/validator.py index 84e678d07..46fa6d963 100644 --- a/core/runtime/validator.py +++ b/core/runtime/validator.py @@ -1,8 +1,45 @@ import json +import re from .errors import InputValidationError +def _required_sets(parameters: dict, key: str) -> list[list[str]]: + value = parameters.get(key, []) + if not isinstance(value, list): + return [] + sets: list[list[str]] = [] + for item in value: + if isinstance(item, dict): + required = item.get("required", []) + else: + required = item + if isinstance(required, list): + sets.append([field for field in required if isinstance(field, str)]) + return sets + + +def _required_sets_match(parameters: dict, args: dict) -> bool: + required = parameters.get("required", []) + if any(field not in args for field in required): + return False + + # @@@required-set-contract - some tools need one of several identifier sets + # before they're valid. Keep that contract in runtime metadata so + # validator/readiness stay aligned without sending unsupported top-level + # anyOf/oneOf schema to live providers. + any_of = _required_sets(parameters, "x-leon-required-any-of") or _required_sets(parameters, "anyOf") + if any_of: + return any(all(field in args for field in required) for required in any_of) + + one_of = _required_sets(parameters, "x-leon-required-one-of") or _required_sets(parameters, "oneOf") + if one_of: + matches = [required for required in one_of if all(field in args for field in required)] + return len(matches) == 1 + + return True + + class ValidationResult: def __init__(self, ok: bool, params: dict): self.ok = ok @@ -13,14 +50,43 @@ class ToolValidator: """Three-phase tool argument validation.""" def validate(self, schema: dict, args: dict) -> ValidationResult: - properties = schema.get("parameters", {}).get("properties", {}) - required = schema.get("parameters", {}).get("required", []) + parameters = schema.get("parameters", {}) + properties = parameters.get("properties", {}) # Phase 1: required fields - missing = [f for f in required if f not in args] - if missing: - msgs = [f"The required parameter `{f}` is missing" for f in missing] - raise InputValidationError("\n".join(msgs)) + if not _required_sets_match(parameters, args): + required = parameters.get("required", []) + missing = [f for f in required if f not in args] + if missing: + details = [ + { + "field": field, + "error_code": "REQUIRED_FIELD_MISSING", + "message": f"The required parameter `{field}` is missing", + } + for field in missing + ] + raise InputValidationError( + "\n".join(detail["message"] for detail in details), + error_code="REQUIRED_FIELD_MISSING" if len(details) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=details, + ) + any_of = _required_sets(parameters, "x-leon-required-any-of") or _required_sets(parameters, "anyOf") + one_of = _required_sets(parameters, "x-leon-required-one-of") or _required_sets(parameters, "oneOf") + if any_of: + message = f"Arguments must satisfy one of these required sets: {any_of}" + raise InputValidationError( + message, + error_code="REQUIRED_SET_UNSATISFIED", + details=[{"error_code": "REQUIRED_SET_UNSATISFIED", "message": message}], + ) + if one_of: + message = f"Arguments must satisfy exactly one of these required sets: {one_of}" + raise InputValidationError( + message, + error_code="REQUIRED_SET_UNSATISFIED", + details=[{"error_code": "REQUIRED_SET_UNSATISFIED", "message": message}], + ) # Phase 2: type check for name, val in args.items(): @@ -28,12 +94,38 @@ def validate(self, schema: dict, args: dict) -> ValidationResult: expected = prop.get("type") if expected and not self._type_matches(val, expected): actual = type(val).__name__ - raise InputValidationError(f"The parameter `{name}` type is expected as `{expected}` but provided as `{actual}`") + message = f"The parameter `{name}` type is expected as `{expected}` but provided as `{actual}`" + raise InputValidationError( + message, + error_code="INVALID_TYPE", + details=[ + { + "field": name, + "error_code": "INVALID_TYPE", + "expected": expected, + "actual": actual, + "message": message, + } + ], + ) - # Phase 3: enum validation + # Phase 3: scalar constraints + issues = self._validate_scalar_constraints(properties, args) + if issues: + raise InputValidationError( + "\n".join(str(issue["message"]) for issue in issues), + error_code=str(issues[0]["error_code"]) if len(issues) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=issues, + ) + + # Phase 4: enum validation issues = self._validate_enum(properties, args) if issues: - raise InputValidationError(json.dumps(issues)) + raise InputValidationError( + json.dumps(issues), + error_code="INVALID_ENUM" if len(issues) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=issues, + ) return ValidationResult(ok=True, params=args) @@ -51,11 +143,77 @@ def _type_matches(self, val, expected: str) -> bool: return True return isinstance(val, expected_type) - def _validate_enum(self, properties: dict, args: dict) -> list: - issues = [] + def _validate_enum(self, properties: dict, args: dict) -> list[dict[str, object]]: + issues: list[dict[str, object]] = [] for name, val in args.items(): prop = properties.get(name, {}) enum_vals = prop.get("enum") if enum_vals and val not in enum_vals: - issues.append({"field": name, "expected": enum_vals, "got": val}) + issues.append( + { + "field": name, + "error_code": "INVALID_ENUM", + "expected": enum_vals, + "got": val, + "message": f"The parameter `{name}` must be one of {enum_vals}, got {val!r}", + } + ) + return issues + + def _validate_scalar_constraints(self, properties: dict, args: dict) -> list[dict[str, object]]: + issues: list[dict[str, object]] = [] + for name, val in args.items(): + prop = properties.get(name, {}) + if isinstance(val, str): + min_length = prop.get("minLength") + if isinstance(min_length, int) and len(val) < min_length: + issues.append( + { + "field": name, + "error_code": "STRING_TOO_SHORT", + "message": f"The parameter `{name}` must be at least {min_length} characters long", + "minimum": min_length, + } + ) + max_length = prop.get("maxLength") + if isinstance(max_length, int) and len(val) > max_length: + issues.append( + { + "field": name, + "error_code": "STRING_TOO_LONG", + "message": f"The parameter `{name}` must be at most {max_length} characters long", + "maximum": max_length, + } + ) + pattern = prop.get("pattern") + if isinstance(pattern, str) and re.search(pattern, val) is None: + issues.append( + { + "field": name, + "error_code": "PATTERN_MISMATCH", + "message": f"The parameter `{name}` must match pattern `{pattern}`", + "pattern": pattern, + } + ) + if isinstance(val, (int, float)) and not isinstance(val, bool): + minimum = prop.get("minimum") + if isinstance(minimum, (int, float)) and val < minimum: + issues.append( + { + "field": name, + "error_code": "NUMBER_TOO_SMALL", + "message": f"The parameter `{name}` must be at least {minimum}", + "minimum": minimum, + } + ) + maximum = prop.get("maximum") + if isinstance(maximum, (int, float)) and val > maximum: + issues.append( + { + "field": name, + "error_code": "NUMBER_TOO_LARGE", + "message": f"The parameter `{name}` must be at most {maximum}", + "maximum": maximum, + } + ) return issues diff --git a/core/runtime/visibility.py b/core/runtime/visibility.py index 5c1a31f5d..cd1e1467f 100644 --- a/core/runtime/visibility.py +++ b/core/runtime/visibility.py @@ -1,7 +1,8 @@ -"""Owner visibility — v3: everything is always visible. +"""Owner visibility helpers. -v2 had a two-layer context/showing state machine for private context. -v3 removes private context entirely — all messages are shown to the owner. +v3 default is "visible unless explicitly hidden". Some backend paths still emit +durable hidden owner messages (for example AskUserQuestion answer anchors), so +this layer must preserve an already-declared display contract. """ from __future__ import annotations @@ -11,23 +12,8 @@ _ALWAYS_SHOWING = {"showing": True} -def compute_visibility(source: str, is_steer: bool, context: str) -> tuple[bool, str]: - """Always visible. Kept for call-site compatibility during transition.""" - return True, "owner" - - -def message_visibility(context: str, tool_names: list[str] | None = None) -> dict[str, Any]: - """Always visible.""" - return _ALWAYS_SHOWING - - -def tool_event_visibility(context: str, tool_name: str) -> dict[str, Any]: - """Always visible.""" - return _ALWAYS_SHOWING - - def annotate_owner_visibility(messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], str]: - """Annotate every message as visible.""" + """Annotate messages as visible unless they already carry display metadata.""" for msg in messages: - msg["display"] = _ALWAYS_SHOWING + msg.setdefault("display", _ALWAYS_SHOWING) return messages, "owner" diff --git a/core/tools/command/base.py b/core/tools/command/base.py index e716420b2..7a1356081 100644 --- a/core/tools/command/base.py +++ b/core/tools/command/base.py @@ -4,7 +4,25 @@ This module re-exports for backward compatibility. """ +from __future__ import annotations + from sandbox.interfaces.executor import * # noqa: F401,F403 from sandbox.interfaces.executor import AsyncCommand, BaseExecutor, ExecuteResult __all__ = ["BaseExecutor", "ExecuteResult", "AsyncCommand"] + + +def describe_execution_exception(exc: Exception) -> str: + detail = str(exc).strip() + if detail: + return detail + return exc.__class__.__name__ + + +def require_subprocess_pipe[TPipe](pipe: TPipe | None, name: str) -> TPipe: + # @@@persistent-shell-pipe-contract - persistent shell executors only work + # when asyncio created real stdio pipes; fail loudly instead of pretending + # optional streams are always present. + if pipe is None: + raise RuntimeError(f"Subprocess missing {name} pipe") + return pipe diff --git a/core/tools/command/bash/executor.py b/core/tools/command/bash/executor.py index d559970d0..c4c060f53 100644 --- a/core/tools/command/bash/executor.py +++ b/core/tools/command/bash/executor.py @@ -6,7 +6,7 @@ import os import uuid -from ..base import AsyncCommand, BaseExecutor, ExecuteResult +from ..base import AsyncCommand, BaseExecutor, ExecuteResult, require_subprocess_pipe _RUNNING_COMMANDS: dict[str, AsyncCommand] = {} @@ -35,8 +35,9 @@ async def _ensure_session(self, env: dict[str, str]) -> asyncio.subprocess.Proce cwd=self._current_cwd, ) # Disable PS1 prompt - self._session.stdin.write(b"export PS1=''\n") - await self._session.stdin.drain() + stdin = require_subprocess_pipe(self._session.stdin, "stdin") + stdin.write(b"export PS1=''\n") + await stdin.drain() return self._session async def _send_command(self, proc: asyncio.subprocess.Process, command: str) -> tuple[str, str, int]: @@ -44,14 +45,16 @@ async def _send_command(self, proc: asyncio.subprocess.Process, command: str) -> marker = f"__END_{uuid.uuid4().hex[:8]}__" full_cmd = f"{command}\necho {marker} $?\n" - proc.stdin.write(full_cmd.encode()) - await proc.stdin.drain() + stdin = require_subprocess_pipe(proc.stdin, "stdin") + stdout = require_subprocess_pipe(proc.stdout, "stdout") + stdin.write(full_cmd.encode()) + await stdin.drain() stdout_lines = [] exit_code = 0 while True: - line = await proc.stdout.readline() + line = await stdout.readline() if not line: break line_str = line.decode("utf-8", errors="replace") diff --git a/core/tools/command/hooks/dangerous_commands.py b/core/tools/command/hooks/dangerous_commands.py index 496251292..3abde2337 100644 --- a/core/tools/command/hooks/dangerous_commands.py +++ b/core/tools/command/hooks/dangerous_commands.py @@ -1,6 +1,7 @@ """Dangerous commands hook - blocks commands that may harm the system.""" import re +import shlex from pathlib import Path from typing import Any @@ -40,6 +41,32 @@ class DangerousCommandsHook(BashHook): r"\bssh\b", ] + DEFAULT_BLOCKED_BASE_COMMANDS = { + "rmdir", + "chmod", + "chown", + "sudo", + "su", + "kill", + "pkill", + "reboot", + "shutdown", + "mkfs", + "dd", + } + NETWORK_BASE_COMMANDS = { + "curl", + "wget", + "scp", + "sftp", + "rsync", + "ssh", + } + OPERATOR_TOKENS = {";", ";;", "&", "&&", "|", "||", "(", ")"} + ENV_ASSIGN_RE = re.compile(r"^[A-Za-z_]\w*=") + ANSI_C_QUOTE_RE = re.compile(r"\$'[^']*'") + LOCALE_QUOTE_RE = re.compile(r'\$"[^"]*"') + def __init__( self, workspace_root: Path | str | None = None, @@ -58,13 +85,140 @@ def __init__( patterns.extend(custom_blocked) self.compiled_patterns = [re.compile(p, re.IGNORECASE) for p in patterns] + self.blocked_base_commands = set(self.DEFAULT_BLOCKED_BASE_COMMANDS) + if block_network: + self.blocked_base_commands.update(self.NETWORK_BASE_COMMANDS) if verbose: print(f"[DangerousCommands] Loaded {len(self.compiled_patterns)} blocked command patterns") + @staticmethod + def _unquoted_command(command: str) -> str: + # @@@bash-hook-unquoted-scan - dangerous regexes should only inspect executable shell surface, + # not literal text inside quotes. + pieces: list[str] = [] + in_single = False + in_double = False + escaped = False + + for char in command: + if escaped: + if not in_single and not in_double: + pieces.append(char) + escaped = False + continue + + if char == "\\" and not in_single: + if not in_double: + pieces.append(char) + escaped = True + continue + + if char == "'" and not in_double: + in_single = not in_single + continue + + if char == '"' and not in_single: + in_double = not in_double + continue + + if not in_single and not in_double and char == "#": + prev = pieces[-1] if pieces else "" + if not prev or prev.isspace(): + break + + if not in_single and not in_double: + pieces.append(char) + + return "".join(pieces) + + @classmethod + def _has_dangerous_rm_flags(cls, tokens: list[str], start: int) -> bool: + recursive = False + force = False + + for token in tokens[start:]: + if token in cls.OPERATOR_TOKENS: + break + if token == "--": + break + lowered = token.lower() + if lowered == "--recursive": + recursive = True + elif lowered == "--force": + force = True + elif lowered.startswith("-"): + short_flags = lowered[1:] + recursive = recursive or "r" in short_flags + force = force or "f" in short_flags + if recursive and force: + return True + + return False + + def _find_dangerous_command_word(self, command: str) -> str | None: + try: + lexer = shlex.shlex(command, posix=True, punctuation_chars=";&|()<>") + except ValueError: + return None + lexer.whitespace_split = True + lexer.commenters = "#" + tokens = list(lexer) + command_position = True + + for index, token in enumerate(tokens): + if token in self.OPERATOR_TOKENS: + command_position = True + continue + + if token in {"<", ">", ">>", "<<", "<<<", "<>", ">|", "&>", "2>", "1>"}: + command_position = False + continue + + if not command_position: + continue + + if self.ENV_ASSIGN_RE.match(token): + continue + + if token in self.blocked_base_commands: + return token + + if token == "rm" and self._has_dangerous_rm_flags(tokens, index + 1): + return "rm -rf" + + command_position = False + + return None + def check_command(self, command: str, context: dict[str, Any]) -> HookResult: + stripped = command.strip() + if self.ANSI_C_QUOTE_RE.search(stripped) or self.LOCALE_QUOTE_RE.search(stripped): + return HookResult.block_command( + error_message=( + f"❌ SECURITY ERROR: Dangerous command detected\n" + f" Command: {command[:100]}\n" + f" Reason: Obfuscated shell quoting is blocked for security reasons\n" + f" Pattern: raw_obfuscation:$quote\n" + f" 💡 If you need to perform this operation, ask the user for permission." + ) + ) + + dangerous_word = self._find_dangerous_command_word(stripped) + if dangerous_word is not None: + return HookResult.block_command( + error_message=( + f"❌ SECURITY ERROR: Dangerous command detected\n" + f" Command: {command[:100]}\n" + f" Reason: This command is blocked for security reasons\n" + f" Pattern: command_word:{dangerous_word}\n" + f" 💡 If you need to perform this operation, ask the user for permission." + ) + ) + + scanned = self._unquoted_command(stripped) for pattern in self.compiled_patterns: - if pattern.search(command.strip()): + if pattern.search(scanned): return HookResult.block_command( error_message=( f"❌ SECURITY ERROR: Dangerous command detected\n" diff --git a/core/tools/command/hooks/loader.py b/core/tools/command/hooks/loader.py index d46ee78b9..449b2901c 100644 --- a/core/tools/command/hooks/loader.py +++ b/core/tools/command/hooks/loader.py @@ -39,13 +39,3 @@ def load_hooks( hooks.sort(key=lambda h: h.priority) print(f"[BashHooks] Total {len(hooks)} hooks loaded") return hooks - - -def discover_hooks() -> list[str]: - """Discover all available hook plugins without loading them.""" - hooks_dir = Path(__file__).parent - return [ - py_file.stem - for py_file in hooks_dir.glob("*.py") - if not py_file.name.startswith("_") and py_file.name not in ["base.py", "loader.py"] - ] diff --git a/core/tools/command/middleware.py b/core/tools/command/middleware.py index dcd6453a4..c01d2e71d 100644 --- a/core/tools/command/middleware.py +++ b/core/tools/command/middleware.py @@ -9,7 +9,7 @@ import json import logging from pathlib import Path -from typing import Any +from typing import Any, Literal from langchain.agents.middleware import AgentMiddleware, AgentState from langchain.agents.middleware.types import ModelRequest, ModelResponse @@ -18,7 +18,7 @@ from sandbox.shell_output import normalize_pty_result -from .base import AsyncCommand, BaseExecutor +from .base import AsyncCommand, BaseExecutor, describe_execution_exception from .dispatcher import get_executor, get_shell_info logger = logging.getLogger(__name__) @@ -203,7 +203,7 @@ async def _execute_blocking(self, command_line: str, work_dir: str | None, timeo env=self.env, ) except Exception as e: - return f"Error executing command: {e}" + return f"Error executing command: {describe_execution_exception(e)}" return result.to_tool_result() def set_agent(self, agent: Any) -> None: @@ -219,7 +219,7 @@ async def _execute_async(self, command_line: str, work_dir: str | None, timeout: env=self.env, ) except Exception as e: - return f"Error starting async command: {e}" + return f"Error starting async command: {describe_execution_exception(e)}" # Emit task_start event runtime = getattr(self._agent, "runtime", None) if self._agent else None @@ -319,7 +319,7 @@ async def _monitor_async_command(self, command_id: str, command_line: str, runti async def _inject_command_notification( self, command_id: str, - status: str, + status: Literal["completed", "failed"], exit_code: int, command_line: str, output: str, diff --git a/core/tools/command/service.py b/core/tools/command/service.py index 475289b9c..e1927b82b 100644 --- a/core/tools/command/service.py +++ b/core/tools/command/service.py @@ -15,11 +15,13 @@ import asyncio import json import logging +from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry -from core.tools.command.base import BaseExecutor +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.tool_result import ToolResultEnvelope, tool_permission_denied +from core.tools.command.base import BaseExecutor, describe_execution_exception from core.tools.command.dispatcher import get_executor logger = logging.getLogger(__name__) @@ -61,35 +63,39 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Bash", mode=ToolMode.INLINE, - schema={ - "name": "Bash", - "description": ("Execute shell command. OS auto-detects shell (mac->zsh, linux->bash, win->powershell)."), - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "Command to execute", - }, - "description": { - "type": "string", - "description": ( - "Human-readable description of what this command does. " - "Required when run_in_background is true; shown in the background task indicator." - ), - }, - "run_in_background": { - "type": "boolean", - "description": "Run in background (default: false). Returns task ID for status queries.", - }, - "timeout": { - "type": "integer", - "description": "Timeout in milliseconds (default: 120000)", - }, + schema=make_tool_schema( + name="Bash", + description=( + "Execute shell command (zsh on macOS, bash on Linux, PowerShell on Windows). " + "Default timeout 120s (max 600s). Dangerous commands are blocked. " + "Prefer dedicated tools over Bash: Read over cat, Grep over grep/rg, Glob over find/ls, Edit over sed/awk." + ), + properties={ + "command": { + "type": "string", + "description": "Command to execute", + "minLength": 1, + }, + "description": { + "type": "string", + "description": ( + "Human-readable description of what this command does. " + "Required when run_in_background is true; shown in the background task indicator." + ), + }, + "run_in_background": { + "type": "boolean", + "description": "Run in background (default: false). Returns task ID for status queries.", + }, + "timeout": { + "type": "integer", + "description": "Timeout in milliseconds (default: 120000)", + "minimum": 1, + "maximum": 600000, }, - "required": ["command"], }, - }, + required=["command"], + ), handler=self._bash, source="CommandService", ) @@ -113,10 +119,13 @@ async def _bash( description: str = "", run_in_background: bool = False, timeout: int = DEFAULT_TIMEOUT_MS, - ) -> str: + ) -> str | ToolResultEnvelope: allowed, error_msg = self._check_hooks(command) if not allowed: - return error_msg + return tool_permission_denied( + error_msg, + metadata={"policy": "command_hook"}, + ) work_dir = None if self._executor.runtime_owns_cwd else str(self.workspace_root) timeout_secs = timeout / 1000.0 @@ -135,7 +144,7 @@ async def _execute_blocking(self, command: str, work_dir: str | None, timeout_se env=self.env, ) except Exception as e: - return f"Error executing command: {e}" + return f"Error executing command: {describe_execution_exception(e)}" return result.to_tool_result() async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: float, description: str = "") -> str: @@ -146,7 +155,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: env=self.env, ) except Exception as e: - return f"Error starting async command: {e}" + return f"Error starting async command: {describe_execution_exception(e)}" task_id = async_cmd.command_id @@ -156,7 +165,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: self._background_runs[task_id] = _BashBackgroundRun(async_cmd, command, description=description) # Build emit_fn for SSE task lifecycle events - emit_fn = None + emit_fn: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None parent_thread_id = None try: from backend.web.event_bus import get_event_bus @@ -178,7 +187,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: # Emit task_start so the frontend dot lights up immediately if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_start", "data": json.dumps( @@ -193,6 +202,8 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: ), } ) + if asyncio.iscoroutine(emission): + await emission if parent_thread_id: asyncio.create_task( @@ -207,7 +218,7 @@ async def _notify_bash_completion( async_cmd: Any, command: str, parent_thread_id: str, - emit_fn: Any = None, + emit_fn: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None, description: str = "", ) -> None: """Poll until async command finishes, then enqueue CommandNotification.""" @@ -220,7 +231,7 @@ async def _notify_bash_completion( # Emit task_done so the frontend dot updates in real time if emit_fn is not None: try: - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -232,6 +243,8 @@ async def _notify_bash_completion( ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: pass diff --git a/core/tools/command/zsh/executor.py b/core/tools/command/zsh/executor.py index 6990531aa..2d19be8ec 100644 --- a/core/tools/command/zsh/executor.py +++ b/core/tools/command/zsh/executor.py @@ -6,7 +6,7 @@ import os import uuid -from ..base import AsyncCommand, BaseExecutor, ExecuteResult +from ..base import AsyncCommand, BaseExecutor, ExecuteResult, require_subprocess_pipe _RUNNING_COMMANDS: dict[str, AsyncCommand] = {} @@ -35,8 +35,9 @@ async def _ensure_session(self, env: dict[str, str]) -> asyncio.subprocess.Proce cwd=self._current_cwd, ) # Disable PS1 prompt - self._session.stdin.write(b"export PS1=''\n") - await self._session.stdin.drain() + stdin = require_subprocess_pipe(self._session.stdin, "stdin") + stdin.write(b"export PS1=''\n") + await stdin.drain() return self._session async def _send_command(self, proc: asyncio.subprocess.Process, command: str) -> tuple[str, str, int]: @@ -44,14 +45,16 @@ async def _send_command(self, proc: asyncio.subprocess.Process, command: str) -> marker = f"__END_{uuid.uuid4().hex[:8]}__" full_cmd = f"{command}\necho {marker} $?\n" - proc.stdin.write(full_cmd.encode()) - await proc.stdin.drain() + stdin = require_subprocess_pipe(proc.stdin, "stdin") + stdout = require_subprocess_pipe(proc.stdout, "stdout") + stdin.write(full_cmd.encode()) + await stdin.drain() stdout_lines = [] exit_code = 0 while True: - line = await proc.stdout.readline() + line = await stdout.readline() if not line: break line_str = line.decode("utf-8", errors="replace") diff --git a/core/tools/cron/service.py b/core/tools/cron/service.py new file mode 100644 index 000000000..026c7d9be --- /dev/null +++ b/core/tools/cron/service.py @@ -0,0 +1,102 @@ +"""CronToolService — agent-callable cron job CRUD on top of existing backend service.""" + +from __future__ import annotations + +import json +from typing import Any + +from croniter import croniter + +from backend.web.services import cron_job_service +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +CRON_CREATE_SCHEMA = make_tool_schema( + name="CronCreate", + description="Create a cron job using the existing Mycel cron_jobs substrate.", + properties={ + "name": {"type": "string", "description": "Human-readable cron job name", "minLength": 1}, + "cron_expression": { + "type": "string", + "description": "Standard 5-field cron expression", + "minLength": 1, + }, + "description": {"type": "string", "description": "Optional cron job description"}, + "task_template": { + "type": "string", + "description": "JSON string template used when the cron job creates a task", + }, + "enabled": {"type": "boolean", "description": "Whether the cron job starts enabled"}, + }, + required=["name", "cron_expression"], +) + +CRON_DELETE_SCHEMA = make_tool_schema( + name="CronDelete", + description="Delete a cron job by ID.", + properties={ + "job_id": {"type": "string", "description": "Cron job ID returned by CronCreate", "minLength": 1}, + }, + required=["job_id"], +) + +CRON_LIST_SCHEMA = make_tool_schema( + name="CronList", + description="List all cron jobs in the current Mycel cron_jobs substrate.", + properties={}, +) + + +class CronToolService: + def __init__(self, registry: ToolRegistry): + self._register(registry) + + def _register(self, registry: ToolRegistry) -> None: + for name, schema, handler, read_only in [ + ("CronCreate", CRON_CREATE_SCHEMA, self._create, False), + ("CronDelete", CRON_DELETE_SCHEMA, self._delete, False), + ("CronList", CRON_LIST_SCHEMA, self._list, True), + ]: + registry.register( + ToolEntry( + name=name, + mode=ToolMode.DEFERRED, + schema=schema, + handler=handler, + source="CronToolService", + is_concurrency_safe=read_only, + is_read_only=read_only, + ) + ) + + def _create(self, **args: Any) -> str: + name = str(args.get("name", "")).strip() + cron_expression = str(args.get("cron_expression", "")).strip() + if not croniter.is_valid(cron_expression): + raise ValueError(f"Invalid cron expression: {cron_expression!r}") + + task_template = args.get("task_template", "{}") + if isinstance(task_template, str): + try: + json.loads(task_template) + except json.JSONDecodeError as exc: + raise ValueError("task_template must be valid JSON") from exc + + item = cron_job_service.create_cron_job( + name=name, + cron_expression=cron_expression, + description=str(args.get("description", "")), + task_template=task_template, + enabled=int(bool(args.get("enabled", True))), + ) + return json.dumps({"item": item}, ensure_ascii=False, indent=2) + + def _delete(self, **args: Any) -> str: + job_id = str(args.get("job_id", "")).strip() + ok = cron_job_service.delete_cron_job(job_id) + if not ok: + raise ValueError(f"Cron job not found: {job_id}") + return json.dumps({"ok": True, "id": job_id}, ensure_ascii=False, indent=2) + + def _list(self, **_args: Any) -> str: + items = cron_job_service.list_cron_jobs() + return json.dumps({"items": items, "total": len(items)}, ensure_ascii=False, indent=2) diff --git a/core/tools/filesystem/local_backend.py b/core/tools/filesystem/local_backend.py index 2bad2d45b..50bbe58a0 100644 --- a/core/tools/filesystem/local_backend.py +++ b/core/tools/filesystem/local_backend.py @@ -18,14 +18,16 @@ class LocalBackend(FileSystemBackend): def read_file(self, path: str) -> FileReadResult: p = Path(path) - content = p.read_text(encoding="utf-8") + with p.open("r", encoding="utf-8", newline="") as f: + content = f.read() return FileReadResult(content=content, size=p.stat().st_size) def write_file(self, path: str, content: str) -> FileWriteResult: try: p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) - p.write_text(content, encoding="utf-8") + with p.open("w", encoding="utf-8", newline="") as f: + f.write(content) return FileWriteResult(success=True) except Exception as e: return FileWriteResult(success=False, error=str(e)) diff --git a/core/tools/filesystem/middleware.py b/core/tools/filesystem/middleware.py index 0844d892a..8519d30ea 100644 --- a/core/tools/filesystem/middleware.py +++ b/core/tools/filesystem/middleware.py @@ -13,8 +13,8 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable -from pathlib import Path +from collections.abc import Awaitable, Callable, Mapping +from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING, Any from langchain.agents.middleware.types import ( @@ -33,6 +33,28 @@ from core.operations import FileOperationRecorder +def _remote_path(path: str | Path) -> PurePosixPath: + # @@@remote-posix-path-contract - Middleware callers still hand us sandbox + # POSIX paths even when tests run on Windows, so keep validation and + # workspace comparisons in POSIX space instead of host-native path rules. + return PurePosixPath(str(path).replace("\\", "/")) + + +type ResolvedPath = Path | PurePosixPath + + +def _require_resolved_path(resolved: ResolvedPath | None) -> ResolvedPath: + if resolved is None: + raise RuntimeError("Validated filesystem path unexpectedly missing") + return resolved + + +def _require_local_path(resolved: ResolvedPath) -> Path: + if not isinstance(resolved, Path): + raise RuntimeError(f"Expected local filesystem path, got remote path: {resolved}") + return resolved + + class FileSystemMiddleware(AgentMiddleware): """FileSystem Middleware - pure middleware implementation of file operations. @@ -80,7 +102,12 @@ def __init__( backend = LocalBackend() self.backend = backend - self.workspace_root = Path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() + if backend.is_remote: + self.workspace_root: ResolvedPath = _remote_path(workspace_root) + else: + local_workspace_root = Path(workspace_root).resolve() + local_workspace_root.mkdir(parents=True, exist_ok=True) + self.workspace_root = local_workspace_root self.max_file_size = max_file_size self.allowed_extensions = allowed_extensions self.hooks = hooks or [] @@ -91,13 +118,10 @@ def __init__( "multi_edit": True, "list_dir": True, } - self._read_files: dict[Path, float | None] = {} + self._read_files: dict[Path | PurePosixPath, float | None] = {} self.operation_recorder = operation_recorder self.verbose = verbose - self.extra_allowed_paths: list[Path] = [Path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] - - if not backend.is_remote: - self.workspace_root.mkdir(parents=True, exist_ok=True) + self.extra_allowed_paths = [_remote_path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] if verbose: backend_name = type(backend).__name__ @@ -105,17 +129,20 @@ def __init__( if self.hooks: print(f"[FileSystemMiddleware] Loaded {len(self.hooks)} hooks") - def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | None]: + def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | PurePosixPath | None]: """Validate path for file operations. Returns: (is_valid, error_message, resolved_path) """ - if not Path(path).is_absolute(): + if self.backend.is_remote: + if not _remote_path(path).is_absolute(): + return False, f"Path must be absolute: {path}", None + elif not Path(path).is_absolute(): return False, f"Path must be absolute: {path}", None try: - resolved = Path(path) if self.backend.is_remote else Path(path).resolve() + resolved = _remote_path(path) if self.backend.is_remote else Path(path).resolve() except Exception as e: return False, f"Invalid path: {path} ({e})", None @@ -146,7 +173,7 @@ def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | N return True, "", resolved - def _check_file_staleness(self, resolved: Path) -> str | None: + def _check_file_staleness(self, resolved: Path | PurePosixPath) -> str | None: """Check if file has been modified since last read. Returns: @@ -165,7 +192,7 @@ def _check_file_staleness(self, resolved: Path) -> str | None: return None - def _update_file_tracking(self, resolved: Path) -> None: + def _update_file_tracking(self, resolved: Path | PurePosixPath) -> None: """Update mtime tracking after successful file operation.""" self._read_files[resolved] = self.backend.file_mtime(str(resolved)) @@ -203,7 +230,7 @@ def _record_operation( except Exception as e: raise RuntimeError(f"[FileSystemMiddleware] Failed to record operation: {e}") from e - def _count_lines(self, resolved: Path) -> int: + def _count_lines(self, resolved: Path | PurePosixPath) -> int: """Count total lines in a file (for error messages).""" try: raw = self.backend.read_file(str(resolved)) @@ -222,6 +249,7 @@ def _read_file_impl(self, file_path: str, offset: int = 0, limit: int | None = N if not is_valid: return ReadResult(file_path=file_path, file_type=None, error=error) # type: ignore[arg-type] + resolved = _require_resolved_path(resolved) file_size = self.backend.file_size(str(resolved)) # Absolute limit — always reject (even with offset/limit) @@ -265,7 +293,13 @@ def _read_file_impl(self, file_path: str, offset: int = 0, limit: int | None = N if isinstance(self.backend, LocalBackend): limits = ReadLimits() - result = read_file_dispatch(path=resolved, limits=limits, offset=offset if offset > 0 else None, limit=limit) + local_resolved = _require_local_path(resolved) + result = read_file_dispatch( + path=local_resolved, + limits=limits, + offset=offset if offset > 0 else None, + limit=limit, + ) if not result.error: self._update_file_tracking(resolved) return result @@ -314,6 +348,7 @@ def _write_file_impl(self, file_path: str, content: str) -> str: if not is_valid: return error + resolved = _require_resolved_path(resolved) if self.backend.file_exists(str(resolved)): return f"File already exists: {file_path}\nUse edit_file to modify existing files" @@ -342,6 +377,7 @@ def _edit_file_impl(self, file_path: str, old_string: str, new_string: str) -> s if not is_valid: return error + resolved = _require_resolved_path(resolved) if not self.backend.file_exists(str(resolved)): return f"File not found: {file_path}" @@ -388,6 +424,7 @@ def _multi_edit_impl(self, file_path: str, edits: list[dict[str, str]]) -> str: if not is_valid: return error + resolved = _require_resolved_path(resolved) if not self.backend.file_exists(str(resolved)): return f"File not found: {file_path}" @@ -435,6 +472,7 @@ def _list_dir_impl(self, directory_path: str) -> str: if not is_valid: return error + resolved = _require_resolved_path(resolved) if not self.backend.is_dir(str(resolved)): if self.backend.file_exists(str(resolved)): return f"Not a directory: {directory_path}" @@ -461,7 +499,7 @@ def _list_dir_impl(self, directory_path: str) -> str: except Exception as e: return f"Error listing directory: {e}" - def _get_tool_schemas(self) -> list[dict]: + def _get_tool_schemas(self) -> list[dict[str, Any]]: """获取文件系统工具 schema(sync/async 共享)""" return [ { @@ -571,12 +609,12 @@ def _get_tool_schemas(self) -> list[dict]: "parameters": { "type": "object", "properties": { - "directory_path": { + "path": { "type": "string", "description": "Absolute directory path (e.g., /path/to/dir). Do NOT use '.' or '..'", }, }, - "required": ["directory_path"], + "required": ["path"], }, }, }, @@ -602,7 +640,7 @@ async def awrap_model_call( tools.extend(self._get_tool_schemas()) return await handler(request.override(tools=tools)) - def _handle_tool_call(self, tool_call: dict) -> ToolMessage | None: + def _handle_tool_call(self, tool_call: Mapping[str, Any]) -> ToolMessage | None: """Handle filesystem tool calls. Returns ToolMessage if handled, None otherwise.""" tool_name = tool_call.get("name") args = tool_call.get("args", {}) @@ -633,7 +671,7 @@ def _handle_tool_call(self, tool_call: dict) -> ToolMessage | None: return ToolMessage(content=result, tool_call_id=tool_call_id) if tool_name == self.TOOL_LIST_DIR: - result = self._list_dir_impl(directory_path=args.get("directory_path", "")) + result = self._list_dir_impl(directory_path=args.get("path", "")) return ToolMessage(content=result, tool_call_id=tool_call_id) return None diff --git a/core/tools/filesystem/read/dispatcher.py b/core/tools/filesystem/read/dispatcher.py index f880e60e1..0119f424e 100644 --- a/core/tools/filesystem/read/dispatcher.py +++ b/core/tools/filesystem/read/dispatcher.py @@ -22,6 +22,7 @@ def read_file( limits: ReadLimits | None = None, offset: int | None = None, limit: int | None = None, + pages: str | None = None, ) -> ReadResult: """ Read file with type-specific handling. @@ -38,6 +39,7 @@ def read_file( limits: ReadLimits configuration (uses defaults if None) offset: Start line for text files (1-indexed) limit: Number of lines for text files + pages: Optional page range for document files, e.g. "1" or "3-5" Returns: ReadResult with content and metadata @@ -68,7 +70,8 @@ def read_file( return read_binary(path) if file_type == FileType.DOCUMENT: - return _read_document(path, limits, offset, limit) + start_page, limit_pages = _parse_pages_arg(pages, offset, limit) + return _read_document(path, limits, start_page, limit_pages) if file_type == FileType.NOTEBOOK: return read_notebook(path, limits, start_cell=offset, limit_cells=limit) @@ -79,6 +82,32 @@ def read_file( return read_text(path, limits, offset, limit) +def _parse_pages_arg( + pages: str | None, + offset: int | None, + limit: int | None, +) -> tuple[int | None, int | None]: + if pages is None: + return offset, limit + + raw = pages.strip() + if not raw: + raise ValueError("pages must not be empty") + + if "-" in raw: + start_raw, end_raw = raw.split("-", 1) + start_page = int(start_raw) + end_page = int(end_raw) + if start_page <= 0 or end_page < start_page: + raise ValueError(f"Invalid pages range: {pages}") + return start_page, end_page - start_page + 1 + + start_page = int(raw) + if start_page <= 0: + raise ValueError(f"Invalid page number: {pages}") + return start_page, 1 + + def _read_document( path: Path, limits: ReadLimits, diff --git a/core/tools/filesystem/read/readers/pdf.py b/core/tools/filesystem/read/readers/pdf.py index 6f43eabfa..9a1f58bb5 100644 --- a/core/tools/filesystem/read/readers/pdf.py +++ b/core/tools/filesystem/read/readers/pdf.py @@ -3,11 +3,14 @@ from __future__ import annotations from pathlib import Path +from typing import Any from core.tools.filesystem.read.types import FileType, ReadLimits, ReadResult +_pymupdf: Any | None = None + try: - import pymupdf + import pymupdf as _pymupdf HAS_PYMUPDF = True except ImportError: @@ -34,6 +37,8 @@ def read_pdf( """ if not HAS_PYMUPDF: return _no_pymupdf_result(path) + if _pymupdf is None: + raise RuntimeError("pymupdf import unexpectedly unavailable") stat = path.stat() result = ReadResult( @@ -43,7 +48,7 @@ def read_pdf( ) try: - doc = pymupdf.open(path) + doc = _pymupdf.open(path) except Exception as e: result.error = f"Error opening PDF: {e}" return result diff --git a/core/tools/filesystem/read/readers/pptx.py b/core/tools/filesystem/read/readers/pptx.py index 822f29a37..7f2dde962 100644 --- a/core/tools/filesystem/read/readers/pptx.py +++ b/core/tools/filesystem/read/readers/pptx.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +from typing import Any, cast from core.tools.filesystem.read.types import FileType, ReadLimits, ReadResult @@ -43,7 +44,9 @@ def read_pptx( ) try: - prs = Presentation(path) + # @@@pptx-callable-seam - python-pptx exports Presentation as a factory function at runtime, + # but pyright sees a module-like surface here. Keep the third-party seam local. + prs = cast(Any, Presentation)(str(path)) except Exception as e: result.error = f"Error opening PPTX: {e}" return result diff --git a/core/tools/filesystem/service.py b/core/tools/filesystem/service.py index a8cf1c9c6..ecfa0b7c5 100644 --- a/core/tools/filesystem/service.py +++ b/core/tools/filesystem/service.py @@ -10,18 +10,90 @@ from __future__ import annotations import logging -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +import tempfile +import threading +from collections import OrderedDict +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from typing import TYPE_CHECKING, Any, Literal + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.tool_result import ToolResultEnvelope, tool_success from core.tools.filesystem.backend import FileSystemBackend from core.tools.filesystem.read import ReadLimits from core.tools.filesystem.read import read_file as read_file_dispatch +from core.tools.filesystem.read.readers.binary import IMAGE_EXTENSIONS, MAX_IMAGE_SIZE +from core.tools.filesystem.read.types import FileType, detect_file_type if TYPE_CHECKING: from core.operations import FileOperationRecorder logger = logging.getLogger(__name__) +DEFAULT_READ_STATE_CACHE_SIZE = 100 +ABSOLUTE_PATH_PATTERN = r"^(?:/|[A-Za-z]:[\\/])" +type ResolvedPath = Path | PurePosixPath +type ValidationResult = tuple[Literal[True], str, ResolvedPath] | tuple[Literal[False], str, None] + + +def _remote_path(path: str | Path) -> PurePosixPath: + # @@@remote-posix-path-contract - Remote filesystem tools operate on sandbox + # POSIX paths, not host-native paths. Preserve forward-slash semantics even + # when the host process is running on Windows. + return PurePosixPath(str(path).replace("\\", "/")) + + +@dataclass +class _ReadFileState: + timestamp: float | None + is_partial: bool + + +class _ReadFileStateCache: + def __init__(self, max_entries: int = DEFAULT_READ_STATE_CACHE_SIZE): + self._max_entries = max_entries + self._entries: OrderedDict[ResolvedPath, _ReadFileState] = OrderedDict() + + @staticmethod + def make_state(*, timestamp: float | None, is_partial: bool) -> _ReadFileState: + return _ReadFileState(timestamp=timestamp, is_partial=is_partial) + + def get(self, path: ResolvedPath) -> _ReadFileState | None: + state = self._entries.get(path) + if state is None: + return None + self._entries.move_to_end(path) + return state + + def set(self, path: ResolvedPath, state: _ReadFileState) -> None: + self._entries[path] = state + self._entries.move_to_end(path) + while len(self._entries) > self._max_entries: + self._entries.popitem(last=False) + + def clone(self) -> _ReadFileStateCache: + clone = _ReadFileStateCache(max_entries=self._max_entries) + clone._entries = OrderedDict( + (path, _ReadFileState(timestamp=state.timestamp, is_partial=state.is_partial)) for path, state in self._entries.items() + ) + return clone + + def merge(self, other: _ReadFileStateCache) -> None: + for path, incoming in other._entries.items(): + existing = self._entries.get(path) + if existing is None or self._is_newer(incoming, existing): + self.set( + path, + _ReadFileState(timestamp=incoming.timestamp, is_partial=incoming.is_partial), + ) + + @staticmethod + def _is_newer(incoming: _ReadFileState, existing: _ReadFileState) -> bool: + if incoming.timestamp is None: + return False + if existing.timestamp is None: + return True + return incoming.timestamp >= existing.timestamp class FileSystemService: @@ -37,7 +109,9 @@ def __init__( hooks: list[Any] | None = None, operation_recorder: FileOperationRecorder | None = None, backend: FileSystemBackend | None = None, - extra_allowed_paths: list[str | Path] | None = None, + extra_allowed_paths: Sequence[str | Path] | None = None, + max_read_cache_entries: int = DEFAULT_READ_STATE_CACHE_SIZE, + max_edit_file_size: int | None = None, ): if backend is None: from core.tools.filesystem.local_backend import LocalBackend @@ -45,15 +119,17 @@ def __init__( backend = LocalBackend() self.backend = backend - self.workspace_root = Path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() + self.workspace_root: ResolvedPath = _remote_path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() self.max_file_size = max_file_size self.allowed_extensions = allowed_extensions self.hooks = hooks or [] - self._read_files: dict[Path, float | None] = {} + self._read_files = _ReadFileStateCache(max_entries=max_read_cache_entries) + self.max_edit_file_size = max_file_size if max_edit_file_size is None else max_edit_file_size self.operation_recorder = operation_recorder - self.extra_allowed_paths: list[Path] = [Path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] + self.extra_allowed_paths = [_remote_path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] + self._edit_critical_section = threading.Lock() - if not backend.is_remote: + if not backend.is_remote and isinstance(self.workspace_root, Path): self.workspace_root.mkdir(parents=True, exist_ok=True) self._register(registry) @@ -67,30 +143,42 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Read", mode=ToolMode.INLINE, - schema={ - "name": "Read", - "description": ("Read file content (text/code/images/PDF/PPTX/Notebook). Path must be absolute."), - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "offset": { - "type": "integer", - "description": "Start line (1-indexed, optional)", - }, - "limit": { - "type": "integer", - "description": "Number of lines to read (optional)", - }, + schema=make_tool_schema( + name="Read", + description=( + "Read file content. Output uses cat -n format (line numbers starting at 1). " + "Default reads up to 2000 lines from start; use offset/limit for long files. " + "Supports images (PNG/JPG), PDF (use pages param for large PDFs), and Jupyter notebooks. " + "Path must be absolute." + ), + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "offset": { + "type": "integer", + "description": "Start line (1-indexed, optional)", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read (optional)", + }, + "pages": { + "type": "string", + "description": "Page range for PDF files (e.g. '1-5'). Max 20 pages per request.", }, - "required": ["file_path"], }, - }, + required=["file_path"], + ), handler=self._read_file, + validate_input=self._validate_read_args, source="FileSystemService", + search_hint="read view file content text code image PDF notebook", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -98,26 +186,27 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Write", mode=ToolMode.INLINE, - schema={ - "name": "Write", - "description": "Create new file. Path must be absolute. Fails if file exists.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "content": { - "type": "string", - "description": "File content", - }, + schema=make_tool_schema( + name="Write", + description="Create or overwrite a file with full content. Forces LF line endings. Path must be absolute.", + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "content": { + "type": "string", + "description": "File content", }, - "required": ["file_path", "content"], }, - }, + required=["file_path", "content"], + ), handler=self._write_file, + validate_input=self._validate_write_args, source="FileSystemService", + search_hint="create new file write content to disk", ) ) @@ -125,39 +214,39 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Edit", mode=ToolMode.INLINE, - schema={ - "name": "Edit", - "description": ( - "Edit existing file using exact string replacement. " - "MUST read file before editing. " - "old_string must be unique in file. " - "Set replace_all=true to replace all occurrences." + schema=make_tool_schema( + name="Edit", + description=( + "Edit file via exact string replacement. You MUST Read the file first. " + "old_string must match exactly one location (or use replace_all=true). " + "Does not support .ipynb files (use Write to overwrite full JSON). Path must be absolute." ), - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "old_string": { - "type": "string", - "description": "Exact string to replace", - }, - "new_string": { - "type": "string", - "description": "Replacement string", - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences (default: false)", - }, + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "old_string": { + "type": "string", + "description": "Exact string to replace", + }, + "new_string": { + "type": "string", + "description": "Replacement string", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences (default: false)", }, - "required": ["file_path", "old_string", "new_string"], }, - }, + required=["file_path", "old_string", "new_string"], + ), handler=self._edit_file, + validate_input=self._validate_edit_args, source="FileSystemService", + search_hint="edit modify replace string in existing file", ) ) @@ -165,22 +254,25 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="list_dir", mode=ToolMode.INLINE, - schema={ - "name": "list_dir", - "description": "List directory contents. Path must be absolute.", - "parameters": { - "type": "object", - "properties": { - "directory_path": { - "type": "string", - "description": "Absolute directory path", - }, + schema=make_tool_schema( + name="list_dir", + description="List directory contents (files and subdirectories, non-recursive). Path must be absolute.", + properties={ + "path": { + "type": "string", + "description": "Absolute directory path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, }, - "required": ["directory_path"], }, - }, + required=["path"], + ), handler=self._list_dir, + validate_input=self._validate_list_dir_args, source="FileSystemService", + search_hint="list directory contents browse folder", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -188,12 +280,15 @@ def _register(self, registry: ToolRegistry) -> None: # Path validation (reused from middleware) # ------------------------------------------------------------------ - def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | None]: - if not Path(path).is_absolute(): + def _validate_path(self, path: str, operation: str) -> ValidationResult: + if self.backend.is_remote: + if not _remote_path(path).is_absolute(): + return False, f"Path must be absolute: {path}", None + elif not Path(path).is_absolute(): return False, f"Path must be absolute: {path}", None try: - resolved = Path(path) if self.backend.is_remote else Path(path).resolve() + resolved = _remote_path(path) if self.backend.is_remote else Path(path).resolve() except Exception as e: return False, f"Invalid path: {path} ({e})", None @@ -224,10 +319,159 @@ def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | N return True, "", resolved - def _check_file_staleness(self, resolved: Path) -> str | None: - if resolved not in self._read_files: - return "File has not been read yet. Read it first before writing to it." - stored_mtime = self._read_files[resolved] + def _validation_error(self, message: str, error_code: str) -> dict[str, object]: + return { + "result": False, + "message": message, + "errorCode": error_code, + } + + def _path_validation_error(self, message: str) -> dict[str, object]: + # @@@filesystem-validation-codes - Keep the pre-execution path failure + # mapping centralized so the runner can surface stable structured + # codes instead of ad-hoc handler strings on the highest-traffic tools. + if message.startswith("Path must be absolute:"): + return self._validation_error(message, "PATH_NOT_ABSOLUTE") + if message.startswith("Invalid path:"): + return self._validation_error(message, "INVALID_PATH") + if message.startswith("Path outside workspace"): + return self._validation_error(message, "PATH_OUTSIDE_WORKSPACE") + if message.startswith("File type not allowed:"): + return self._validation_error(message, "FILE_TYPE_NOT_ALLOWED") + return self._validation_error(message, "INVALID_PATH") + + def _validate_existing_path(self, path: str, operation: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + is_valid, error, resolved = self._validate_path(path, operation) + if not is_valid: + return self._path_validation_error(error), None + assert resolved is not None + return None, resolved + + def _validation_message(self, error: dict[str, object]) -> str: + return str(error["message"]) + + def _read_preflight( + self, + *, + file_path: str, + offset: int = 0, + limit: int | None = None, + pages: str | None = None, + ) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(file_path, "read") + if error is not None: + return error, None + assert resolved is not None + + file_size = self.backend.file_size(str(resolved)) + if file_size is not None and file_size > self.max_file_size: + return ( + self._validation_error( + f"File too large: {file_size:,} bytes (max: {self.max_file_size:,} bytes)", + "FILE_TOO_LARGE", + ), + None, + ) + + has_pagination = offset > 0 or limit is not None or pages is not None + if not has_pagination and file_size is not None: + limits = ReadLimits() + if file_size > limits.max_size_bytes: + total_lines = self._count_lines(resolved) + return ( + self._validation_error( + ( + f"File content ({file_size:,} bytes) exceeds maximum allowed size ({limits.max_size_bytes:,} bytes).\n" + f"Use offset and limit parameters to read specific sections.\n" + f"Total lines: {total_lines}" + ), + "READ_REQUIRES_PAGINATION", + ), + None, + ) + estimated_tokens = file_size // 4 + if estimated_tokens > limits.max_tokens: + total_lines = self._count_lines(resolved) + return ( + self._validation_error( + ( + f"File content (~{estimated_tokens:,} tokens) exceeds maximum allowed tokens ({limits.max_tokens:,}).\n" + f"Use offset and limit parameters to read specific sections.\n" + f"Total lines: {total_lines}" + ), + "READ_REQUIRES_PAGINATION", + ), + None, + ) + + return None, resolved + + def _edit_preflight(self, *, file_path: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(file_path, "edit") + if error is not None: + return error, None + assert resolved is not None + + if resolved.suffix.lower() == ".ipynb": + return ( + self._validation_error( + "Notebook files (.ipynb) are not supported by Edit. Use Write to overwrite the full JSON.", + "NOTEBOOK_EDIT_UNSUPPORTED", + ), + None, + ) + + file_size = self.backend.file_size(str(resolved)) + if file_size is not None and file_size > self.max_edit_file_size: + return ( + self._validation_error( + f"File too large for Edit: {file_size:,} bytes (max: {self.max_edit_file_size:,} bytes)", + "FILE_TOO_LARGE", + ), + None, + ) + + return None, resolved + + def _list_dir_preflight(self, *, path: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(path, "list") + if error is not None: + return error, None + assert resolved is not None + if not self.backend.is_dir(str(resolved)): + if self.backend.file_exists(str(resolved)): + return self._validation_error(f"Not a directory: {path}", "NOT_A_DIRECTORY"), None + return self._validation_error(f"Directory not found: {path}", "DIRECTORY_NOT_FOUND"), None + return None, resolved + + def _validate_read_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._read_preflight( + file_path=args["file_path"], + offset=args.get("offset") or 0, + limit=args.get("limit"), + pages=args.get("pages"), + ) + return error or args + + def _validate_write_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._validate_existing_path(args["file_path"], "write") + return error or args + + def _validate_edit_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._edit_preflight(file_path=args["file_path"]) + return error or args + + def _validate_list_dir_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._list_dir_preflight(path=args["path"]) + return error or args + + def _check_file_staleness(self, resolved: ResolvedPath) -> str | None: + state = self._read_files.get(resolved) + if state is None: + return "File has not been read yet. Read the full file first before editing." + if state.is_partial: + return "File has only been read partially. Read the full file before editing." + stored_mtime = state.timestamp if stored_mtime is None: return None current_mtime = self.backend.file_mtime(str(resolved)) @@ -235,8 +479,70 @@ def _check_file_staleness(self, resolved: Path) -> str | None: return "File has been modified since last read. Read it again before editing." return None - def _update_file_tracking(self, resolved: Path) -> None: - self._read_files[resolved] = self.backend.file_mtime(str(resolved)) + def _update_file_tracking( + self, + resolved: ResolvedPath, + *, + is_partial: bool, + file_type: FileType | None = None, + ) -> None: + if file_type is None: + file_type = self._detect_file_type(resolved) + if file_type not in {FileType.TEXT, FileType.NOTEBOOK}: + return + self._read_files.set( + resolved, + _ReadFileState( + timestamp=self.backend.file_mtime(str(resolved)), + is_partial=is_partial, + ), + ) + + def _normalize_write_content(self, content: str) -> str: + return content.replace("\r\n", "\n").replace("\r", "\n") + + def _read_result_is_partial(self, result) -> bool: + if getattr(result, "truncated", False): + return True + if getattr(result, "file_type", None) == FileType.TEXT: + start_line = getattr(result, "start_line", None) or 1 + total_lines = getattr(result, "total_lines", None) + end_line = getattr(result, "end_line", None) or total_lines or start_line + if total_lines is not None: + return start_line > 1 or end_line < total_lines + return False + + def _detect_file_type(self, resolved: ResolvedPath) -> FileType: + return detect_file_type(Path(str(resolved))) + + def _structured_media_success( + self, + *, + resolved: ResolvedPath, + file_type: FileType, + content_blocks: list[dict[str, str]], + ) -> ToolResultEnvelope: + return tool_success( + [ + { + "type": "text", + "text": (f"Read file: {resolved.name}\nSpecial content is attached below as structured blocks."), + }, + *content_blocks, + ], + metadata={"file_type": file_type.value}, + ) + + def _restore_special_result_identity( + self, + *, + result, + resolved: ResolvedPath, + temp_path: Path, + ) -> None: + result.file_path = str(resolved) + if isinstance(getattr(result, "content", None), str): + result.content = result.content.replace(str(temp_path), str(resolved)).replace(temp_path.name, resolved.name) def _record_operation( self, @@ -267,7 +573,7 @@ def _record_operation( except Exception as e: raise RuntimeError(f"[FileSystemService] Failed to record operation: {e}") from e - def _count_lines(self, resolved: Path) -> int: + def _count_lines(self, resolved: ResolvedPath) -> int: try: raw = self.backend.read_file(str(resolved)) return raw.content.count("\n") + 1 @@ -278,50 +584,86 @@ def _count_lines(self, resolved: Path) -> int: # Tool handlers # ------------------------------------------------------------------ - def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None) -> str: - is_valid, error, resolved = self._validate_path(file_path, "read") - if not is_valid: - return error - - file_size = self.backend.file_size(str(resolved)) - - if file_size is not None and file_size > self.max_file_size: - return f"File too large: {file_size:,} bytes (max: {self.max_file_size:,} bytes)" - - has_pagination = offset > 0 or limit is not None - if not has_pagination and file_size is not None: - limits = ReadLimits() - if file_size > limits.max_size_bytes: - total_lines = self._count_lines(resolved) - return ( - f"File content ({file_size:,} bytes) exceeds maximum allowed size ({limits.max_size_bytes:,} bytes).\n" - f"Use offset and limit parameters to read specific sections.\n" - f"Total lines: {total_lines}" - ) - estimated_tokens = file_size // 4 - if estimated_tokens > limits.max_tokens: - total_lines = self._count_lines(resolved) - return ( - f"File content (~{estimated_tokens:,} tokens) exceeds maximum allowed tokens ({limits.max_tokens:,}).\n" - f"Use offset and limit parameters to read specific sections.\n" - f"Total lines: {total_lines}" - ) + def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None, pages: str | None = None) -> str | ToolResultEnvelope: + error, resolved = self._read_preflight( + file_path=file_path, + offset=offset, + limit=limit, + pages=pages, + ) + if error is not None: + return self._validation_message(error) + assert resolved is not None from core.tools.filesystem.local_backend import LocalBackend if isinstance(self.backend, LocalBackend): + assert isinstance(resolved, Path) limits = ReadLimits() result = read_file_dispatch( path=resolved, limits=limits, offset=offset if offset > 0 else None, limit=limit, + pages=pages, ) if not result.error: - self._update_file_tracking(resolved) + self._update_file_tracking( + resolved, + is_partial=self._read_result_is_partial(result), + file_type=result.file_type, + ) + if result.content_blocks: + return self._structured_media_success( + resolved=resolved, + file_type=result.file_type, + content_blocks=result.content_blocks, + ) return result.format_output() try: + file_type = self._detect_file_type(resolved) + download_bytes = getattr(self.backend, "download_bytes", None) + if callable(download_bytes) and file_type in {FileType.BINARY, FileType.DOCUMENT}: + # @@@dt-02-remote-special-file-bridge + # Remote providers expose raw-byte download hooks. Reuse the + # same local dispatcher for binary/document reads instead of + # degrading special files into placeholder text. + raw_bytes = download_bytes(str(resolved)) + if not isinstance(raw_bytes, (bytes, bytearray)): + raise TypeError(f"Remote special-file download returned {type(raw_bytes).__name__}, expected bytes.") + raw_bytes = bytes(raw_bytes) + if ( + file_type == FileType.BINARY + and resolved.suffix.lstrip(".").lower() in IMAGE_EXTENSIONS + and len(raw_bytes) > MAX_IMAGE_SIZE + ): + return f"Image exceeds size limit: {len(raw_bytes)} bytes" + with tempfile.NamedTemporaryFile(suffix=resolved.suffix, delete=False) as tmp: + tmp.write(raw_bytes) + tmp_path = Path(tmp.name) + try: + result = read_file_dispatch( + path=tmp_path, + limits=ReadLimits(), + offset=offset if offset > 0 else None, + limit=limit, + pages=pages, + ) + finally: + tmp_path.unlink(missing_ok=True) + self._restore_special_result_identity( + result=result, + resolved=resolved, + temp_path=tmp_path, + ) + if result.content_blocks: + return self._structured_media_success( + resolved=resolved, + file_type=result.file_type, + content_blocks=result.content_blocks, + ) + return result.format_output() raw = self.backend.read_file(str(resolved)) lines = raw.content.split("\n") total_lines = len(lines) @@ -331,7 +673,10 @@ def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None) selected = lines[start:end] numbered = [f"{start + i + 1:>6}\t{line}" for i, line in enumerate(selected)] content = "\n".join(numbered) - self._update_file_tracking(resolved) + self._update_file_tracking( + resolved, + is_partial=start > 0 or end < total_lines, + ) return content except Exception as e: return f"Error reading file: {e}" @@ -340,88 +685,102 @@ def _write_file(self, file_path: str, content: str) -> str: is_valid, error, resolved = self._validate_path(file_path, "write") if not is_valid: return error - - if self.backend.file_exists(str(resolved)): - return f"File already exists: {file_path}\nUse Edit to modify existing files" + assert resolved is not None try: - result = self.backend.write_file(str(resolved), content) + normalized = self._normalize_write_content(content) + result = self.backend.write_file(str(resolved), normalized) if not result.success: return f"Error writing file: {result.error}" - self._update_file_tracking(resolved) + self._update_file_tracking(resolved, is_partial=False) self._record_operation( operation_type="write", file_path=file_path, before_content=None, - after_content=content, + after_content=normalized, ) - lines = content.count("\n") + 1 + lines = normalized.count("\n") + 1 return f"File created: {file_path}\n Lines: {lines}\n Size: {len(content)} bytes" except Exception as e: return f"Error writing file: {e}" def _edit_file(self, file_path: str, old_string: str, new_string: str, replace_all: bool = False) -> str: - is_valid, error, resolved = self._validate_path(file_path, "edit") - if not is_valid: - return error - - if not self.backend.file_exists(str(resolved)): - return f"File not found: {file_path}" - - staleness_error = self._check_file_staleness(resolved) - if staleness_error: - return staleness_error - - if old_string == new_string: - return "Error: old_string and new_string are identical (no-op edit)" + error, resolved = self._edit_preflight(file_path=file_path) + if error is not None: + return self._validation_message(error) + assert resolved is not None try: - raw = self.backend.read_file(str(resolved)) - content = raw.content - - if old_string not in content: - return f"String not found in file\n Looking for: {old_string[:100]}..." - - if replace_all: - count = content.count(old_string) - new_content = content.replace(old_string, new_string) - else: - count = content.count(old_string) - if count > 1: - return ( - f"String appears {count} times in file (not unique)\n" - f" Use replace_all=true or provide more context to make it unique" - ) - new_content = content.replace(old_string, new_string, 1) - count = 1 - - result = self.backend.write_file(str(resolved), new_content) - if not result.success: - return f"Error editing file: {result.error}" - - self._update_file_tracking(resolved) - self._record_operation( - operation_type="edit", - file_path=file_path, - before_content=content, - after_content=new_content, - changes=[{"old_string": old_string, "new_string": new_string}], - ) - return f"File edited: {file_path}\n Replaced {count} occurrence(s)" + # @@@edit-critical-lock + # dt-01 requires the reread -> stale check -> write path to be one + # synchronous critical section so two stale concurrent edits cannot + # both commit from the same prior read snapshot. + with self._edit_critical_section: + try: + raw = self.backend.read_file(str(resolved)) + except FileNotFoundError: + if old_string == "": + return self._write_file(file_path, new_string) + return f"File not found: {file_path}" + content = raw.content + + if old_string == "": + return "Cannot use empty old_string on an existing file. Use Write to replace the full file content." + staleness_error = self._check_file_staleness(resolved) + if staleness_error: + return staleness_error + + if old_string == new_string: + return "Error: old_string and new_string are identical (no-op edit)" + + # @@@edit-critical-staleness + # te-06 needs a second stale-read check inside the read->write + # critical section so an external write that lands after the + # preflight check cannot be silently overwritten. + staleness_error = self._check_file_staleness(resolved) + if staleness_error: + return staleness_error + + if old_string not in content: + return f"String not found in file\n Looking for: {old_string[:100]}..." + + if replace_all: + count = content.count(old_string) + new_content = content.replace(old_string, new_string) + else: + count = content.count(old_string) + if count > 1: + return ( + f"String appears {count} times in file (not unique)\n" + f" Use replace_all=true or provide more context to make it unique" + ) + new_content = content.replace(old_string, new_string, 1) + count = 1 + + result = self.backend.write_file(str(resolved), new_content) + if not result.success: + return f"Error editing file: {result.error}" + + self._update_file_tracking(resolved, is_partial=False) + self._record_operation( + operation_type="edit", + file_path=file_path, + before_content=content, + after_content=new_content, + changes=[{"old_string": old_string, "new_string": new_string}], + ) + return f"File edited: {file_path}\n Replaced {count} occurrence(s)" except Exception as e: return f"Error editing file: {e}" - def _list_dir(self, directory_path: str) -> str: - is_valid, error, resolved = self._validate_path(directory_path, "list") - if not is_valid: - return error - - if not self.backend.is_dir(str(resolved)): - if self.backend.file_exists(str(resolved)): - return f"Not a directory: {directory_path}" - return f"Directory not found: {directory_path}" + def _list_dir(self, path: str) -> str: + directory_path = path + error, resolved = self._list_dir_preflight(path=directory_path) + if error is not None: + return self._validation_message(error) + assert resolved is not None try: result = self.backend.list_dir(str(resolved)) diff --git a/core/tools/lsp/__init__.py b/core/tools/lsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/tools/lsp/service.py b/core/tools/lsp/service.py new file mode 100644 index 000000000..dc480812d --- /dev/null +++ b/core/tools/lsp/service.py @@ -0,0 +1,838 @@ +"""LSP Service - Language Server Protocol code intelligence via multilspy. + +Registers a single DEFERRED `LSP` tool with 9 operations: + goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, + goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls + +Sessions are managed by the process-level _LSPSessionPool singleton — they +start lazily on first use and persist for the lifetime of the process, +surviving agent restarts. Call `await lsp_pool.close_all()` on process exit. + +Supported languages (via multilspy): + python, typescript, javascript, go, rust, java, ruby, kotlin, csharp +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import shutil +import subprocess +from pathlib import Path +from typing import Any + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +_FILE_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MB — matches CC LSP limit + +logger = logging.getLogger(__name__) + +LSP_SCHEMA = make_tool_schema( + name="LSP", + description=( + "Language Server Protocol code intelligence. " + "Operations: goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, " + "goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls. " + "Language servers are auto-downloaded on first use. " + "Supports python, typescript, javascript, go, rust, java, ruby, kotlin. " + "file_path must be absolute. line/character are 1-based. " + "incomingCalls/outgoingCalls require 'item' from prepareCallHierarchy output." + ), + properties={ + "operation": { + "type": "string", + "enum": [ + "goToDefinition", + "findReferences", + "hover", + "documentSymbol", + "workspaceSymbol", + "goToImplementation", + "prepareCallHierarchy", + "incomingCalls", + "outgoingCalls", + ], + "description": "LSP operation to perform", + }, + "file_path": { + "type": "string", + "description": "Absolute path to file (required for all operations except workspaceSymbol)", + }, + "line": { + "type": "integer", + "description": "1-based line number (required for goToDefinition, findReferences, hover)", + }, + "character": { + "type": "integer", + "description": "1-based character offset (required for goToDefinition, findReferences, hover)", + }, + "query": { + "type": "string", + "description": "Symbol name to search (required for workspaceSymbol)", + }, + "language": { + "type": "string", + "description": "Language override. Auto-detected from file extension if omitted.", + }, + "item": { + "type": "object", + "description": "CallHierarchyItem from prepareCallHierarchy (required for incomingCalls/outgoingCalls).", + }, + }, + required=["operation"], +) + +# File extension → multilspy language identifier +_EXT_TO_LANG: dict[str, str] = { + ".py": "python", + ".ts": "typescript", + ".tsx": "typescript", + ".js": "javascript", + ".jsx": "javascript", + ".go": "go", + ".rs": "rust", + ".java": "java", + ".rb": "ruby", + ".kt": "kotlin", + ".cs": "csharp", +} + + +def _find_pyright() -> str | None: + """Locate pyright-langserver: venv-local first, then PATH.""" + for name in ("pyright-langserver", "pyright_langserver"): + # prefer the binary in the same venv as the current interpreter + venv_bin = Path(os.__file__).parent.parent.parent / "bin" / name + if venv_bin.exists(): + return str(venv_bin) + found = shutil.which(name) + if found: + return found + return None + + +class _PyrightSession: + """Minimal asyncio LSP client for pyright-langserver (stdio). + + Used for Python operations not supported by Jedi: + goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls. + + Requires pyright in the active venv: pip install pyright + """ + + def __init__(self, workspace_root: str) -> None: + self._workspace_root = workspace_root + self._proc: asyncio.subprocess.Process | None = None + self._pending: dict[int, asyncio.Future] = {} + self._next_id = 1 + self._reader_task: asyncio.Task | None = None + self._open_files: set[str] = set() + + async def start(self) -> None: + server = _find_pyright() + if not server: + raise RuntimeError("pyright-langserver not found. Install with: pip install pyright") + self._proc = await asyncio.create_subprocess_exec( + server, + "--stdio", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + self._reader_task = asyncio.create_task(self._read_loop(), name="pyright-reader") + + # LSP handshake + await self._request( + "initialize", + { + "processId": os.getpid(), + "rootUri": Path(self._workspace_root).as_uri(), + "capabilities": { + "textDocument": { + "synchronization": {"dynamicRegistration": False}, + "implementation": {"dynamicRegistration": False, "linkSupport": True}, + "callHierarchy": {"dynamicRegistration": False}, + } + }, + "initializationOptions": {}, + }, + ) + self._notify("initialized", {}) + + # ── I/O ─────────────────────────────────────────────────────────── + + async def _read_loop(self) -> None: + try: + while True: + assert self._proc and self._proc.stdout + # Read headers until blank line + content_length = 0 + while True: + raw = await self._proc.stdout.readline() + if not raw: + return + line = raw.decode().rstrip() + if not line: + break + if line.lower().startswith("content-length:"): + content_length = int(line.split(":", 1)[1].strip()) + if content_length == 0: + continue + body = await self._proc.stdout.readexactly(content_length) + msg = json.loads(body) + # Route response/error to waiting Future + msg_id = msg.get("id") + msg_method = msg.get("method", "") + if msg_id is not None and msg_method: + # Server-to-client request — must acknowledge with a response + self._write({"jsonrpc": "2.0", "id": msg_id, "result": None}) + await self._drain() + elif msg_id is not None and msg_id in self._pending: + fut = self._pending.pop(msg_id) + if not fut.done(): + if "error" in msg: + fut.set_exception(RuntimeError(f"{msg['error'].get('message', 'LSP error')} ({msg['error'].get('code', '')})")) + else: + fut.set_result(msg.get("result")) + # All other notifications ($/progress, diagnostics, etc.) are silently dropped + except Exception as exc: + for fut in self._pending.values(): + if not fut.done(): + fut.set_exception(exc) + + def _write(self, msg: dict) -> None: + """Encode and buffer one LSP message (call drain() to flush).""" + assert self._proc and self._proc.stdin + body = json.dumps(msg, separators=(",", ":")).encode() + header = f"Content-Length: {len(body)}\r\n\r\n".encode() + self._proc.stdin.write(header + body) + + async def _drain(self) -> None: + assert self._proc and self._proc.stdin + await self._proc.stdin.drain() + + def _notify(self, method: str, params: Any) -> None: + self._write({"jsonrpc": "2.0", "method": method, "params": params}) + + async def _request(self, method: str, params: Any, timeout: float = 30.0) -> Any: + req_id = self._next_id + self._next_id += 1 + loop = asyncio.get_event_loop() + fut: asyncio.Future = loop.create_future() + self._pending[req_id] = fut + self._write({"jsonrpc": "2.0", "id": req_id, "method": method, "params": params}) + await self._drain() + return await asyncio.wait_for(fut, timeout=timeout) + + # ── file lifecycle ──────────────────────────────────────────────── + + def _open_file(self, abs_path: str) -> None: + uri = Path(abs_path).as_uri() + if uri in self._open_files: + return + try: + text = Path(abs_path).read_text(encoding="utf-8", errors="replace") + except OSError: + text = "" + self._notify("textDocument/didOpen", {"textDocument": {"uri": uri, "languageId": "python", "version": 1, "text": text}}) + self._open_files.add(uri) + + def _close_file(self, abs_path: str) -> None: + uri = Path(abs_path).as_uri() + if uri not in self._open_files: + return + self._notify("textDocument/didClose", {"textDocument": {"uri": uri}}) + self._open_files.discard(uri) + + def _abs(self, rel_path: str) -> str: + return str(Path(self._workspace_root) / rel_path) + + # ── LSP operations ──────────────────────────────────────────────── + + async def request_implementation(self, rel_path: str, line: int, col: int) -> list: + abs_path = self._abs(rel_path) + self._open_file(abs_path) + await self._drain() + uri = Path(abs_path).as_uri() + response = await self._request( + "textDocument/implementation", + { + "textDocument": {"uri": uri}, + "position": {"line": line, "character": col}, + }, + ) + return self._normalise_locations(response) + + async def request_prepare_call_hierarchy(self, rel_path: str, line: int, col: int) -> list: + abs_path = self._abs(rel_path) + self._open_file(abs_path) + await self._drain() + uri = Path(abs_path).as_uri() + response = await self._request( + "textDocument/prepareCallHierarchy", + { + "textDocument": {"uri": uri}, + "position": {"line": line, "character": col}, + }, + ) + # File stays open — callHierarchy/incomingCalls and outgoingCalls may need it + return response or [] + + async def request_incoming_calls(self, item: dict) -> list: + response = await self._request("callHierarchy/incomingCalls", {"item": item}) + return response or [] + + async def request_outgoing_calls(self, item: dict) -> list: + response = await self._request("callHierarchy/outgoingCalls", {"item": item}) + return response or [] + + @staticmethod + def _normalise_locations(response: Any) -> list: + if not response: + return [] + if isinstance(response, dict): + response = [response] + out = [] + for loc in response: + uri = loc.get("uri") or loc.get("targetUri", "") + rng = loc.get("range") or loc.get("targetSelectionRange") or loc.get("targetRange") or {} + out.append({"uri": uri, "absolutePath": uri.replace("file://", ""), "range": rng}) + return out + + # ── shutdown ────────────────────────────────────────────────────── + + async def stop(self) -> None: + if self._proc: + try: + await asyncio.wait_for(self._request("shutdown", {}), timeout=5) + self._notify("exit", {}) + except Exception: + pass + try: + self._proc.terminate() + await asyncio.wait_for(self._proc.wait(), timeout=5) + except Exception: + self._proc.kill() + if self._reader_task and not self._reader_task.done(): + self._reader_task.cancel() + try: + await self._reader_task + except (asyncio.CancelledError, Exception): + pass + + +class _LSPSession: + """Holds a multilspy LanguageServer alive in a background asyncio task. + + Pattern: start_server() is an async context manager that must stay open + for the lifetime of the session. We enter it inside a background Task and + use an Event to signal readiness. Stopping sets a second Event that causes + the background task to exit the context and shut down the server process. + """ + + def __init__(self, language: str, workspace_root: str) -> None: + self.language = language + self._workspace_root = workspace_root + self._ready = asyncio.Event() + self._stop = asyncio.Event() + self._task: asyncio.Task | None = None + self._lsp: Any = None + self._error: Exception | None = None + + async def start(self) -> None: + self._task = asyncio.create_task(self._run(), name=f"lsp-{self.language}") + try: + await asyncio.wait_for(asyncio.shield(self._ready.wait()), timeout=60) + except TimeoutError: + raise TimeoutError(f"LSP server for '{self.language}' did not start within 60s") + if self._error: + raise self._error + + async def _run(self) -> None: + try: + from multilspy import LanguageServer # core dep — always available + from multilspy.multilspy_config import MultilspyConfig + from multilspy.multilspy_logger import MultilspyLogger + + config = MultilspyConfig.from_dict({"code_language": self.language}) + lsp_logger = MultilspyLogger() + self._lsp = LanguageServer.create(config, lsp_logger, self._workspace_root) + async with self._lsp.start_server(): + self._ready.set() + await self._stop.wait() + except Exception as e: + self._error = e + self._ready.set() # unblock any waiters + logger.error("[LSPService] %s server error: %s", self.language, e) + + async def stop(self) -> None: + self._stop.set() + if self._task and not self._task.done(): + try: + await asyncio.wait_for(self._task, timeout=5) + except (TimeoutError, asyncio.CancelledError): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + # ── request methods ─────────────────────────────────────────────── + + async def request_definition(self, rel_path: str, line: int, col: int) -> list: + try: + return await self._lsp.request_definition(rel_path, line, col) or [] + except AssertionError: + return [] # multilspy asserts on None response (no definition found) + + async def request_references(self, rel_path: str, line: int, col: int) -> list: + try: + return await self._lsp.request_references(rel_path, line, col) or [] + except AssertionError: + return [] + + async def request_hover(self, rel_path: str, line: int, col: int) -> Any: + try: + return await self._lsp.request_hover(rel_path, line, col) + except AssertionError: + return None + + async def request_document_symbols(self, rel_path: str) -> list: + try: + symbols, _ = await self._lsp.request_document_symbols(rel_path) + return symbols or [] + except AssertionError: + return [] + + async def request_workspace_symbol(self, query: str) -> list: + return await self._lsp.request_workspace_symbol(query) or [] + + # ── advanced ops (direct server.send, for servers that support them) ── + + async def request_implementation(self, rel_path: str, line: int, col: int) -> list: + abs_uri = Path(self._workspace_root, rel_path).as_uri() + with self._lsp.open_file(rel_path): + response = await self._lsp.server.send.implementation( + {"textDocument": {"uri": abs_uri}, "position": {"line": line, "character": col}} + ) + if not response: + return [] + if isinstance(response, dict): + response = [response] + out = [] + for item in response: + if "uri" in item and "range" in item: + item.setdefault("absolutePath", item["uri"].replace("file://", "")) + out.append(item) + elif "targetUri" in item: + out.append( + { + "uri": item["targetUri"], + "absolutePath": item["targetUri"].replace("file://", ""), + "range": item.get("targetSelectionRange", item.get("targetRange", {})), + } + ) + return out + + async def request_prepare_call_hierarchy(self, rel_path: str, line: int, col: int) -> list: + abs_uri = Path(self._workspace_root, rel_path).as_uri() + with self._lsp.open_file(rel_path): + response = await self._lsp.server.send.prepare_call_hierarchy( + {"textDocument": {"uri": abs_uri}, "position": {"line": line, "character": col}} + ) + return response or [] + + async def request_incoming_calls(self, item: dict) -> list: + response = await self._lsp.server.send.incoming_calls({"item": item}) + return response or [] + + async def request_outgoing_calls(self, item: dict) -> list: + response = await self._lsp.server.send.outgoing_calls({"item": item}) + return response or [] + + +class _LSPSessionPool: + """Process-level singleton managing LSP sessions across all agent instances. + + Sessions are keyed by (language, workspace_root) and survive agent restarts. + Call close_all() once at process exit (e.g. from backend lifespan shutdown). + """ + + def __init__(self) -> None: + # (language, workspace_root) → _LSPSession + self._sessions: dict[tuple[str, str], _LSPSession] = {} + # workspace_root → _PyrightSession + self._pyright: dict[str, _PyrightSession] = {} + # In-flight start tasks to prevent duplicate starts under concurrent requests + self._starting: dict[tuple[str, str], asyncio.Task] = {} + self._starting_pyright: dict[str, asyncio.Task] = {} + + async def get_session(self, language: str, workspace_root: str) -> _LSPSession: + key = (language, workspace_root) + if key in self._sessions: + return self._sessions[key] + if key not in self._starting: + + async def _start() -> _LSPSession: + logger.info("[LSPPool] starting %s language server (workspace=%s)...", language, workspace_root) + s = _LSPSession(language, workspace_root) + await s.start() + self._sessions[key] = s + self._starting.pop(key, None) + logger.info("[LSPPool] %s language server ready", language) + return s + + self._starting[key] = asyncio.create_task(_start(), name=f"lsp-start-{language}") + return await self._starting[key] + + async def get_pyright(self, workspace_root: str) -> _PyrightSession: + if workspace_root in self._pyright: + return self._pyright[workspace_root] + if workspace_root not in self._starting_pyright: + + async def _start() -> _PyrightSession: + logger.info("[LSPPool] starting pyright (workspace=%s)...", workspace_root) + s = _PyrightSession(workspace_root) + await s.start() + self._pyright[workspace_root] = s + self._starting_pyright.pop(workspace_root, None) + logger.info("[LSPPool] pyright ready") + return s + + self._starting_pyright[workspace_root] = asyncio.create_task(_start(), name="lsp-start-pyright") + return await self._starting_pyright[workspace_root] + + async def close_all(self) -> None: + """Stop all running language server processes. Call once at process exit.""" + for (lang, ws), session in list(self._sessions.items()): + try: + await session.stop() + logger.debug("[LSPPool] stopped %s server (workspace=%s)", lang, ws) + except Exception as e: + logger.debug("[LSPPool] error stopping %s: %s", lang, e) + self._sessions.clear() + for ws, session in list(self._pyright.items()): + try: + await session.stop() + logger.debug("[LSPPool] stopped pyright (workspace=%s)", ws) + except Exception as e: + logger.debug("[LSPPool] error stopping pyright: %s", e) + self._pyright.clear() + + +# Process-level singleton — import and use directly +lsp_pool = _LSPSessionPool() + + +class LSPService: + """Registers the LSP tool (DEFERRED) into ToolRegistry. + + Delegates all session management to the process-level lsp_pool singleton. + Language servers start lazily on first use and persist across agent restarts. + """ + + # Operations that Jedi doesn't support — routed to pyright for Python, + # or to the native server.send.* for other languages. + _ADVANCED_OPS: frozenset[str] = frozenset({"goToImplementation", "prepareCallHierarchy", "incomingCalls", "outgoingCalls"}) + + def __init__(self, registry: ToolRegistry, workspace_root: str | Path) -> None: + self._workspace_root = str(Path(workspace_root).resolve()) + registry.register( + ToolEntry( + name="LSP", + mode=ToolMode.DEFERRED, + schema=LSP_SCHEMA, + handler=self._handle, + source="LSPService", + search_hint="language server definition references hover symbols go-to", + is_read_only=True, + is_concurrency_safe=True, + ) + ) + logger.debug("[LSPService] registered (workspace=%s)", self._workspace_root) + + # ── session management (delegates to process-level pool) ────────── + + async def _get_session(self, language: str) -> _LSPSession: + return await lsp_pool.get_session(language, self._workspace_root) + + async def _get_pyright(self) -> _PyrightSession: + return await lsp_pool.get_pyright(self._workspace_root) + + def _detect_language(self, file_path: str) -> str | None: + return _EXT_TO_LANG.get(Path(file_path).suffix.lower()) + + def _to_relative(self, file_path: str) -> str: + try: + return str(Path(file_path).relative_to(self._workspace_root)) + except ValueError: + return file_path # fallback: pass as-is + + # ── pre-flight checks ───────────────────────────────────────────── + + @staticmethod + def _check_file(file_path: str) -> str | None: + """Return error string if file exceeds 10 MB limit, else None.""" + try: + size = Path(file_path).stat().st_size + except OSError: + return None # let LSP handle missing file errors + if size > _FILE_SIZE_LIMIT: + mb = size / (1024 * 1024) + return f"File too large ({mb:.1f} MB). LSP file size limit is 10 MB." + return None + + def _filter_gitignored(self, locations: list) -> list: + """Filter out locations inside gitignored paths (batches of 50, like CC).""" + if not locations: + return locations + abs_paths = [loc.get("absolutePath") or loc.get("uri", "").replace("file://", "") for loc in locations] + try: + # git check-ignore exits 0 if any path is ignored, 1 if none are + result = subprocess.run( + ["git", "check-ignore", "--stdin", "-z"], + input="\0".join(abs_paths), + capture_output=True, + text=True, + cwd=self._workspace_root, + timeout=5, + ) + ignored = set(result.stdout.split("\0")) if result.stdout else set() + except Exception: + return locations # on error, return all (fail-open) + return [loc for loc, p in zip(locations, abs_paths) if p not in ignored] + + def _filter_gitignored_batched(self, locations: list) -> list: + """Run _filter_gitignored in batches of 50 (matches CC batch size).""" + out = [] + for i in range(0, len(locations), 50): + out.extend(self._filter_gitignored(locations[i : i + 50])) + return out + + async def _filter_gitignored_batched_async(self, locations: list) -> list: + return await asyncio.to_thread(self._filter_gitignored_batched, locations) + + # ── output formatters ───────────────────────────────────────────── + + @staticmethod + def _fmt_location(loc: Any) -> dict: + start = loc.get("range", {}).get("start", {}) + return { + "file": loc.get("absolutePath") or loc.get("uri", ""), + "line": start.get("line", 0), + "column": start.get("character", 0), + } + + @staticmethod + def _fmt_hover(result: Any) -> str: + contents = result.get("contents", "") + if isinstance(contents, dict): + return contents.get("value", str(contents)) + if isinstance(contents, list): + parts = [] + for c in contents: + parts.append(c.get("value", str(c)) if isinstance(c, dict) else str(c)) + return "\n".join(parts) + return str(contents) + + @staticmethod + def _fmt_symbol(sym: Any) -> dict: + loc = sym.get("location") or {} + if loc: + # SymbolInformation (workspaceSymbol) — location.uri + location.range + start = loc.get("range", {}).get("start", {}) + uri = loc.get("uri", "") + file = loc.get("absolutePath") or (uri.replace("file://", "") if uri.startswith("file://") else uri) + else: + # DocumentSymbol (documentSymbol) — range/selectionRange at top level, no file + start = sym.get("selectionRange", sym.get("range", {})).get("start", {}) + file = "" + return { + "name": sym.get("name", ""), + "kind": sym.get("kind"), + "file": file, + "line": start.get("line"), + } + + @staticmethod + def _fmt_call_hierarchy_item(item: Any) -> dict: + uri = item.get("uri", "") + start = item.get("range", {}).get("start", {}) + return { + "name": item.get("name", ""), + "kind": item.get("kind"), + "file": uri.replace("file://", "") if uri.startswith("file://") else uri, + "line": start.get("line"), + "item": item, # pass-through for incomingCalls/outgoingCalls + } + + @staticmethod + def _fmt_call_hierarchy_call(call: Any, direction: str) -> dict: + item_key = "from" if direction == "incoming" else "to" + caller = call.get(item_key, {}) + uri = caller.get("uri", "") + start = caller.get("range", {}).get("start", {}) + ranges = [r.get("start", {}) for r in call.get(f"{item_key}Ranges", [])] + return { + "name": caller.get("name", ""), + "kind": caller.get("kind"), + "file": uri.replace("file://", "") if uri.startswith("file://") else uri, + "line": start.get("line"), + "call_sites": [{"line": r.get("line"), "column": r.get("character")} for r in ranges], + "item": caller, # pass-through for chaining + } + + # ── tool handler ────────────────────────────────────────────────── + + async def _handle( + self, + operation: str, + file_path: str | None = None, + line: int | None = None, + character: int | None = None, + query: str | None = None, + language: str | None = None, + item: dict | None = None, + ) -> str: + # Resolve language (incomingCalls/outgoingCalls carry language in item["uri"]) + lang = language + if not lang and file_path: + lang = self._detect_language(file_path) + if not lang and operation in ("incomingCalls", "outgoingCalls") and item: + uri = item.get("uri", "") + lang = self._detect_language(uri) + if not lang: + supported = ", ".join(sorted(set(_EXT_TO_LANG.values()))) + return f"Cannot detect language. Set 'language' parameter. Supported: {supported}" + + # 10 MB file size guard (matches CC LSP limit) + if file_path: + err = self._check_file(file_path) + if err: + return err + + # Python advanced ops → pyright; other languages → multilspy server.send.* + use_pyright = lang == "python" and operation in self._ADVANCED_OPS + + pyright: _PyrightSession | None = None + session: _LSPSession | None = None + + if use_pyright: + try: + pyright = await self._get_pyright() + except Exception as e: + return f"Failed to start pyright language server: {e}" + else: + try: + session = await self._get_session(lang) + except Exception as e: + return f"Failed to start {lang} language server: {e}" + + rel = self._to_relative(file_path) if file_path else "" + # @@@dt-04-lsp-position-contract - CC exposes editor-facing 1-based + # positions and converts at the tool boundary. Leon must do the same + # or every position-aware operation silently lands one symbol off. + zero_line = line - 1 if line is not None else None + zero_character = character - 1 if character is not None else None + + try: + if operation == "goToDefinition": + if not file_path or zero_line is None or zero_character is None: + return "goToDefinition requires: file_path, line, character" + assert session is not None + results = await session.request_definition(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No definition found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "findReferences": + if not file_path or zero_line is None or zero_character is None: + return "findReferences requires: file_path, line, character" + assert session is not None + results = await session.request_references(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No references found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "hover": + if not file_path or zero_line is None or zero_character is None: + return "hover requires: file_path, line, character" + assert session is not None + result = await session.request_hover(rel, zero_line, zero_character) + if not result: + return "No hover info." + return self._fmt_hover(result) + + elif operation == "documentSymbol": + if not file_path: + return "documentSymbol requires: file_path" + assert session is not None + symbols = await session.request_document_symbols(rel) + if not symbols: + return "No symbols found." + return json.dumps([self._fmt_symbol(s) for s in symbols], indent=2) + + elif operation == "workspaceSymbol": + if not query: + return "workspaceSymbol requires: query" + assert session is not None + symbols = await session.request_workspace_symbol(query) + if not symbols: + return f"No symbols matching '{query}'." + return json.dumps([self._fmt_symbol(s) for s in symbols], indent=2) + + elif operation == "goToImplementation": + if not file_path or zero_line is None or zero_character is None: + return "goToImplementation requires: file_path, line, character" + src = pyright if use_pyright else session + assert src is not None + results = await src.request_implementation(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No implementation found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "prepareCallHierarchy": + if not file_path or zero_line is None or zero_character is None: + return "prepareCallHierarchy requires: file_path, line, character" + src = pyright if use_pyright else session + assert src is not None + items = await src.request_prepare_call_hierarchy(rel, zero_line, zero_character) + if not items: + return "No call hierarchy items found." + return json.dumps([self._fmt_call_hierarchy_item(i) for i in items], indent=2) + + elif operation == "incomingCalls": + if not item: + return "incomingCalls requires: item (CallHierarchyItem from prepareCallHierarchy)" + src = pyright if use_pyright else session + assert src is not None + calls = await src.request_incoming_calls(item) + if not calls: + return "No incoming calls found." + return json.dumps([self._fmt_call_hierarchy_call(c, "incoming") for c in calls], indent=2) + + elif operation == "outgoingCalls": + if not item: + return "outgoingCalls requires: item (CallHierarchyItem from prepareCallHierarchy)" + src = pyright if use_pyright else session + assert src is not None + calls = await src.request_outgoing_calls(item) + if not calls: + return "No outgoing calls found." + return json.dumps([self._fmt_call_hierarchy_call(c, "outgoing") for c in calls], indent=2) + + else: + return ( + f"Unknown operation '{operation}'. " + "Valid: goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, " + "goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls" + ) + + except Exception as e: + logger.exception("[LSPService] operation=%s failed", operation) + return f"LSP error: {e}" diff --git a/core/tools/mcp_resources/service.py b/core/tools/mcp_resources/service.py new file mode 100644 index 000000000..bf44c2cbc --- /dev/null +++ b/core/tools/mcp_resources/service.py @@ -0,0 +1,155 @@ +"""Expose MCP resource discovery and reading as agent-callable deferred tools.""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import Callable +from typing import Any + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +LIST_MCP_RESOURCES_SCHEMA = make_tool_schema( + name="ListMcpResources", + description="List MCP resources exposed by connected MCP servers.", + properties={ + "server": { + "type": "string", + "description": "Optional MCP server name to filter by.", + "minLength": 1, + } + }, +) + +READ_MCP_RESOURCE_SCHEMA = make_tool_schema( + name="ReadMcpResource", + description="Read a specific MCP resource by server name and URI.", + properties={ + "server": { + "type": "string", + "description": "MCP server name.", + "minLength": 1, + }, + "uri": { + "type": "string", + "description": "Resource URI to read.", + "minLength": 1, + }, + }, + required=["server", "uri"], +) + + +class McpResourceToolService: + def __init__( + self, + *, + registry: ToolRegistry, + client_fn: Callable[[], Any | None], + server_configs_fn: Callable[[], dict[str, Any]], + ) -> None: + self._client_fn = client_fn + self._server_configs_fn = server_configs_fn + if not self._server_configs_fn(): + return + self._register(registry) + + def _register(self, registry: ToolRegistry) -> None: + for name, schema, handler in [ + ("ListMcpResources", LIST_MCP_RESOURCES_SCHEMA, self._list_resources), + ("ReadMcpResource", READ_MCP_RESOURCE_SCHEMA, self._read_resource), + ]: + registry.register( + ToolEntry( + name=name, + mode=ToolMode.DEFERRED, + schema=schema, + handler=handler, + source="McpResourceToolService", + is_concurrency_safe=True, + is_read_only=True, + ) + ) + + def _get_client(self) -> Any: + client = self._client_fn() + if client is None: + raise ValueError("MCP client is not initialized") + return client + + def _available_servers(self) -> list[str]: + return list(self._server_configs_fn().keys()) + + @staticmethod + def _stringify_uri(value: Any) -> str | None: + if value is None: + return None + return str(value) + + async def _list_resources(self, server: str | None = None, **_kwargs: Any) -> str: + client = self._get_client() + server_names = [server] if server else self._available_servers() + if server and server not in self._available_servers(): + raise ValueError(f'MCP server not found: "{server}"') + + items: list[dict[str, Any]] = [] + for server_name in server_names: + async with client.session(server_name) as session: + result = await session.list_resources() + for resource in result.resources: + items.append( + { + "server": server_name, + "uri": self._stringify_uri(resource.uri), + "name": getattr(resource, "name", self._stringify_uri(resource.uri)), + "mime_type": getattr(resource, "mimeType", None), + "description": getattr(resource, "description", None), + } + ) + return json.dumps({"items": items, "total": len(items)}, ensure_ascii=False, indent=2) + + async def _read_resource(self, *, server: str, uri: str, **_kwargs: Any) -> str: + client = self._get_client() + if server not in self._available_servers(): + raise ValueError(f'MCP server not found: "{server}"') + + async with client.session(server) as session: + result = await session.read_resource(uri) + + contents: list[dict[str, Any]] = [] + for content in result.contents: + if hasattr(content, "text"): + contents.append( + { + "uri": self._stringify_uri(content.uri), + "mime_type": getattr(content, "mimeType", None), + "text": content.text, + } + ) + continue + if hasattr(content, "blob"): + blob_size = len(base64.b64decode(content.blob)) + contents.append( + { + "uri": self._stringify_uri(content.uri), + "mime_type": getattr(content, "mimeType", None), + "text": f"Binary MCP resource omitted from context ({blob_size} bytes).", + } + ) + continue + contents.append( + { + "uri": self._stringify_uri(getattr(content, "uri", uri)), + "mime_type": getattr(content, "mimeType", None), + } + ) + + return json.dumps( + { + "server": server, + "uri": uri, + "contents": contents, + }, + ensure_ascii=False, + indent=2, + ) diff --git a/core/tools/search/service.py b/core/tools/search/service.py index 4329de6e4..a6ff0a4d4 100644 --- a/core/tools/search/service.py +++ b/core/tools/search/service.py @@ -12,11 +12,16 @@ import subprocess from pathlib import Path -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema DEFAULT_EXCLUDES: list[str] = [ "node_modules", ".git", + ".svn", + ".hg", + ".bzr", + ".jj", + ".sl", "__pycache__", ".venv", "venv", @@ -50,67 +55,76 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Grep", mode=ToolMode.INLINE, - schema={ - "name": "Grep", - "description": "Search file contents using regex patterns.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Regex pattern to search for", - }, - "path": { - "type": "string", - "description": "File or directory (absolute). Defaults to workspace.", - }, - "glob": { - "type": "string", - "description": "Filter files by glob (e.g., '*.py')", - }, - "type": { - "type": "string", - "description": "Filter by file type (e.g., 'py', 'js')", - }, - "case_insensitive": { - "type": "boolean", - "description": "Case insensitive search", - }, - "after_context": { - "type": "integer", - "description": "Lines to show after each match", - }, - "before_context": { - "type": "integer", - "description": "Lines to show before each match", - }, - "context": { - "type": "integer", - "description": "Context lines before and after each match", - }, - "output_mode": { - "type": "string", - "enum": ["content", "files_with_matches", "count"], - "description": "Output format. Default: files_with_matches", - }, - "head_limit": { - "type": "integer", - "description": "Limit to first N entries", - }, - "offset": { - "type": "integer", - "description": "Skip first N entries", - }, - "multiline": { - "type": "boolean", - "description": "Allow pattern to span multiple lines", - }, + schema=make_tool_schema( + name="Grep", + description=( + "Regex search across files (ripgrep-based). " + "Default output_mode: files_with_matches (sorted by mtime). Default head_limit: 250 entries. " + "Auto-excludes .git/.svn/.hg dirs. Max column width 500 chars (suppresses minified/base64). " + "Use output_mode='content' with after_context/before_context/context for context lines." + ), + properties={ + "pattern": { + "type": "string", + "description": "Regex pattern to search for", + }, + "path": { + "type": "string", + "description": "File or directory (absolute). Defaults to workspace.", + }, + "glob": { + "type": "string", + "description": "Filter files by glob (e.g., '*.py')", + }, + "type": { + "type": "string", + "description": "Filter by file type (e.g., 'py', 'js')", + }, + "case_insensitive": { + "type": "boolean", + "description": "Case insensitive search", + }, + "after_context": { + "type": "integer", + "description": "Lines to show after each match", + }, + "before_context": { + "type": "integer", + "description": "Lines to show before each match", + }, + "context": { + "type": "integer", + "description": "Context lines before and after each match", + }, + "output_mode": { + "type": "string", + "enum": ["content", "files_with_matches", "count"], + "description": "Output format. Default: files_with_matches", + }, + "head_limit": { + "type": "integer", + "description": "Limit to first N entries", + }, + "offset": { + "type": "integer", + "description": "Skip first N entries", + }, + "multiline": { + "type": "boolean", + "description": "Allow pattern to span multiple lines", + }, + "line_numbers": { + "type": "boolean", + "description": "Show line numbers (default true). Only applies with output_mode='content'.", }, - "required": ["pattern"], }, - }, + required=["pattern"], + ), handler=self._grep, source="SearchService", + search_hint="search file contents regex pattern matching ripgrep", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -118,26 +132,30 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Glob", mode=ToolMode.INLINE, - schema={ - "name": "Glob", - "description": "Find files by glob pattern. Returns paths sorted by modification time.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern (e.g., '**/*.py')", - }, - "path": { - "type": "string", - "description": "Directory to search (absolute). Defaults to workspace.", - }, + schema=make_tool_schema( + name="Glob", + description=( + "Fast file pattern matching (ripgrep-based). Returns paths sorted by modification time. " + "Includes hidden files, ignores .gitignore. Default limit 100 results. " + "Use '**/*.py' for recursive search. Path must be absolute." + ), + properties={ + "pattern": { + "type": "string", + "description": "Glob pattern (e.g., '**/*.py')", + }, + "path": { + "type": "string", + "description": "Directory to search (absolute). Defaults to workspace.", }, - "required": ["pattern"], }, - }, + required=["pattern"], + ), handler=self._glob, source="SearchService", + search_hint="find files by name glob pattern matching", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -183,9 +201,10 @@ def _grep( before_context: int | None = None, context: int | None = None, output_mode: str = "files_with_matches", - head_limit: int | None = None, + head_limit: int | None = 250, offset: int | None = None, multiline: bool = False, + line_numbers: bool = True, ) -> str: ok, error, resolved = self._validate_path(path) if not ok: @@ -209,6 +228,7 @@ def _grep( head_limit=head_limit, offset=offset, multiline=multiline, + line_numbers=line_numbers, ) except Exception: pass # fallback to Python @@ -238,8 +258,9 @@ def _ripgrep_search( head_limit: int | None, offset: int | None, multiline: bool, + line_numbers: bool = True, ) -> str: - cmd: list[str] = ["rg", pattern, str(path)] + cmd: list[str] = ["rg", pattern, str(path), "--max-columns", "500"] for excl in DEFAULT_EXCLUDES: cmd.extend(["--glob", f"!{excl}"]) @@ -258,7 +279,8 @@ def _ripgrep_search( elif output_mode == "count": cmd.append("--count") elif output_mode == "content": - cmd.extend(["--line-number", "--no-heading"]) + ln_flag = "--line-number" if line_numbers else "--no-line-number" + cmd.extend([ln_flag, "--no-heading"]) if context is not None: cmd.extend(["-C", str(context)]) else: diff --git a/core/tools/skills/service.py b/core/tools/skills/service.py index e65215a20..17c0b842a 100644 --- a/core/tools/skills/service.py +++ b/core/tools/skills/service.py @@ -9,9 +9,10 @@ from __future__ import annotations import re +from collections.abc import Sequence from pathlib import Path -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema class SkillsService: @@ -20,7 +21,7 @@ class SkillsService: def __init__( self, registry: ToolRegistry, - skill_paths: list[str | Path], + skill_paths: Sequence[str | Path], enabled_skills: dict[str, bool] | None = None, ): self.skill_paths = [Path(p).expanduser().resolve() for p in skill_paths] @@ -65,6 +66,8 @@ def _register(self, registry: ToolRegistry) -> None: schema=self._get_schema, handler=self._load_skill, source="SkillsService", + is_concurrency_safe=True, + is_read_only=True, ) ) @@ -72,24 +75,22 @@ def _get_schema(self) -> dict: available_skills = list(self._skills_index.keys()) skills_list = "\n".join(f"- {name}" for name in available_skills) - return { - "name": "load_skill", - "description": ( - f"Load a specialized skill to access domain-specific knowledge and workflows.\n\n" - f"Available skills:\n{skills_list}\n\n" - f"Returns the skill's instructions and context." + return make_tool_schema( + name="load_skill", + description=( + f"Load a skill for domain-specific guidance. " + f"Use when you need specialized workflows (TDD, debugging, git). " + f"Skills are loaded on-demand to save context.\n\n" + f"Available skills:\n{skills_list}" ), - "parameters": { - "type": "object", - "properties": { - "skill_name": { - "type": "string", - "description": f"Name of the skill to load. Available: {', '.join(self._skills_index.keys())}", - }, + properties={ + "skill_name": { + "type": "string", + "description": f"Name of the skill to load. Available: {', '.join(self._skills_index.keys())}", }, - "required": ["skill_name"], }, - } + required=["skill_name"], + ) def _load_skill(self, skill_name: str) -> str: if skill_name not in self._skills_index: diff --git a/core/tools/task/service.py b/core/tools/task/service.py index b6e9f6f96..e09fd39fa 100644 --- a/core/tools/task/service.py +++ b/core/tools/task/service.py @@ -12,118 +12,110 @@ from pathlib import Path from typing import Any -from backend.web.core.storage_factory import make_tool_task_repo -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema from core.tools.task.types import Task, TaskStatus +from storage.runtime import build_tool_task_repo logger = logging.getLogger(__name__) DEFAULT_DB_PATH = Path.home() / ".leon" / "tasks.db" -TASK_CREATE_SCHEMA = { - "name": "TaskCreate", - "description": ("Create a new task to track work progress. Tasks are created with status 'pending'."), - "parameters": { - "type": "object", - "properties": { - "subject": { - "type": "string", - "description": "Brief task title in imperative form", - }, - "description": { - "type": "string", - "description": "Detailed description of what needs to be done", - }, - "active_form": { - "type": "string", - "description": "Present continuous form for spinner display", - }, - "metadata": { - "type": "object", - "description": "Optional metadata to attach to the task", - }, +TASK_CREATE_SCHEMA = make_tool_schema( + name="TaskCreate", + description=( + "Create a task to track multi-step work. " + "Use for complex tasks with 3+ steps or when managing multiple parallel workstreams. " + "Status starts as 'pending'." + ), + properties={ + "subject": { + "type": "string", + "description": "Brief task title in imperative form", }, - "required": ["subject", "description"], - }, -} - -TASK_GET_SCHEMA = { - "name": "TaskGet", - "description": "Get full details of a task including description and dependencies.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to retrieve", - }, + "description": { + "type": "string", + "description": "Detailed description of what needs to be done", + }, + "active_form": { + "type": "string", + "description": "Present continuous form for spinner display", + }, + "metadata": { + "type": "object", + "description": "Optional metadata to attach to the task", }, - "required": ["task_id"], }, -} - -TASK_LIST_SCHEMA = { - "name": "TaskList", - "description": ("List all tasks with summary info: id, subject, status, owner, blockedBy."), - "parameters": { - "type": "object", - "properties": {}, + required=["subject", "description"], +) + +TASK_GET_SCHEMA = make_tool_schema( + name="TaskGet", + description="Get full details of a task including description and dependencies.", + properties={ + "task_id": { + "type": "string", + "description": "The task ID to retrieve", + }, }, -} - -TASK_UPDATE_SCHEMA = { - "name": "TaskUpdate", - "description": ( + required=["task_id"], +) + +TASK_LIST_SCHEMA = make_tool_schema( + name="TaskList", + description="List all tasks with summary info: id, subject, status, owner, blockedBy.", + properties={}, +) + +TASK_UPDATE_SCHEMA = make_tool_schema( + name="TaskUpdate", + description=( "Update a task's status, dependencies, or other fields. " "Status flow: pending -> in_progress -> completed. " "Use status='deleted' to remove a task." ), - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to update", - }, - "status": { - "type": "string", - "enum": ["pending", "in_progress", "completed", "deleted"], - "description": "New status for the task", - }, - "subject": { - "type": "string", - "description": "New subject for the task", - }, - "description": { - "type": "string", - "description": "New description for the task", - }, - "active_form": { - "type": "string", - "description": "New activeForm for the task", - }, - "owner": { - "type": "string", - "description": "Assign task to an agent", - }, - "add_blocks": { - "type": "array", - "items": {"type": "string"}, - "description": "Task IDs that this task blocks", - }, - "add_blocked_by": { - "type": "array", - "items": {"type": "string"}, - "description": "Task IDs that block this task", - }, - "metadata": { - "type": "object", - "description": "Metadata keys to merge (set key to null to delete)", - }, + properties={ + "task_id": { + "type": "string", + "description": "The task ID to update", + }, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed", "deleted"], + "description": "New status for the task", + }, + "subject": { + "type": "string", + "description": "New subject for the task", + }, + "description": { + "type": "string", + "description": "New description for the task", + }, + "active_form": { + "type": "string", + "description": "New activeForm for the task", + }, + "owner": { + "type": "string", + "description": "Assign task to an agent", + }, + "add_blocks": { + "type": "array", + "items": {"type": "string"}, + "description": "Task IDs that this task blocks", + }, + "add_blocked_by": { + "type": "array", + "items": {"type": "string"}, + "description": "Task IDs that block this task", + }, + "metadata": { + "type": "object", + "description": "Metadata keys to merge (set key to null to delete)", }, - "required": ["task_id"], }, -} + required=["task_id"], +) class TaskService: @@ -139,14 +131,15 @@ class TaskService: def __init__( self, registry: ToolRegistry, - workspace_root: str | None = None, + workspace_root: str | Path | None = None, db_path: Path | None = None, thread_id: str | None = None, + repo: Any | None = None, ): - self._repo = make_tool_task_repo(db_path or DEFAULT_DB_PATH) + self._repo = repo or build_tool_task_repo(db_path=db_path or DEFAULT_DB_PATH) self._default_thread_id = thread_id # override for tests / single-agent TUI self._register(registry) - logger.info("TaskService initialized (db=%s)", db_path or DEFAULT_DB_PATH) + logger.info("TaskService initialized") def _get_thread_id(self) -> str: if self._default_thread_id: @@ -157,12 +150,14 @@ def _get_thread_id(self) -> str: return tid or "default" def _register(self, registry: ToolRegistry) -> None: + read_only = {"TaskGet", "TaskList"} for name, schema, handler in [ ("TaskCreate", TASK_CREATE_SCHEMA, self._create), ("TaskGet", TASK_GET_SCHEMA, self._get), ("TaskList", TASK_LIST_SCHEMA, self._list), ("TaskUpdate", TASK_UPDATE_SCHEMA, self._update), ]: + ro = name in read_only registry.register( ToolEntry( name=name, @@ -170,6 +165,8 @@ def _register(self, registry: ToolRegistry) -> None: schema=schema, handler=handler, source="TaskService", + is_concurrency_safe=ro, + is_read_only=ro, ) ) diff --git a/core/tools/tool_search/service.py b/core/tools/tool_search/service.py index 9b5ceba77..234007182 100644 --- a/core/tools/tool_search/service.py +++ b/core/tools/tool_search/service.py @@ -9,24 +9,26 @@ import json import logging -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema logger = logging.getLogger(__name__) -TOOL_SEARCH_SCHEMA = { - "name": "tool_search", - "description": ("Search for available tools. Use this to discover tools that might help with your task."), - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query - tool name or description of what you want to do", - }, +TOOL_SEARCH_SCHEMA = make_tool_schema( + name="tool_search", + description=( + "Search for available deferred tools by name or keyword. " + "Use 'select:ToolA,ToolB' for exact deferred-tool lookup (returns full schema). " + "Use keywords for fuzzy search (up to 5 results). " + "Deferred tools are only usable after discovery via this tool." + ), + properties={ + "query": { + "type": "string", + "description": "Search query. Use 'select:ToolA,ToolB' for exact deferred-tool lookup, or keywords for fuzzy search.", }, - "required": ["query"], }, -} + required=["query"], +) class ToolSearchService: @@ -41,11 +43,34 @@ def __init__(self, registry: ToolRegistry): schema=TOOL_SEARCH_SCHEMA, handler=self._search, source="ToolSearchService", + is_concurrency_safe=True, + is_read_only=True, ) ) logger.info("ToolSearchService initialized") - def _search(self, query: str = "", **kwargs) -> str: - results = self._registry.search(query) + def _search(self, query: str = "", tool_context=None, **kwargs) -> str: + select_names: list[str] = [] + normalized = query.strip() + if normalized.lower().startswith("select:"): + select_names = [name.strip() for name in normalized[len("select:") :].split(",") if name.strip()] + + results = self._registry.search(query, modes={ToolMode.DEFERRED}) + if select_names: + found_names = {entry.name for entry in results} + missing = [name for name in select_names if name not in found_names] + inline = [name for name in missing if (entry := self._registry.get(name)) is not None and entry.mode == ToolMode.INLINE] + unknown = [name for name in missing if self._registry.get(name) is None] + if inline or unknown: + parts: list[str] = [] + if inline: + parts.append(f"inline/already-available tools: {', '.join(inline)}") + if unknown: + parts.append(f"unknown tools: {', '.join(unknown)}") + raise ValueError("tool_search select: only supports deferred tools; " + "; ".join(parts)) + else: + results = results[:5] + if tool_context is not None and hasattr(tool_context, "discovered_tool_names"): + tool_context.discovered_tool_names.update(entry.name for entry in results) schemas = [e.get_schema() for e in results] return json.dumps(schemas, indent=2, ensure_ascii=False) diff --git a/core/tools/web/fetchers/markdownify.py b/core/tools/web/fetchers/markdownify.py index 22e855f8e..508790276 100644 --- a/core/tools/web/fetchers/markdownify.py +++ b/core/tools/web/fetchers/markdownify.py @@ -3,12 +3,15 @@ from __future__ import annotations import re +from collections.abc import Callable +from typing import Any import httpx from core.tools.web.fetchers.base import BaseFetcher from core.tools.web.types import ContentChunk, FetchLimits, FetchResult +md: Callable[..., str] | None = None try: from markdownify import markdownify as md @@ -16,6 +19,7 @@ except ImportError: HAS_MARKDOWNIFY = False +BeautifulSoup: Any | None = None try: from bs4 import BeautifulSoup @@ -112,7 +116,11 @@ def _process_html(self, html: str, result: FetchResult) -> str: def _markdownify_html(self, html: str, result: FetchResult) -> str: """Convert HTML to Markdown using markdownify.""" + if md is None: + raise RuntimeError("markdownify import unexpectedly unavailable") if self.has_bs4: + if BeautifulSoup is None: + raise RuntimeError("BeautifulSoup import unexpectedly unavailable") soup = BeautifulSoup(html, "html.parser") title_tag = soup.find("title") @@ -145,6 +153,8 @@ def _markdownify_html(self, html: str, result: FetchResult) -> str: def _bs4_extract(self, html: str, result: FetchResult) -> str: """Extract text using BeautifulSoup.""" + if BeautifulSoup is None: + raise RuntimeError("BeautifulSoup import unexpectedly unavailable") soup = BeautifulSoup(html, "html.parser") title_tag = soup.find("title") diff --git a/core/tools/web/middleware.py b/core/tools/web/middleware.py index fedf1708e..1cfef8827 100644 --- a/core/tools/web/middleware.py +++ b/core/tools/web/middleware.py @@ -103,8 +103,8 @@ async def _web_search_impl( self, Query: str, MaxResults: int | None = None, - IncludeDomains: list[str] | None = None, - ExcludeDomains: list[str] | None = None, + AllowedDomains: list[str] | None = None, + BlockedDomains: list[str] | None = None, ) -> SearchResult: """ 实现 web_search(多提供商降级) @@ -121,8 +121,8 @@ async def _web_search_impl( result = await searcher.search( query=Query, max_results=max_results, - include_domains=IncludeDomains, - exclude_domains=ExcludeDomains, + include_domains=AllowedDomains, + exclude_domains=BlockedDomains, ) if not result.error: return result @@ -217,12 +217,12 @@ def _get_tool_definitions(self) -> list[dict]: "type": "integer", "description": "Maximum number of results (default: 5)", }, - "IncludeDomains": { + "AllowedDomains": { "type": "array", "items": {"type": "string"}, "description": "Only include results from these domains", }, - "ExcludeDomains": { + "BlockedDomains": { "type": "array", "items": {"type": "string"}, "description": "Exclude results from these domains", @@ -281,8 +281,8 @@ async def _handle_tool_call(self, tool_name: str, args: dict, tool_call_id: str) result = await self._web_search_impl( Query=args.get("Query", ""), MaxResults=args.get("MaxResults"), - IncludeDomains=args.get("IncludeDomains"), - ExcludeDomains=args.get("ExcludeDomains"), + AllowedDomains=args.get("AllowedDomains"), + BlockedDomains=args.get("BlockedDomains"), ) return ToolMessage(content=result.format_output(), tool_call_id=tool_call_id) @@ -304,7 +304,8 @@ async def awrap_tool_call( tool_call = request.tool_call tool_name = tool_call.get("name") args = tool_call.get("args", {}) - tool_call_id = tool_call.get("id", "") + raw_tool_call_id = tool_call.get("id", "") + tool_call_id = raw_tool_call_id if isinstance(raw_tool_call_id, str) else "" result = await self._handle_tool_call(tool_name, args, tool_call_id) if result is not None: diff --git a/core/tools/web/service.py b/core/tools/web/service.py index 077db9b70..02d2f12e8 100644 --- a/core/tools/web/service.py +++ b/core/tools/web/service.py @@ -10,7 +10,7 @@ import asyncio from typing import Any -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema from core.tools.web.fetchers.jina import JinaFetcher from core.tools.web.fetchers.markdownify import MarkdownifyFetcher from core.tools.web.searchers.exa import ExaSearcher @@ -59,64 +59,74 @@ def _register(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( name="WebSearch", - mode=ToolMode.INLINE, - schema={ - "name": "WebSearch", - "description": "Search the web for current information. Returns titles, URLs, and snippets.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query", - }, - "max_results": { - "type": "integer", - "description": "Maximum number of results (default: 5)", - }, - "include_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Only include results from these domains", - }, - "exclude_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Exclude results from these domains", - }, + mode=ToolMode.DEFERRED, + schema=make_tool_schema( + name="WebSearch", + description=( + "Search the web. Returns titles, URLs, and text snippets. " + "Use for current events, documentation lookups, or fact-checking. Max 10 results per query." + ), + properties={ + "query": { + "type": "string", + "description": "Search query", + "minLength": 1, + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results (default: 5)", + "minimum": 1, + "maximum": 10, + }, + "allowed_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Only include results from these domains", + }, + "blocked_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Exclude results from these domains", }, - "required": ["query"], }, - }, + required=["query"], + ), handler=self._web_search, source="WebService", + is_concurrency_safe=True, + is_read_only=True, ) ) registry.register( ToolEntry( name="WebFetch", - mode=ToolMode.INLINE, - schema={ - "name": "WebFetch", - "description": "Fetch a URL and extract specific information using AI. Returns processed content, not raw HTML.", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "URL to fetch content from", - }, - "prompt": { - "type": "string", - "description": "What information to extract from the page", - }, + mode=ToolMode.DEFERRED, + schema=make_tool_schema( + name="WebFetch", + description=( + "Fetch a URL and extract specific information via AI. Returns processed text, not raw HTML. " + "Provide a focused prompt describing what to extract. " + "Useful for reading documentation pages, API references, or articles." + ), + properties={ + "url": { + "type": "string", + "description": "URL to fetch content from", + "minLength": 1, + }, + "prompt": { + "type": "string", + "description": "What information to extract from the page", + "minLength": 1, }, - "required": ["url", "prompt"], }, - }, + required=["url", "prompt"], + ), handler=self._web_fetch, source="WebService", + is_concurrency_safe=True, + is_read_only=True, ) ) @@ -124,8 +134,8 @@ async def _web_search( self, query: str, max_results: int | None = None, - include_domains: list[str] | None = None, - exclude_domains: list[str] | None = None, + allowed_domains: list[str] | None = None, + blocked_domains: list[str] | None = None, ) -> str: if not self._searchers: return "No search providers configured" @@ -137,8 +147,8 @@ async def _web_search( result: SearchResult = await searcher.search( query=query, max_results=effective_max, - include_domains=include_domains, - exclude_domains=exclude_domains, + include_domains=allowed_domains, + exclude_domains=blocked_domains, ) if not result.error: return result.format_output() diff --git a/core/tools/wechat/service.py b/core/tools/wechat/service.py deleted file mode 100644 index 9cb57e233..000000000 --- a/core/tools/wechat/service.py +++ /dev/null @@ -1,109 +0,0 @@ -"""WeChat tool service — registers wechat_send and wechat_contacts into ToolRegistry. - -Thin wrapper: actual API calls go through WeChatConnection (backend). -Tools are scoped to the agent's owner's user_id (the human who connected WeChat). -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import TYPE_CHECKING - -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry - -if TYPE_CHECKING: - from backend.web.services.wechat_service import WeChatConnection - -logger = logging.getLogger(__name__) - - -class WeChatToolService: - """Registers WeChat tools for agents to interact with WeChat contacts. - - @@@lazy-connection — connection_fn is called at tool invocation time, not registration. - This avoids import-time dependency on app.state. - """ - - def __init__(self, registry: ToolRegistry, connection_fn: Callable[[], WeChatConnection | None]) -> None: - self._get_conn = connection_fn - self._register(registry) - - def _register(self, registry: ToolRegistry) -> None: - self._register_wechat_send(registry) - self._register_wechat_contacts(registry) - - def _register_wechat_send(self, registry: ToolRegistry) -> None: - get_conn = self._get_conn - - async def handle(user_id: str, text: str) -> str: - conn = get_conn() - if not conn or not conn.connected: - return "Error: WeChat is not connected. Ask the owner to connect via the Connections page." - try: - await conn.send_message(user_id, text) - return f"Message sent to {user_id.split('@')[0]}" - except RuntimeError as e: - return f"Error: {e}" - - registry.register( - ToolEntry( - name="wechat_send", - mode=ToolMode.INLINE, - schema={ - "name": "wechat_send", - "description": ( - "Send a text message to a WeChat user via the connected WeChat bot.\n" - "Use wechat_contacts to find available user_ids.\n" - "The user must have messaged the bot first before you can reply.\n" - "Keep messages concise — WeChat is a chat app. Use plain text, no markdown." - ), - "parameters": { - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "WeChat user ID (format: xxx@im.wechat). Get from wechat_contacts.", - }, - "text": { - "type": "string", - "description": "Plain text message to send. No markdown — WeChat won't render it.", - }, - }, - "required": ["user_id", "text"], - }, - }, - handler=handle, - source="wechat", - ) - ) - - def _register_wechat_contacts(self, registry: ToolRegistry) -> None: - get_conn = self._get_conn - - def handle() -> str: - conn = get_conn() - if not conn or not conn.connected: - return "WeChat is not connected." - contacts = conn.list_contacts() - if not contacts: - return "No WeChat contacts yet. Users need to message the bot first." - lines = [f"- {c['display_name']} [user_id: {c['user_id']}]" for c in contacts] - return "\n".join(lines) - - registry.register( - ToolEntry( - name="wechat_contacts", - mode=ToolMode.INLINE, - schema={ - "name": "wechat_contacts", - "description": "List WeChat contacts who have messaged the bot. Returns user_ids for use with wechat_send.", - "parameters": { - "type": "object", - "properties": {}, - }, - }, - handler=handle, - source="wechat", - ) - ) diff --git a/docker-compose.yml b/docker-compose.yml index cb302edf3..15c3e7c7a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,10 @@ services: build: context: . dockerfile: Dockerfile + volumes: + # @@@staging-leon-home-volume - staging runtime state (models/members/sandboxes) + # must survive container replacement, otherwise each deploy boots with an empty ~/.leon. + - leon-home:/root/.leon restart: unless-stopped frontend: @@ -14,3 +18,6 @@ services: depends_on: - backend restart: unless-stopped + +volumes: + leon-home: diff --git a/docs/en/introduction.mdx b/docs/en/introduction.mdx index 306238336..84e35bd7d 100644 --- a/docs/en/introduction.mdx +++ b/docs/en/introduction.mdx @@ -49,7 +49,7 @@ flowchart LR direction LR H["Human Entity"] A["Agent Entity"] - H <-->|chat_send / chat_read| A + H <-->|send_message / read_messages| A end subgraph Infra["Infrastructure"] diff --git a/docs/en/multi-agent-chat.mdx b/docs/en/multi-agent-chat.mdx index 6a10e8fec..2da8a8591 100644 --- a/docs/en/multi-agent-chat.mdx +++ b/docs/en/multi-agent-chat.mdx @@ -3,7 +3,7 @@ title: Multi-agent chat sidebarTitle: Social layer description: How humans and agents communicate on the Mycel social layer icon: comments -keywords: [entity, chat, agent communication, social, directory, chat_send, SSE] +keywords: [entity, chat, agent communication, social, list_chats, send_message, SSE] --- Mycel's social layer lets humans and agents coexist as equals in a shared messaging environment. Agents can initiate conversations, forward context to teammates, and collaborate autonomously — without any special orchestration code. @@ -19,7 +19,7 @@ flowchart LR direction TB HE["Human Entity"] AE["Agent Entity"] - HE <-->|"chat_send / chat_read"| AE + HE <-->|"send_message / read_messages"| AE end T --> Chat @@ -53,42 +53,33 @@ Every participant on the platform — human or agent — has an **Entity**. When ## Agent chat tools -Agents have five built-in tools for social interaction: +Agents have four built-in tools for social interaction: - - Browse all known Entities. Returns Entity IDs needed for other tools. - - ```text - directory(search="Alice", type="human") - → - Alice [human] entity_id=m_abc123-1 - ``` - - - + List the agent's active chats with unread counts and last message preview. ```text - chats(unread_only=true) + list_chats(unread_only=true) → - Alice [m_abc123-1] (3 unread) — last: "Can you help me with..." ``` - + Read message history in a chat. Automatically marks messages as read. ```text - chat_read(entity_id="m_abc123-1", limit=10) + read_messages(entity_id="m_abc123-1", limit=10) → [Alice]: Can you help me with this bug? [you]: Sure, let me take a look. ``` - + Send a message. The agent must read unread messages before sending (enforced by the system). ```text - chat_send(content="Here's the fix.", entity_id="m_abc123-1") + send_message(content="Here's the fix.", entity_id="m_abc123-1") ``` **Signal protocol** controls conversation flow: @@ -100,11 +91,11 @@ Agents have five built-in tools for social interaction: | `close` | "Conversation over, do not reply" | - + Search through message history across all chats or within a specific chat. ```text - chat_search(query="bug fix", entity_id="m_abc123-1") + search_messages(query="bug fix", entity_id="m_abc123-1") ``` @@ -124,15 +115,15 @@ sequenceDiagram API->>H: SSE push (message event) API->>Q: Enqueue notification Q->>T: Wake thread (if idle) - T->>API: chat_read (get actual message) + T->>API: read_messages (get actual message) T->>T: Process message - T->>API: chat_send (response) + T->>API: send_message (response) API->>DB: Store response API->>H: SSE push (message event) ``` - Notifications don't include message content — the agent must call `chat_read` to read them. This enforces a consistent **read → respond** pattern and prevents agents from acting on stale summaries. + Notifications don't include message content — the agent must call `read_messages` to read them. This enforces a consistent **read → respond** pattern and prevents agents from acting on stale summaries. ## Real-time updates diff --git a/docs/en/quickstart.mdx b/docs/en/quickstart.mdx index 91954831c..204f99163 100644 --- a/docs/en/quickstart.mdx +++ b/docs/en/quickstart.mdx @@ -100,7 +100,7 @@ Mycel's social layer lets agents message each other — and you — like a group - In the first agent's thread, tell it to message your code reviewer: "Ask the code reviewer to look at this function." The agent will call `chat_send` and the reviewer will respond autonomously. + In the first agent's thread, tell it to message your code reviewer: "Ask the code reviewer to look at this function." The agent will call `send_message` and the reviewer will respond autonomously. diff --git a/docs/zh/introduction.mdx b/docs/zh/introduction.mdx index fdc5e8693..9566e8cfe 100644 --- a/docs/zh/introduction.mdx +++ b/docs/zh/introduction.mdx @@ -49,7 +49,7 @@ flowchart LR direction LR H["人类 Entity"] A["Agent Entity"] - H <-->|"chat_send / chat_read"| A + H <-->|"send_message / read_messages"| A end subgraph Infra["基础设施"] diff --git a/docs/zh/multi-agent-chat.mdx b/docs/zh/multi-agent-chat.mdx index 3a44bd48c..4fb44940a 100644 --- a/docs/zh/multi-agent-chat.mdx +++ b/docs/zh/multi-agent-chat.mdx @@ -3,7 +3,7 @@ title: 多 Agent 通讯 sidebarTitle: 社交层 description: 人与 Agent 如何在 Mycel 社交层中通讯 icon: comments -keywords: [entity, chat, agent 通讯, 社交, directory, chat_send, SSE] +keywords: [entity, chat, agent 通讯, 社交, list_chats, send_message, SSE] --- Mycel 的社交层让人与 Agent 在共享的消息环境中平等共存。Agent 可以主动发起对话、把上下文转发给队友、自主协作 — 无需任何特殊的编排代码。 @@ -19,7 +19,7 @@ flowchart LR direction TB HE["人类 Entity"] AE["Agent Entity"] - HE <-->|"chat_send / chat_read"| AE + HE <-->|"send_message / read_messages"| AE end T --> Chat @@ -52,39 +52,30 @@ flowchart LR ## Agent 聊天工具 - - 浏览所有已知的 Entity,返回其他工具需要的 Entity ID。 - - ```text - directory(search="Alice", type="human") - → - Alice [human] entity_id=m_abc123-1 - ``` - - - + 列出 Agent 的活跃对话,包含未读数和最新消息预览。 ```text - chats(unread_only=true) + list_chats(unread_only=true) → - Alice [m_abc123-1] (3 条未读) — 最新:"能帮我看看..." ``` - + 读取对话消息历史,自动标记为已读。 ```text - chat_read(entity_id="m_abc123-1", limit=10) + read_messages(entity_id="m_abc123-1", limit=10) → [Alice]: 能帮我看看这个 bug 吗? [you]: 好的,我来看看。 ``` - + 发送消息。系统强制要求 Agent 先读取未读消息再发送。 ```text - chat_send(content="这是修复方案。", entity_id="m_abc123-1") + send_message(content="这是修复方案。", entity_id="m_abc123-1") ``` **信号协议**控制对话流转: @@ -96,11 +87,11 @@ flowchart LR | `close` | "对话结束,不需要回复" | - + 在所有对话或指定对话中搜索消息历史。 ```text - chat_search(query="bug 修复", entity_id="m_abc123-1") + search_messages(query="bug 修复", entity_id="m_abc123-1") ``` @@ -120,15 +111,15 @@ sequenceDiagram API->>H: SSE 推送(message 事件) API->>Q: 入队通知 Q->>T: 唤醒 Thread(若空闲) - T->>API: chat_read(读取实际消息) + T->>API: read_messages(读取实际消息) T->>T: 处理消息 - T->>API: chat_send(回复) + T->>API: send_message(回复) API->>DB: 存储回复 API->>H: SSE 推送(message 事件) ``` - 通知不包含消息内容 — Agent 必须调用 `chat_read` 才能读到。这强制执行「先读后发」的一致模式。 + 通知不包含消息内容 — Agent 必须调用 `read_messages` 才能读到。这强制执行「先读后发」的一致模式。 ## 联系人与投递设置 diff --git a/docs/zh/quickstart.mdx b/docs/zh/quickstart.mdx index 884bf09f4..37c67e8c8 100644 --- a/docs/zh/quickstart.mdx +++ b/docs/zh/quickstart.mdx @@ -100,7 +100,7 @@ Mycel 的社交层让 Agent 之间可以像群聊一样互相发消息。 - 在第一个 Agent 的 Thread 中,告诉它去联系代码审查员:「帮我把这个函数发给代码审查员看看。」Agent 会调用 `chat_send` 工具,审查员会自主回复。 + 在第一个 Agent 的 Thread 中,告诉它去联系代码审查员:「帮我把这个函数发给代码审查员看看。」Agent 会调用 `send_message` 工具,审查员会自主回复。 diff --git a/eval/storage.py b/eval/storage.py index 2dd75c523..ba389cdd1 100644 --- a/eval/storage.py +++ b/eval/storage.py @@ -1,7 +1,4 @@ -"""SQLite storage for eval trajectories and metrics. - -Database: ~/.leon/eval.db (separate from main leon.db) -""" +"""Storage for eval trajectories and metrics.""" from __future__ import annotations @@ -9,28 +6,28 @@ from datetime import UTC from pathlib import Path -from config.user_paths import user_home_path from eval.models import ( ObjectiveMetrics, RunTrajectory, SystemMetrics, ) -from eval.repo import SQLiteEvalRepo - -_DEFAULT_DB_PATH = user_home_path("eval.db") class TrajectoryStore: - """SQLite-backed storage for eval trajectories and metrics.""" + """Storage for eval trajectories and metrics.""" + + def __init__(self, db_path: str | Path | None = None, eval_repo=None): + if eval_repo is not None: + self._repo = eval_repo + else: + from storage.runtime import build_storage_container - def __init__(self, db_path: str | Path | None = None): - self.db_path = Path(db_path) if db_path else _DEFAULT_DB_PATH - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self._repo = SQLiteEvalRepo(self.db_path) - self._init_db() + container = build_storage_container() + self._repo = container.eval_repo() def _init_db(self) -> None: - self._repo.ensure_schema() + if hasattr(self._repo, "ensure_schema"): + self._repo.ensure_schema() def save_trajectory(self, trajectory: RunTrajectory) -> str: """Save a trajectory and its LLM/tool call records. Returns run_id.""" diff --git a/frontend/app/.env.example b/frontend/app/.env.example new file mode 100644 index 000000000..abfdc2804 --- /dev/null +++ b/frontend/app/.env.example @@ -0,0 +1,2 @@ +VITE_SUPABASE_URL= +VITE_SUPABASE_ANON_KEY= diff --git a/frontend/app/DESIGN_SYSTEM.md b/frontend/app/DESIGN_SYSTEM.md index 5043fe083..62ae20435 100644 --- a/frontend/app/DESIGN_SYSTEM.md +++ b/frontend/app/DESIGN_SYSTEM.md @@ -186,7 +186,6 @@ These are **not** motion tokens. Import from `@/styles/ux-timing`. |----------|-------|-------| | `FEEDBACK_BRIEF` | 1500ms | Copy confirmation, save flash | | `FEEDBACK_NORMAL` | 2000ms | Toast display, status message | -| `BLUR_CLOSE_DELAY` | 150ms | Dropdown close delay on blur | ### Rules diff --git a/frontend/app/package-lock.json b/frontend/app/package-lock.json index 8af285c77..e0f68e798 100644 --- a/frontend/app/package-lock.json +++ b/frontend/app/package-lock.json @@ -35,6 +35,7 @@ "@radix-ui/react-toggle": "^1.1.10", "@radix-ui/react-toggle-group": "^1.1.11", "@radix-ui/react-tooltip": "^1.2.8", + "@supabase/supabase-js": "^2.49.8", "@types/diff": "^7.0.2", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", @@ -62,6 +63,7 @@ }, "devDependencies": { "@eslint/js": "^9.39.1", + "@testing-library/react": "^16.3.2", "@types/node": "^24.10.1", "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", @@ -71,6 +73,7 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.4.24", "globals": "^16.5.0", + "jsdom": "^28.1.0", "kimi-plugin-inspect-react": "^1.0.3", "postcss": "^8.5.6", "tailwindcss": "^3.4.19", @@ -78,9 +81,17 @@ "tw-animate-css": "^1.4.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "^7.2.4", + "vitest": "^4.1.2" } }, + "node_modules/@acemir/cssom": { + "version": "0.9.31", + "resolved": "https://registry.npmjs.org/@acemir/cssom/-/cssom-0.9.31.tgz", + "integrity": "sha512-ZnR3GSaH+/vJ0YlHau21FjfLYjMpYVIzTD8M8vIEQvIGxeOXyXdzCI140rrCY862p/C/BbzWsjc1dgnM9mkoTA==", + "dev": true, + "license": "MIT" + }, "node_modules/@alloc/quick-lru": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", @@ -94,6 +105,64 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@asamuzakjp/css-color": { + "version": "5.1.5", + "resolved": "https://registry.npmjs.org/@asamuzakjp/css-color/-/css-color-5.1.5.tgz", + "integrity": "sha512-8cMAA1bE66Mb/tfmkhcfJLjEPgyT7SSy6lW6id5XL113ai1ky76d/1L27sGnXCMsLfq66DInAU3OzuahB4lu9Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@csstools/css-calc": "^3.1.1", + "@csstools/css-color-parser": "^4.0.2", + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0", + "lru-cache": "^11.2.7" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, + "node_modules/@asamuzakjp/css-color/node_modules/lru-cache": { + "version": "11.3.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.3.0.tgz", + "integrity": "sha512-sr8xPKE25m6vJVcrdn6NxtC0fVfuPowbscLypegRgOm0yXSqr5JNHCAY3hnusdJ7HRBW04j6Ip4khvHU778DuQ==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/dom-selector": { + "version": "6.8.1", + "resolved": "https://registry.npmjs.org/@asamuzakjp/dom-selector/-/dom-selector-6.8.1.tgz", + "integrity": "sha512-MvRz1nCqW0fsy8Qz4dnLIvhOlMzqDVBabZx6lH+YywFDdjXhMY37SmpV1XFX3JzG5GWHn63j6HX6QPr3lZXHvQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/nwsapi": "^2.3.9", + "bidi-js": "^1.0.3", + "css-tree": "^3.1.0", + "is-potential-custom-element-name": "^1.0.1", + "lru-cache": "^11.2.6" + } + }, + "node_modules/@asamuzakjp/dom-selector/node_modules/lru-cache": { + "version": "11.3.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.3.0.tgz", + "integrity": "sha512-sr8xPKE25m6vJVcrdn6NxtC0fVfuPowbscLypegRgOm0yXSqr5JNHCAY3hnusdJ7HRBW04j6Ip4khvHU778DuQ==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/nwsapi": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/@asamuzakjp/nwsapi/-/nwsapi-2.3.9.tgz", + "integrity": "sha512-n8GuYSrI9bF7FFZ/SjhwevlHc8xaVlb/7HmHelnc/PZXBD2ZR49NnN9sMMuDdEGPeeRQ5d0hqlSlEpgCX3Wl0Q==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -1846,6 +1915,161 @@ "node": ">=6.9.0" } }, + "node_modules/@bramus/specificity": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/@bramus/specificity/-/specificity-2.4.2.tgz", + "integrity": "sha512-ctxtJ/eA+t+6q2++vj5j7FYX3nRu311q1wfYH3xjlLOsczhlhxAg2FWNUXhpGvAw3BWo1xBcvOV6/YLc2r5FJw==", + "dev": true, + "license": "MIT", + "dependencies": { + "css-tree": "^3.0.0" + }, + "bin": { + "specificity": "bin/cli.js" + } + }, + "node_modules/@csstools/color-helpers": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/@csstools/color-helpers/-/color-helpers-6.0.2.tgz", + "integrity": "sha512-LMGQLS9EuADloEFkcTBR3BwV/CGHV7zyDxVRtVDTwdI2Ca4it0CCVTT9wCkxSgokjE5Ho41hEPgb8OEUwoXr6Q==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "engines": { + "node": ">=20.19.0" + } + }, + "node_modules/@csstools/css-calc": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@csstools/css-calc/-/css-calc-3.1.1.tgz", + "integrity": "sha512-HJ26Z/vmsZQqs/o3a6bgKslXGFAungXGbinULZO3eMsOyNJHeBBZfup5FiZInOghgoM4Hwnmw+OgbJCNg1wwUQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-color-parser": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@csstools/css-color-parser/-/css-color-parser-4.0.2.tgz", + "integrity": "sha512-0GEfbBLmTFf0dJlpsNU7zwxRIH0/BGEMuXLTCvFYxuL1tNhqzTbtnFICyJLTNK4a+RechKP75e7w42ClXSnJQw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "dependencies": { + "@csstools/color-helpers": "^6.0.2", + "@csstools/css-calc": "^3.1.1" + }, + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-parser-algorithms": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@csstools/css-parser-algorithms/-/css-parser-algorithms-4.0.0.tgz", + "integrity": "sha512-+B87qS7fIG3L5h3qwJ/IFbjoVoOe/bpOdh9hAjXbvx0o8ImEmUsGXN0inFOnk2ChCFgqkkGFQ+TpM5rbhkKe4w==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "peer": true, + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-syntax-patches-for-csstree": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@csstools/css-syntax-patches-for-csstree/-/css-syntax-patches-for-csstree-1.1.2.tgz", + "integrity": "sha512-5GkLzz4prTIpoyeUiIu3iV6CSG3Plo7xRVOFPKI7FVEJ3mZ0A8SwK0XU3Gl7xAkiQ+mDyam+NNp875/C5y+jSA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "peerDependencies": { + "css-tree": "^3.2.1" + }, + "peerDependenciesMeta": { + "css-tree": { + "optional": true + } + } + }, + "node_modules/@csstools/css-tokenizer": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@csstools/css-tokenizer/-/css-tokenizer-4.0.0.tgz", + "integrity": "sha512-QxULHAm7cNu72w97JUNCBFODFaXpbDg+dP8b/oWFAZ2MTRppA3U00Y2L1HqaS4J6yBqxwa/Y3nMBaxVKbB/NsA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "peer": true, + "engines": { + "node": ">=20.19.0" + } + }, "node_modules/@date-fns/tz": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/@date-fns/tz/-/tz-1.4.1.tgz", @@ -2451,6 +2675,24 @@ "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, + "node_modules/@exodus/bytes": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/@exodus/bytes/-/bytes-1.15.0.tgz", + "integrity": "sha512-UY0nlA+feH81UGSHv92sLEPLCeZFjXOuHhrIo0HQydScuQc8s0A7kL/UdgwgDq8g8ilksmuoF35YVTNphV2aBQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "@noble/hashes": "^1.8.0 || ^2.0.0" + }, + "peerDependenciesMeta": { + "@noble/hashes": { + "optional": true + } + } + }, "node_modules/@floating-ui/core": { "version": "1.7.3", "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.3.tgz", @@ -4607,12 +4849,161 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "dev": true, + "license": "MIT" + }, "node_modules/@standard-schema/utils": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/@standard-schema/utils/-/utils-0.3.0.tgz", "integrity": "sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==", "license": "MIT" }, + "node_modules/@supabase/auth-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/auth-js/-/auth-js-2.101.1.tgz", + "integrity": "sha512-Kd0Wey+RkFHgyVep7adS6UOE2pN6MJ3mZ32PAXSvfw6IjUkFRC7IQpdZZjUOcUe5pXr1ejufCRgF6lsGINe4Tw==", + "license": "MIT", + "dependencies": { + "tslib": "2.8.1" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@supabase/functions-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/functions-js/-/functions-js-2.101.1.tgz", + "integrity": "sha512-OZWU7YtaG+NNNFZK8p/FuJ6gpq7pFyrG2fLOopP73HAIDHDGpOttPJapvO8ADu3RkqfQfkwrB354vPkSBbZ20A==", + "license": "MIT", + "dependencies": { + "tslib": "2.8.1" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@supabase/phoenix": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/@supabase/phoenix/-/phoenix-0.4.0.tgz", + "integrity": "sha512-RHSx8bHS02xwfHdAbX5Lpbo6PXbgyf7lTaXTlwtFDPwOIw64NnVRwFAXGojHhjtVYI+PEPNSWwkL90f4agN3bw==", + "license": "MIT" + }, + "node_modules/@supabase/postgrest-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/postgrest-js/-/postgrest-js-2.101.1.tgz", + "integrity": "sha512-UW1RajH5jbZoK+ldAJ1I6VZ+HWwZ2oaKjEQ6Gn+AQ67CHQVxGl8wNQoLYyumbyaExm41I+wn7arulcY1eHeZJw==", + "license": "MIT", + "dependencies": { + "tslib": "2.8.1" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@supabase/realtime-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/realtime-js/-/realtime-js-2.101.1.tgz", + "integrity": "sha512-Oa6dno0OB9I+hv5do5zsZHbFu41ViZnE9IWjmkeeF/8fPmB5fWoHGqeTYEC3/0DAgtpUoFJa4FpvzFH0SBHo1Q==", + "license": "MIT", + "dependencies": { + "@supabase/phoenix": "^0.4.0", + "@types/ws": "^8.18.1", + "tslib": "2.8.1", + "ws": "^8.18.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@supabase/storage-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/storage-js/-/storage-js-2.101.1.tgz", + "integrity": "sha512-WhTaUOBgeEvnKLy95Cdlp6+D5igSF/65yC727w1olxbet5nzUvMlajKUWyzNtQu2efrz2cQ7FcdVBdQqgT9YKQ==", + "license": "MIT", + "dependencies": { + "iceberg-js": "^0.8.1", + "tslib": "2.8.1" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@supabase/supabase-js": { + "version": "2.101.1", + "resolved": "https://registry.npmjs.org/@supabase/supabase-js/-/supabase-js-2.101.1.tgz", + "integrity": "sha512-Jnhm3LfuACwjIzvk2pfUbGQn7pa7hi6MFzfSyPrRYWVCCu69RPLCFyHSBl7HSBwadbQ3UZOznnD3gPca3ePrRA==", + "license": "MIT", + "dependencies": { + "@supabase/auth-js": "2.101.1", + "@supabase/functions-js": "2.101.1", + "@supabase/postgrest-js": "2.101.1", + "@supabase/realtime-js": "2.101.1", + "@supabase/storage-js": "2.101.1" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@testing-library/dom": { + "version": "10.4.1", + "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-10.4.1.tgz", + "integrity": "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/code-frame": "^7.10.4", + "@babel/runtime": "^7.12.5", + "@types/aria-query": "^5.0.1", + "aria-query": "5.3.0", + "dom-accessibility-api": "^0.5.9", + "lz-string": "^1.5.0", + "picocolors": "1.1.1", + "pretty-format": "^27.0.2" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@testing-library/react": { + "version": "16.3.2", + "resolved": "https://registry.npmjs.org/@testing-library/react/-/react-16.3.2.tgz", + "integrity": "sha512-XU5/SytQM+ykqMnAnvB2umaJNIOsLF3PVv//1Ew4CTcpz0/BRyy/af40qqrt7SjKpDdT1saBMc42CUok5gaw+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.12.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@testing-library/dom": "^10.0.0", + "@types/react": "^18.0.0 || ^19.0.0", + "@types/react-dom": "^18.0.0 || ^19.0.0", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@types/aria-query": { + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", + "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -4658,6 +5049,17 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/chai": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", + "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, "node_modules/@types/d3-array": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", @@ -4730,6 +5132,13 @@ "@types/ms": "*" } }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", + "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/diff": { "version": "7.0.2", "resolved": "https://registry.npmjs.org/@types/diff/-/diff-7.0.2.tgz", @@ -4786,9 +5195,7 @@ "version": "24.10.4", "resolved": "https://registry.npmjs.org/@types/node/-/node-24.10.4.tgz", "integrity": "sha512-vnDVpYPMzs4wunl27jHrfmwojOGKya0xyM3sH+UE5iv5uPS6vX7UIoh6m+vQc5LGBq52HBKPIn/zcSZVzeDEZg==", - "dev": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -4821,6 +5228,15 @@ "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==", "license": "MIT" }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.52.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.52.0.tgz", @@ -5118,29 +5534,152 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, - "node_modules/acorn": { - "version": "8.15.0", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", - "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "node_modules/@vitest/expect": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.1.2.tgz", + "integrity": "sha512-gbu+7B0YgUJ2nkdsRJrFFW6X7NTP44WlhiclHniUhxADQJH5Szt9mZ9hWnJPJ8YwOK5zUOSSlSvyzRf0u1DSBQ==", "dev": true, "license": "MIT", - "peer": true, - "bin": { - "acorn": "bin/acorn" + "dependencies": { + "@standard-schema/spec": "^1.1.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.1.2", + "@vitest/utils": "4.1.2", + "chai": "^6.2.2", + "tinyrainbow": "^3.1.0" }, - "engines": { - "node": ">=0.4.0" + "funding": { + "url": "https://opencollective.com/vitest" } }, - "node_modules/acorn-jsx": { - "version": "5.3.2", - "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", - "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "node_modules/@vitest/mocker": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.1.2.tgz", + "integrity": "sha512-Ize4iQtEALHDttPRCmN+FKqOl2vxTiNUhzobQFFt/BM1lRUTG7zRCLOykG/6Vo4E4hnUdfVLo5/eqKPukcWW7Q==", "dev": true, "license": "MIT", - "peerDependencies": { - "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" - } + "dependencies": { + "@vitest/spy": "4.1.2", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.1.2.tgz", + "integrity": "sha512-dwQga8aejqeuB+TvXCMzSQemvV9hNEtDDpgUKDzOmNQayl2OG241PSWeJwKRH3CiC+sESrmoFd49rfnq7T4RnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.1.2.tgz", + "integrity": "sha512-Gr+FQan34CdiYAwpGJmQG8PgkyFVmARK8/xSijia3eTFgVfpcpztWLuP6FttGNfPLJhaZVP/euvujeNYar36OQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "4.1.2", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.1.2.tgz", + "integrity": "sha512-g7yfUmxYS4mNxk31qbOYsSt2F4m1E02LFqO53Xpzg3zKMhLAPZAjjfyl9e6z7HrW6LvUdTwAQR3HHfLjpko16A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.1.2", + "@vitest/utils": "4.1.2", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.1.2.tgz", + "integrity": "sha512-DU4fBnbVCJGNBwVA6xSToNXrkZNSiw59H8tcuUspVMsBDBST4nfvsPsEHDHGtWRRnqBERBQu7TrTKskmjqTXKA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/utils": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.1.2.tgz", + "integrity": "sha512-xw2/TiX82lQHA06cgbqRKFb5lCAy3axQ4H4SoUFhUsg+wztiet+co86IAMDtF6Vm1hc7J6j09oh/rgDn+JdKIQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.1.2", + "convert-source-map": "^2.0.0", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "license": "MIT", + "peer": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14" + } }, "node_modules/ajv": { "version": "6.12.6", @@ -5159,6 +5698,16 @@ "url": "https://github.com/sponsors/epoberezkin" } }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/ansi-styles": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", @@ -5235,6 +5784,26 @@ "node": ">=10" } }, + "node_modules/aria-query": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz", + "integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "dequal": "^2.0.3" + } + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/autoprefixer": { "version": "10.4.23", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.23.tgz", @@ -5341,6 +5910,16 @@ "baseline-browser-mapping": "dist/cli.js" } }, + "node_modules/bidi-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/bidi-js/-/bidi-js-1.0.3.tgz", + "integrity": "sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==", + "dev": true, + "license": "MIT", + "dependencies": { + "require-from-string": "^2.0.2" + } + }, "node_modules/binary-extensions": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", @@ -5464,6 +6043,16 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/chai": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.2.tgz", + "integrity": "sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -5692,6 +6281,21 @@ "node": ">= 8" } }, + "node_modules/css-tree": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.2.1.tgz", + "integrity": "sha512-X7sjQzceUhu1u7Y/ylrRZFU2FS6LRiFVp6rKLPg23y3x3c3DOKAwuXGDp+PAGjh6CSnCjYeAul8pcT8bAl+lSA==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "mdn-data": "2.27.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0" + } + }, "node_modules/cssesc": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", @@ -5705,6 +6309,32 @@ "node": ">=4" } }, + "node_modules/cssstyle": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/cssstyle/-/cssstyle-6.2.0.tgz", + "integrity": "sha512-Fm5NvhYathRnXNVndkUsCCuR63DCLVVwGOOwQw782coXFi5HhkXdu289l59HlXZBawsyNccXfWRYvLzcDCdDig==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/css-color": "^5.0.1", + "@csstools/css-syntax-patches-for-csstree": "^1.0.28", + "css-tree": "^3.1.0", + "lru-cache": "^11.2.6" + }, + "engines": { + "node": ">=20" + } + }, + "node_modules/cssstyle/node_modules/lru-cache": { + "version": "11.3.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.3.0.tgz", + "integrity": "sha512-sr8xPKE25m6vJVcrdn6NxtC0fVfuPowbscLypegRgOm0yXSqr5JNHCAY3hnusdJ7HRBW04j6Ip4khvHU778DuQ==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, "node_modules/csstype": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", @@ -5832,6 +6462,20 @@ "node": ">=12" } }, + "node_modules/data-urls": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/data-urls/-/data-urls-7.0.0.tgz", + "integrity": "sha512-23XHcCF+coGYevirZceTVD7NdJOqVn+49IHyxgszm+JIiHLoB2TkmPtsYkNWT1pvRSGkc35L6NHs0yHkN2SumA==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-mimetype": "^5.0.0", + "whatwg-url": "^16.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/date-fns": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-4.1.0.tgz", @@ -5865,6 +6509,13 @@ } } }, + "node_modules/decimal.js": { + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", + "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", + "dev": true, + "license": "MIT" + }, "node_modules/decimal.js-light": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz", @@ -5942,6 +6593,13 @@ "dev": true, "license": "MIT" }, + "node_modules/dom-accessibility-api": { + "version": "0.5.16", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", + "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", + "dev": true, + "license": "MIT" + }, "node_modules/dom-helpers": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz", @@ -6000,6 +6658,13 @@ "url": "https://github.com/fb55/entities?sponsor=1" } }, + "node_modules/es-module-lexer": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", + "dev": true, + "license": "MIT" + }, "node_modules/esbuild": { "version": "0.27.2", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.2.tgz", @@ -6250,6 +6915,16 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -6266,6 +6941,16 @@ "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", "license": "MIT" }, + "node_modules/expect-type": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz", + "integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -6697,6 +7382,19 @@ "hermes-estree": "0.25.1" } }, + "node_modules/html-encoding-sniffer": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-6.0.0.tgz", + "integrity": "sha512-CV9TW3Y3f8/wT0BRFc1/KAVQ3TUHiXmaAb6VW9vtiMFf7SLoMd1PdAc4W3KFOFETBJUb90KatHqlsZMWV+R9Gg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@exodus/bytes": "^1.6.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/html-url-attributes": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", @@ -6717,6 +7415,43 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/iceberg-js": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/iceberg-js/-/iceberg-js-0.8.1.tgz", + "integrity": "sha512-1dhVQZXhcHje7798IVM+xoo/1ZdVfzOMIc8/rgVSijRK38EDqOJoGula9N/8ZI5RD8QTxNQtK/Gozpr+qUqRRA==", + "license": "MIT", + "engines": { + "node": ">=20.0.0" + } + }, "node_modules/ignore": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", @@ -6897,6 +7632,13 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-potential-custom-element-name": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", + "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", + "dev": true, + "license": "MIT" + }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", @@ -6934,6 +7676,60 @@ "js-yaml": "bin/js-yaml.js" } }, + "node_modules/jsdom": { + "version": "28.1.0", + "resolved": "https://registry.npmjs.org/jsdom/-/jsdom-28.1.0.tgz", + "integrity": "sha512-0+MoQNYyr2rBHqO1xilltfDjV9G7ymYGlAUazgcDLQaUf8JDHbuGwsxN6U9qWaElZ4w1B2r7yEGIL3GdeW3Rug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@acemir/cssom": "^0.9.31", + "@asamuzakjp/dom-selector": "^6.8.1", + "@bramus/specificity": "^2.4.2", + "@exodus/bytes": "^1.11.0", + "cssstyle": "^6.0.1", + "data-urls": "^7.0.0", + "decimal.js": "^10.6.0", + "html-encoding-sniffer": "^6.0.0", + "http-proxy-agent": "^7.0.2", + "https-proxy-agent": "^7.0.6", + "is-potential-custom-element-name": "^1.0.1", + "parse5": "^8.0.0", + "saxes": "^6.0.0", + "symbol-tree": "^3.2.4", + "tough-cookie": "^6.0.0", + "undici": "^7.21.0", + "w3c-xmlserializer": "^5.0.0", + "webidl-conversions": "^8.0.1", + "whatwg-mimetype": "^5.0.0", + "whatwg-url": "^16.0.0", + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "canvas": "^3.0.0" + }, + "peerDependenciesMeta": { + "canvas": { + "optional": true + } + } + }, + "node_modules/jsdom/node_modules/parse5": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/parse5/-/parse5-8.0.0.tgz", + "integrity": "sha512-9m4m5GSgXjL4AjumKzq1Fgfp3Z8rsvjRNbnkVwfu2ImRqE5D0LnY2QfDen18FSY9C573YU5XxSapdHZTZ2WolA==", + "dev": true, + "license": "MIT", + "dependencies": { + "entities": "^6.0.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, "node_modules/jsesc": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", @@ -7121,6 +7917,16 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/lz-string": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", + "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", + "dev": true, + "license": "MIT", + "bin": { + "lz-string": "bin/bin.js" + } + }, "node_modules/magic-string": { "version": "0.30.21", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz", @@ -7435,6 +8241,13 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/mdn-data": { + "version": "2.27.1", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.27.1.tgz", + "integrity": "sha512-9Yubnt3e8A0OKwxYSXyhLymGW4sCufcLG6VdiDdUGVkPhpqLxlvP5vl1983gQjJl3tqbrM731mjaZaP68AgosQ==", + "dev": true, + "license": "CC0-1.0" + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -8138,6 +8951,17 @@ "node": ">= 6" } }, + "node_modules/obug": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz", + "integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -8265,6 +9089,13 @@ "dev": true, "license": "MIT" }, + "node_modules/pathe": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", + "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", + "dev": true, + "license": "MIT" + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -8480,6 +9311,41 @@ "node": ">= 0.8.0" } }, + "node_modules/pretty-format": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz", + "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1", + "ansi-styles": "^5.0.0", + "react-is": "^17.0.1" + }, + "engines": { + "node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0" + } + }, + "node_modules/pretty-format/node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/pretty-format/node_modules/react-is": { + "version": "17.0.2", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", + "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", + "dev": true, + "license": "MIT" + }, "node_modules/prop-types": { "version": "15.8.1", "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", @@ -9008,6 +9874,16 @@ "integrity": "sha512-4ZJgIB9EG9fQE41mOJCRHMmnxDTKHWawQoJWZyUbZuj680wVyogu2ihnj8Edqm7vh2mo/TWHyEZpn2kqeDvS7w==", "license": "Apache-2.0" }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/resolve": { "version": "1.22.11", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", @@ -9119,6 +9995,19 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/saxes": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/saxes/-/saxes-6.0.0.tgz", + "integrity": "sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==", + "dev": true, + "license": "ISC", + "dependencies": { + "xmlchars": "^2.2.0" + }, + "engines": { + "node": ">=v12.22.7" + } + }, "node_modules/scheduler": { "version": "0.27.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", @@ -9164,6 +10053,13 @@ "node": ">=8" } }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, "node_modules/sonner": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/sonner/-/sonner-2.0.7.tgz", @@ -9194,6 +10090,20 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-4.0.0.tgz", + "integrity": "sha512-zUMPtQ/HBY3/50VbpkupYHbRroTRZJPRLvreamgErJVys0ceuzMkD44J/QjqhHjOzK42GQ3QZIeFG1OYfOtKqQ==", + "dev": true, + "license": "MIT" + }, "node_modules/streamdown": { "version": "2.4.0", "resolved": "https://registry.npmjs.org/streamdown/-/streamdown-2.4.0.tgz", @@ -9315,6 +10225,13 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/symbol-tree": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/symbol-tree/-/symbol-tree-3.2.4.tgz", + "integrity": "sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==", + "dev": true, + "license": "MIT" + }, "node_modules/tailwind-merge": { "version": "3.4.0", "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.4.0.tgz", @@ -9403,6 +10320,23 @@ "integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==", "license": "MIT" }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.4.tgz", + "integrity": "sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -9420,6 +10354,36 @@ "url": "https://github.com/sponsors/SuperchupuDev" } }, + "node_modules/tinyrainbow": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.1.0.tgz", + "integrity": "sha512-Bf+ILmBgretUrdJxzXM0SgXLZ3XfiaUuOj/IKQHuTXip+05Xn+uyEYdVg0kYDipTBcLrCVyUzAPz7QmArb0mmw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tldts": { + "version": "7.0.28", + "resolved": "https://registry.npmjs.org/tldts/-/tldts-7.0.28.tgz", + "integrity": "sha512-+Zg3vWhRUv8B1maGSTFdev9mjoo8Etn2Ayfs4cnjlD3CsGkxXX4QyW3j2WJ0wdjYcYmy7Lx2RDsZMhgCWafKIw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tldts-core": "^7.0.28" + }, + "bin": { + "tldts": "bin/cli.js" + } + }, + "node_modules/tldts-core": { + "version": "7.0.28", + "resolved": "https://registry.npmjs.org/tldts-core/-/tldts-core-7.0.28.tgz", + "integrity": "sha512-7W5Efjhsc3chVdFhqtaU0KtK32J37Zcr9RKtID54nG+tIpcY79CQK/veYPODxtD/LJ4Lue66jvrQzIX2Z2/pUQ==", + "dev": true, + "license": "MIT" + }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -9433,6 +10397,32 @@ "node": ">=8.0" } }, + "node_modules/tough-cookie": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-6.0.1.tgz", + "integrity": "sha512-LktZQb3IeoUWB9lqR5EWTHgW/VTITCXg4D21M+lvybRVdylLrRMnqaIONLVb5mav8vM19m44HIcGq4qASeu2Qw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "tldts": "^7.0.5" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/tr46": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-6.0.0.tgz", + "integrity": "sha512-bLVMLPtstlZ4iMQHpFHTR7GAGj2jxi8Dg0s2h2MafAE4uSWF98FC/3MomU51iQAMf8/qDUbKWf5GxuvvVcXEhw==", + "dev": true, + "license": "MIT", + "dependencies": { + "punycode": "^2.3.1" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/trim-lines": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", @@ -9541,11 +10531,20 @@ "typescript": ">=4.8.4 <6.0.0" } }, + "node_modules/undici": { + "version": "7.24.7", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.7.tgz", + "integrity": "sha512-H/nlJ/h0ggGC+uRL3ovD+G0i4bqhvsDOpbDv7At5eFLlj2b41L8QliGbnl2H7SnDiYhENphh1tQFJZf+MyfLsQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/undici-types": { "version": "7.16.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.16.0.tgz", "integrity": "sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==", - "dev": true, "license": "MIT" }, "node_modules/unicode-canonical-property-names-ecmascript": { @@ -9932,6 +10931,101 @@ } } }, + "node_modules/vitest": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.1.2.tgz", + "integrity": "sha512-xjR1dMTVHlFLh98JE3i/f/WePqJsah4A0FK9cc8Ehp9Udk0AZk6ccpIZhh1qJ/yxVWRZ+Q54ocnD8TXmkhspGg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "4.1.2", + "@vitest/mocker": "4.1.2", + "@vitest/pretty-format": "4.1.2", + "@vitest/runner": "4.1.2", + "@vitest/snapshot": "4.1.2", + "@vitest/spy": "4.1.2", + "@vitest/utils": "4.1.2", + "es-module-lexer": "^2.0.0", + "expect-type": "^1.3.0", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^4.0.0-rc.1", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.1.0", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.1.2", + "@vitest/browser-preview": "4.1.2", + "@vitest/browser-webdriverio": "4.1.2", + "@vitest/ui": "4.1.2", + "happy-dom": "*", + "jsdom": "*", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + }, + "vite": { + "optional": false + } + } + }, + "node_modules/w3c-xmlserializer": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz", + "integrity": "sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==", + "dev": true, + "license": "MIT", + "dependencies": { + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/web-namespaces": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/web-namespaces/-/web-namespaces-2.0.1.tgz", @@ -9942,6 +11036,41 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/webidl-conversions": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.1.tgz", + "integrity": "sha512-BMhLD/Sw+GbJC21C/UgyaZX41nPt8bUTg+jWyDeg7e7YN4xOM05YPSIXceACnXVtqyEw/LMClUQMtMZ+PGGpqQ==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-mimetype": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-5.0.0.tgz", + "integrity": "sha512-sXcNcHOC51uPGF0P/D4NVtrkjSU2fNsm9iog4ZvZJsL3rjoDAzXZhkm2MWt1y+PUdggKAYVoMAIYcs78wJ51Cw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-url": { + "version": "16.0.1", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-16.0.1.tgz", + "integrity": "sha512-1to4zXBxmXHV3IiSSEInrreIlu02vUOvrhxJJH5vcxYTBDAx51cqZiKdyTxlecdKNSjj8EcxGBxNf6Vg+945gw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@exodus/bytes": "^1.11.0", + "tr46": "^6.0.0", + "webidl-conversions": "^8.0.1" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -9958,6 +11087,23 @@ "node": ">= 8" } }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/word-wrap": { "version": "1.2.5", "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", @@ -9968,6 +11114,44 @@ "node": ">=0.10.0" } }, + "node_modules/ws": { + "version": "8.20.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.20.0.tgz", + "integrity": "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xml-name-validator": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-5.0.0.tgz", + "integrity": "sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18" + } + }, + "node_modules/xmlchars": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz", + "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==", + "dev": true, + "license": "MIT" + }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", diff --git a/frontend/app/package.json b/frontend/app/package.json index 52199cd30..1e247d29d 100644 --- a/frontend/app/package.json +++ b/frontend/app/package.json @@ -7,7 +7,8 @@ "dev": "vite", "build": "tsc -b && vite build", "lint": "eslint .", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest run" }, "dependencies": { "@hookform/resolvers": "^5.2.2", @@ -55,6 +56,7 @@ "react-resizable-panels": "^4.2.2", "react-router-dom": "^7.13.0", "recharts": "^2.15.4", + "@supabase/supabase-js": "^2.49.8", "sonner": "^2.0.7", "streamdown": "^2.4.0", "tailwind-merge": "^3.4.0", @@ -64,6 +66,7 @@ }, "devDependencies": { "@eslint/js": "^9.39.1", + "@testing-library/react": "^16.3.2", "@types/node": "^24.10.1", "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", @@ -73,6 +76,7 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.4.24", "globals": "^16.5.0", + "jsdom": "^28.1.0", "kimi-plugin-inspect-react": "^1.0.3", "postcss": "^8.5.6", "tailwindcss": "^3.4.19", @@ -80,6 +84,7 @@ "tw-animate-css": "^1.4.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "^7.2.4", + "vitest": "^4.1.2" } } diff --git a/frontend/app/src/api/client.ts b/frontend/app/src/api/client.ts index 2dd5c8c56..894171688 100644 --- a/frontend/app/src/api/client.ts +++ b/frontend/app/src/api/client.ts @@ -11,7 +11,10 @@ import type { LeaseStatus, ThreadDetail, ThreadSummary, - SandboxChannelFilesResult, + ThreadPermissions, + ThreadPermissionRules, + PermissionRuleBehavior, + AskUserAnswer, SandboxFileResult, SandboxFilesListResult, SandboxUploadResult, @@ -31,21 +34,11 @@ export async function request(url: string, init?: RequestInit): Promise { return (await response.json()) as T; } -function toThreads(payload: unknown): ThreadSummary[] { - if (payload && typeof payload === "object" && Array.isArray((payload as { threads?: unknown }).threads)) { - return (payload as { threads: ThreadSummary[] }).threads; - } - if (Array.isArray(payload)) { - return payload as ThreadSummary[]; - } - throw new Error("Unexpected /api/threads response shape"); -} - // --- Thread API --- export async function listThreads(): Promise { - const payload = await request("/api/threads"); - return toThreads(payload); + const payload = await request<{ threads: ThreadSummary[] }>("/api/threads"); + return payload.threads; } export interface CreateThreadOptions { @@ -68,7 +61,10 @@ export async function createThread(opts: CreateThreadOptions): Promise("/api/threads", { method: "POST", body: JSON.stringify(body) }); } -export async function getMainThread(memberId: string, signal?: AbortSignal): Promise { +export async function getDefaultThread(memberId: string, signal?: AbortSignal): Promise { + // @@@default-thread-wire-legacy - frontend now treats this as a template -> + // default-thread resolver, but the backend endpoint name stays `/threads/main` + // until the route contract is renamed in a later slice. const payload = await request<{ thread: ThreadSummary | null }>("/api/threads/main", { method: "POST", body: JSON.stringify({ member_id: memberId }), @@ -99,26 +95,55 @@ export async function getThread(threadId: string): Promise { return request(`/api/threads/${encodeURIComponent(threadId)}`); } -export async function getThreadRuntime(threadId: string): Promise { - return request(`/api/threads/${encodeURIComponent(threadId)}/runtime`); +export async function getThreadPermissions(threadId: string, signal?: AbortSignal): Promise { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions`, { signal }); } -export async function sendMessage(threadId: string, message: string): Promise<{ status: string; routing: string }> { - return request(`/api/threads/${encodeURIComponent(threadId)}/messages`, { +export async function resolveThreadPermission( + threadId: string, + requestId: string, + decision: "allow" | "deny", + message?: string, + answers?: AskUserAnswer[], + annotations?: Record, +): Promise<{ ok: boolean; thread_id: string; request_id: string }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions/${encodeURIComponent(requestId)}/resolve`, { method: "POST", - body: JSON.stringify({ message }), + body: JSON.stringify({ decision, message, answers, annotations }), }); } -export async function queueMessage(threadId: string, message: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/queue`, { +export async function addThreadPermissionRule( + threadId: string, + behavior: PermissionRuleBehavior, + toolName: string, +): Promise<{ ok: boolean; thread_id: string; scope: string; rules: ThreadPermissionRules; managed_only: boolean }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions/rules`, { method: "POST", - body: JSON.stringify({ message }), + body: JSON.stringify({ behavior, tool_name: toolName }), }); } -export async function getQueue(threadId: string): Promise<{ messages: Array<{ id: number; content: string; created_at: string }> }> { - return request(`/api/threads/${encodeURIComponent(threadId)}/queue`); +export async function removeThreadPermissionRule( + threadId: string, + behavior: PermissionRuleBehavior, + toolName: string, +): Promise<{ ok: boolean; thread_id: string; scope: string; rules: ThreadPermissionRules; managed_only: boolean }> { + return request( + `/api/threads/${encodeURIComponent(threadId)}/permissions/rules/${encodeURIComponent(behavior)}/${encodeURIComponent(toolName)}`, + { method: "DELETE" }, + ); +} + +export async function getThreadRuntime(threadId: string): Promise { + return request(`/api/threads/${encodeURIComponent(threadId)}/runtime`); +} + +export async function sendMessage(threadId: string, message: string): Promise<{ status: string; routing: string }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/messages`, { + method: "POST", + body: JSON.stringify({ message }), + }); } // --- Sandbox API --- @@ -163,32 +188,6 @@ export async function listMyLeases(signal?: AbortSignal): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox/pause`, { method: "POST" }); -} - -export async function resumeThreadSandbox(threadId: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox/resume`, { method: "POST" }); -} - -export async function destroyThreadSandbox(threadId: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox`, { method: "DELETE" }); -} - -export async function pauseSandboxSession(sessionId: string, provider: string): Promise { - await request( - `/api/sandbox/sessions/${encodeURIComponent(sessionId)}/pause?provider=${encodeURIComponent(provider)}`, - { method: "POST" }, - ); -} - -export async function resumeSandboxSession(sessionId: string, provider: string): Promise { - await request( - `/api/sandbox/sessions/${encodeURIComponent(sessionId)}/resume?provider=${encodeURIComponent(provider)}`, - { method: "POST" }, - ); -} - export async function destroySandboxSession(sessionId: string, provider: string): Promise { await request( `/api/sandbox/sessions/${encodeURIComponent(sessionId)}?provider=${encodeURIComponent(provider)}`, @@ -206,8 +205,16 @@ export async function getThreadTerminal(threadId: string): Promise { - return request(`/api/threads/${encodeURIComponent(threadId)}/lease`); +export async function getThreadLease(threadId: string): Promise { + const response = await authFetch(`/api/threads/${encodeURIComponent(threadId)}/lease`); + if (response.status === 404) { + return null; + } + if (!response.ok) { + const body = await response.text(); + throw new Error(`API ${response.status}: ${body || response.statusText}`); + } + return (await response.json()) as LeaseStatus; } // --- Sandbox Files API --- @@ -225,12 +232,6 @@ export async function readSandboxFile(threadId: string, path: string): Promise { - return request(`${sandboxFilesBase(threadId)}/channel-files`); -} - export async function uploadSandboxFile( threadId: string, opts: { file: File; path?: string }, @@ -261,11 +262,6 @@ export function getSandboxDownloadUrl( // --- Settings API --- -export async function listSandboxConfigs(): Promise>> { - const payload = await request<{ sandboxes: Record> }>("/api/settings/sandboxes"); - return payload.sandboxes; -} - export async function saveSandboxConfig(name: string, config: Record): Promise { await request("/api/settings/sandboxes", { method: "POST", @@ -275,10 +271,6 @@ export async function saveSandboxConfig(name: string, config: Record> { - return request("/api/settings/observation"); -} - export async function saveObservationConfig( active: string | null, config?: Record, @@ -309,9 +301,8 @@ export interface InviteCode { } export async function fetchInviteCodes(): Promise { - const payload = await request<{ codes: InviteCode[] } | InviteCode[]>("/api/invite-codes"); - if (Array.isArray(payload)) return payload; - return (payload as { codes: InviteCode[] }).codes; + const payload = await request<{ codes: InviteCode[] }>("/api/invite-codes"); + return payload.codes; } export async function generateInviteCode(expiresDays = 7): Promise { diff --git a/frontend/app/src/api/types.ts b/frontend/app/src/api/types.ts index 08d990935..7aa8548cb 100644 --- a/frontend/app/src/api/types.ts +++ b/frontend/app/src/api/types.ts @@ -28,11 +28,12 @@ export interface ThreadSummary { preview?: string; updated_at?: string; running?: boolean; + /** Template entry id for this thread; actor identity still lives in `thread_id`. */ member_id?: string; + /** Template-facing secondary label; child threads should prefer `sidebar_label` when present. */ member_name?: string; - /** Canonical thread/entity display name. Main: {member}. Child: {member} · 分身N */ - entity_name?: string; branch_index?: number; + /** Canonical actor-facing label for sidebar/header surfaces. */ sidebar_label?: string | null; avatar_url?: string; is_main?: boolean; @@ -45,6 +46,49 @@ export interface ThreadDetail { sandbox: SandboxInfo | null; } +export interface PermissionRequest { + request_id: string; + thread_id: string; + tool_name: string; + args: Record; + message?: string | null; +} + +export interface AskUserQuestionOption { + label: string; + description: string; + preview?: string | null; +} + +export interface AskUserQuestionPrompt { + header: string; + question: string; + options: AskUserQuestionOption[]; + multiSelect?: boolean; +} + +export interface AskUserAnswer { + header?: string; + question?: string; + selected_options: string[]; + free_text?: string | null; +} + +export type PermissionRuleBehavior = "allow" | "deny" | "ask"; + +export interface ThreadPermissionRules { + allow: string[]; + deny: string[]; + ask: string[]; +} + +export interface ThreadPermissions { + thread_id: string; + requests: PermissionRequest[]; + session_rules: ThreadPermissionRules; + managed_only: boolean; +} + export interface SandboxType { name: string; provider?: string; @@ -109,7 +153,9 @@ export interface UserLeaseSummary { cwd?: string | null; thread_ids: string[]; agents: Array<{ + /** Template entry bound to the lease; not an actor thread id. */ member_id: string; + /** Template-facing label for the lease summary card. */ member_name: string; avatar_url?: string | null; }>; @@ -200,6 +246,11 @@ export interface UserMessage { timestamp: number; /** Backend-computed: is this message visible to thread owner? */ showing?: boolean; + ask_user_question_answered?: { + questions: AskUserQuestionPrompt[]; + answers: AskUserAnswer[]; + annotations?: Record; + }; senderName?: string; senderAvatarUrl?: string; attachments?: string[]; @@ -219,6 +270,7 @@ export interface StreamStatus { state: { state: string; flags: Record }; tokens: { total_tokens: number; input_tokens: number; output_tokens: number; cost: number }; context: { message_count: number; estimated_tokens: number; usage_percent: number; near_limit: boolean }; + model?: string; current_tool?: string; last_seq?: number; run_start_seq?: number; @@ -278,35 +330,29 @@ export interface SandboxFileResult { size: number; } -// --- Entity Chat types --- +// --- Chat types --- -export interface ChatEntity { +export interface ChatMember { id: string; + /** Current chat-facing display label for this participant. */ name: string; type: string; avatar_url?: string; owner_name?: string | null; + /** Template-facing auxiliary label when this chat member is thread-backed. */ member_name?: string | null; + /** Actor thread backing this participant when applicable. */ thread_id?: string | null; is_main?: boolean | null; branch_index?: number | null; } -export interface ChatSummary { - id: string; - title: string | null; - entities: ChatEntity[]; - last_message?: { content: string; sender_name: string; created_at: number }; - unread_count: number; - has_mention: boolean; -} - export interface ChatDetail { id: string; title: string | null; status: string; created_at: number; - entities: ChatEntity[]; + entities: ChatMember[]; } export interface ChatMessage { @@ -319,29 +365,6 @@ export interface ChatMessage { created_at: number; } -export interface TaskAgentRequest { - subagent_type: string; - prompt: string; - description?: string; - model?: string; - max_turns?: number; -} - -// @@@channel-kind - string union used directly as a selector, not an object -export type SandboxChannelKind = "upload" | "download"; - -export interface SandboxChannelFileEntry { - relative_path: string; - size_bytes: number; - updated_at: string; -} - -export interface SandboxChannelFilesResult { - thread_id: string; - channel: SandboxChannelKind; - entries: SandboxChannelFileEntry[]; -} - export interface SandboxUploadResult { thread_id: string; relative_path: string; diff --git a/frontend/app/src/components/ChatArea.test.tsx b/frontend/app/src/components/ChatArea.test.tsx new file mode 100644 index 000000000..6c4350157 --- /dev/null +++ b/frontend/app/src/components/ChatArea.test.tsx @@ -0,0 +1,185 @@ +// @vitest-environment jsdom + +import { afterEach, describe, expect, it } from "vitest"; +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; + +import ChatArea from "./ChatArea"; + +afterEach(() => { + cleanup(); +}); + +describe("ChatArea", () => { + it("does not render hidden user entries", () => { + render( + {}", + timestamp: Date.now(), + showing: false, + }, + ]} + runtimeStatus={null} + loading={false} + />, + ); + + expect(screen.queryByText(/ask_user_question_answers/i)).toBeNull(); + }); + + it("renders AskUserQuestion inline inside the assistant turn", () => { + render( + undefined, + onSubmit: () => undefined, + selectionKeyForIndex: (index) => String(index), + }} + />, + ); + + expect(screen.getByText("等待回答")).toBeTruthy(); + expect(screen.getByText("选择一个方向")).toBeTruthy(); + expect(screen.getByRole("button", { name: "提交回答" })).toBeTruthy(); + }); + + it("anchors hidden ask-user answers back onto the original assistant turn", () => { + render( + \n{"questions":[{"header":"选择一个方向","question":"你希望我问什么?","options":[{"label":"A","description":"简单问题"},{"label":"B","description":"工作问题"}]}],"answers":[{"header":"选择一个方向","question":"你希望我问什么?","selected_options":["B"]}]}\n', + timestamp: Date.now() + 1, + showing: false, + }, + ]} + runtimeStatus={null} + loading={false} + />, + ); + + expect(screen.queryByText(/ask_user_question_answers/i)).toBeNull(); + expect(screen.getByText(/已回答 · 选择一个方向:B/)).toBeTruthy(); + expect(screen.queryByText("你希望我问什么?")).toBeNull(); + + fireEvent.click(screen.getByRole("button", { name: "查看已回答详情" })); + + expect(screen.getByText("你希望我问什么?")).toBeTruthy(); + expect(screen.getByText("B")).toBeTruthy(); + }); + + it("prefers explicit answered payload metadata over parsing hidden content", () => { + render( + , + ); + + expect(screen.getByText(/已回答 · 选择一个方向:A/)).toBeTruthy(); + }); +}); diff --git a/frontend/app/src/components/ChatArea.tsx b/frontend/app/src/components/ChatArea.tsx index b203acdf2..b385c580f 100644 --- a/frontend/app/src/components/ChatArea.tsx +++ b/frontend/app/src/components/ChatArea.tsx @@ -1,5 +1,7 @@ import type { AssistantTurn, ChatEntry, NoticeMessage, StreamStatus } from "../api"; import { useStickyScroll } from "../hooks/use-sticky-scroll"; +import type { AskUserQuestionPendingState } from "../pages/ask-user-question"; +import { parseAskUserQuestionAnswerPayload } from "../pages/ask-user-question"; import { AssistantBlock } from "./chat-area/AssistantBlock"; import { ChatSkeleton } from "./chat-area/ChatSkeleton"; import { NoticeBubble } from "./chat-area/NoticeBubble"; @@ -15,10 +17,47 @@ interface ChatAreaProps { agentAvatarUrl?: string; userName?: string; userAvatarUrl?: string; + askUserQuestion?: AskUserQuestionPendingState; } -export default function ChatArea({ entries, runtimeStatus, loading, onFocusAgent, onTaskNoticeClick, agentName, agentAvatarUrl, userName, userAvatarUrl }: ChatAreaProps) { +function hasAskUserQuestionTool(entry: AssistantTurn): boolean { + return entry.segments.some((segment) => segment.type === "tool" && segment.step.name === "AskUserQuestion"); +} + +export default function ChatArea({ entries, runtimeStatus, loading, onFocusAgent, onTaskNoticeClick, agentName, agentAvatarUrl, userName, userAvatarUrl, askUserQuestion }: ChatAreaProps) { const containerRef = useStickyScroll(); + const askUserQuestionDisplays = new Map< + string, + | { mode: "pending"; pending: AskUserQuestionPendingState } + | { + mode: "answered"; + answered: NonNullable>; + } + >(); + + let lastAskAssistantId: string | null = null; + for (const entry of entries) { + if (entry.role === "assistant" && hasAskUserQuestionTool(entry as AssistantTurn)) { + lastAskAssistantId = entry.id; + continue; + } + if (entry.role === "user" && "showing" in entry && entry.showing === false) { + const answered = entry.ask_user_question_answered ?? parseAskUserQuestionAnswerPayload(entry.content); + if (answered && lastAskAssistantId) { + askUserQuestionDisplays.set(lastAskAssistantId, { mode: "answered", answered }); + lastAskAssistantId = null; + } + } + } + + if (askUserQuestion) { + const pendingAssistant = [...entries] + .reverse() + .find((entry): entry is AssistantTurn => entry.role === "assistant" && hasAskUserQuestionTool(entry as AssistantTurn)); + if (pendingAssistant) { + askUserQuestionDisplays.set(pendingAssistant.id, { mode: "pending", pending: askUserQuestion }); + } + } return (
@@ -28,23 +67,21 @@ export default function ChatArea({ entries, runtimeStatus, loading, onFocusAgent
{entries.map((entry) => { const isHidden = "showing" in entry && entry.showing === false; + if (isHidden) return null; if (entry.role === "notice") { return ; } if (entry.role === "user") { return ( -
- {isHidden && entry.senderName && ( -
{entry.senderName}
- )} - +
+
); } const assistantEntry = entry as AssistantTurn; const isStreamingThis = assistantEntry.streaming === true; return ( -
+
); diff --git a/frontend/app/src/components/ComputerPanel.tsx b/frontend/app/src/components/ComputerPanel.tsx deleted file mode 100644 index 5a9f92065..000000000 --- a/frontend/app/src/components/ComputerPanel.tsx +++ /dev/null @@ -1,3 +0,0 @@ -// Re-export from refactored module -export { default } from "./computer-panel"; -export type { ComputerPanelProps } from "./computer-panel"; diff --git a/frontend/app/src/components/CreateMemberDialog.tsx b/frontend/app/src/components/CreateMemberDialog.tsx index fff6bfb34..58c6c401f 100644 --- a/frontend/app/src/components/CreateMemberDialog.tsx +++ b/frontend/app/src/components/CreateMemberDialog.tsx @@ -29,7 +29,7 @@ export default function CreateMemberDialog({ open, onOpenChange }: Props) { onOpenChange(false); setName(""); setDescription(""); - navigate(`/members/${member.id}`); + navigate(`/contacts/agents/${member.id}`); } catch (e) { toast.error("创建失败,请重试"); } diff --git a/frontend/app/src/components/FileBrowser.tsx b/frontend/app/src/components/FileBrowser.tsx deleted file mode 100644 index 4cef7086a..000000000 --- a/frontend/app/src/components/FileBrowser.tsx +++ /dev/null @@ -1,101 +0,0 @@ -import { useState } from 'react'; -import { authFetch } from '@/store/auth-store'; -import { useFileList } from '@/hooks/useFileList'; -import { MoreVertical } from 'lucide-react'; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from '@/components/ui/dropdown-menu'; -import { Button } from '@/components/ui/button'; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from '@/components/ui/alert-dialog'; - -interface FileBrowserProps { - threadId: string; -} - -export function FileBrowser({ threadId }: FileBrowserProps) { - const { files, loading, error, refetch } = useFileList(threadId); - const [deleteTarget, setDeleteTarget] = useState(null); - const [deleting, setDeleting] = useState(false); - - const handleDownload = (path: string) => { - const url = `/api/threads/${threadId}/files/download?path=${encodeURIComponent(path)}`; - window.open(url, '_blank'); - }; - - const handleDelete = async () => { - if (!deleteTarget) return; - setDeleting(true); - try { - const res = await authFetch( - `/api/threads/${threadId}/files/files?path=${encodeURIComponent(deleteTarget)}`, - { method: 'DELETE' } - ); - if (!res.ok) throw new Error('Failed to delete file'); - await refetch(); - } catch (e) { - alert(e instanceof Error ? e.message : 'Failed to delete file'); - } finally { - setDeleting(false); - setDeleteTarget(null); - } - }; - - if (loading) return
加载文件中...
; - if (error) return
错误:{error}
; - if (files.length === 0) return
暂无已上传文件
; - - return ( - <> -
- {files.map((file) => ( -
- {file.relative_path} -
- {(file.size_bytes / 1024).toFixed(1)} KB - - - - - - handleDownload(file.relative_path)}>下载 - setDeleteTarget(file.relative_path)} disabled={deleting}>删除 - - -
-
- ))} -
- - setDeleteTarget(null)}> - - - 删除文件? - - 确定要删除 "{deleteTarget}" 吗?此操作无法撤销。 - - - - 取消 - - {deleting ? '删除中...' : '删除'} - - - - - - ); -} diff --git a/frontend/app/src/components/Header.tsx b/frontend/app/src/components/Header.tsx index 9273f8c7b..1d850dbaf 100644 --- a/frontend/app/src/components/Header.tsx +++ b/frontend/app/src/components/Header.tsx @@ -1,4 +1,4 @@ -import { ChevronLeft, PanelLeft, Pause, Play } from "lucide-react"; +import { ChevronLeft, PanelLeft } from "lucide-react"; import { useNavigate } from "react-router-dom"; import type { SandboxInfo } from "../api"; import { useIsMobile } from "../hooks/use-mobile"; @@ -22,8 +22,6 @@ interface HeaderProps { sandboxInfo: SandboxInfo | null; currentModel?: string; onToggleSidebar: () => void; - onPauseSandbox: () => void; - onResumeSandbox: () => void; onModelChange?: (model: string) => void; } @@ -33,8 +31,6 @@ export default function Header({ sandboxInfo, currentModel = "leon:medium", onToggleSidebar, - onPauseSandbox, - onResumeSandbox, onModelChange, }: HeaderProps) { const isMobile = useIsMobile(); @@ -52,7 +48,7 @@ export default function Header({
{isMobile ? ( - )} - {hasRemote && sandboxInfo?.status === "paused" && ( - - )}
); diff --git a/frontend/app/src/components/LibraryEditor.tsx b/frontend/app/src/components/LibraryEditor.tsx deleted file mode 100644 index 33c269af0..000000000 --- a/frontend/app/src/components/LibraryEditor.tsx +++ /dev/null @@ -1,145 +0,0 @@ -import { useState, useEffect } from "react"; -import { X, Save, Tag, Users, Calendar, FileText } from "lucide-react"; -import { Button } from "@/components/ui/button"; -import { Input } from "@/components/ui/input"; -import { toast } from "sonner"; -import { useAppStore } from "@/store/app-store"; -import { formatDistanceToNow } from "date-fns"; -import { zhCN } from "date-fns/locale"; -import type { ResourceItem } from "@/store/types"; - -interface Props { - item: ResourceItem | null; - type: "skill" | "mcp" | "agent"; - onClose: () => void; - onCreated?: (item: ResourceItem) => void; -} - -export default function LibraryEditor({ item, type, onClose, onCreated }: Props) { - const fetchResourceContent = useAppStore(s => s.fetchResourceContent); - const updateResourceContent = useAppStore(s => s.updateResourceContent); - const updateResource = useAppStore(s => s.updateResource); - const addResource = useAppStore(s => s.addResource); - const getResourceUsedBy = useAppStore(s => s.getResourceUsedBy); - - const isNew = item === null; - - const [name, setName] = useState(""); - const [content, setContent] = useState(""); - const [savedContent, setSavedContent] = useState(""); - const [loading, setLoading] = useState(!isNew); - const [saving, setSaving] = useState(false); - const [desc, setDesc] = useState(""); - - // Load existing item data - useEffect(() => { - if (!item) { - setName(""); setDesc(""); - setContent(""); setSavedContent(""); - setLoading(false); - return; - } - setName(item.name); - setDesc(item.desc); - setLoading(true); - fetchResourceContent(type, item.id) - .then(c => { setContent(c); setSavedContent(c); }) - .catch(() => { setContent(""); setSavedContent(""); }) - .finally(() => setLoading(false)); - }, [item?.id, type, fetchResourceContent]); - - const savedMeta = item ? { name: item.name, desc: item.desc } : null; - const contentDirty = content !== savedContent; - const metaDirty = isNew - ? name.trim().length > 0 - : (desc !== savedMeta!.desc); - const dirty = contentDirty || metaDirty; - const canSave = isNew ? name.trim().length > 0 : dirty; - - const usedByMembers = item ? getResourceUsedBy(type, item.name) : []; - const updatedText = item?.updated_at - ? formatDistanceToNow(new Date(item.updated_at), { addSuffix: true, locale: zhCN }) - : ""; - - const handleSave = async () => { - setSaving(true); - try { - if (isNew) { - const created = await addResource(type, name.trim(), desc.trim()); - if (content.trim()) await updateResourceContent(type, created.id, content); - toast.success(`${name.trim()} 已创建`); - onCreated?.(created); - } else { - if (metaDirty) await updateResource(type, item.id, { desc }); - if (contentDirty) await updateResourceContent(type, item.id, content); - setSavedContent(content); - toast.success("已保存"); - } - } catch { toast.error(isNew ? "创建失败" : "保存失败"); } - finally { setSaving(false); } - }; - - const typeLabel = type === "skill" ? "Skill" : type === "mcp" ? "MCP" : "Agent"; - const fileHint = type === "skill" ? "SKILL.md" : type === "agent" ? `${item?.id || "new"}.md` : ".mcp.json"; - - return ( -
- {/* Header */} -
- {isNew ? ( - setName(e.target.value)} autoFocus /> - ) : ( -

{item.name}

- )} -
- - -
-
- -
- {/* Meta section */} -
- {!isNew && ( -
- {typeLabel} - - {usedByMembers.length ? usedByMembers.join(", ") : "未被使用"} - - {updatedText && {updatedText}} -
- )} -
- setDesc(e.target.value)} /> -
-
- - {/* Content editor */} -
-
- - {fileHint} -
- {loading ? ( -
-

加载中...

-
- ) : ( -