diff --git a/Cargo.toml b/Cargo.toml index 31d4b065..00d69362 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectorless" -version = "0.1.9" +version = "0.1.10" edition = "2024" authors = ["zTgx "] description = "Hierarchical, reasoning-native document intelligence engine" diff --git a/README.md b/README.md index 300818ff..50b88351 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,25 @@ Uses adaptive, multi-stage retrieval with backtracking: This mimics how humans navigate documentation: skim the TOC, drill into relevant sections, and backtrack when needed. +### Pilot: The Brain + +**Pilot** is the intelligence layer that guides retrieval: + +- **Intervention Points** — Pilot acts at key decision moments: + - **START** — Analyze query intent, set initial direction + - **FORK** — Rank candidates at branch points + - **BACKTRACK** — Suggest alternatives when search fails + - **EVALUATE** — Assess content sufficiency + +- **Score Merging** — Combines algorithm scores with LLM reasoning: + ``` + final_score = α × algorithm_score + β × llm_score + ``` + +- **Fallback Strategy** — 4-level degradation (Normal → Retry → Simplified → Algorithm-only) + +- **Budget Control** — Token and call limits with intelligent allocation + ## Comparison | Aspect | Vectorless | Traditional RAG | @@ -116,6 +135,12 @@ See the [examples/](examples/) directory for complete working examples: ## Architecture +### Pilot Architecture + +![Pilot Architecture](docs/design/pilot-architecture.svg) + +### System Overview + ![Architecture](docs/design/architecture-v2.svg) ## Contributing diff --git a/docs/design/pilot-architecture.svg b/docs/design/pilot-architecture.svg new file mode 100644 index 00000000..e85caea3 --- /dev/null +++ b/docs/design/pilot-architecture.svg @@ -0,0 +1,197 @@ + + + + + + Pilot: The Brain of Retrieval Pipeline + + + + Retrieval Pipeline + + + + Analyze + • Complexity + • Keywords + + + + + Plan + • Strategy + • Algorithm + + + + + Search + • Beam/MCTS/Greedy + • Tree Traversal + + + + + Judge + • Sufficiency + • Backtrack? + + + + 🧠 Pilot + The Brain of Retrieval + + + + Budget + Controller + + + Context + Builder + + + Fallback + Manager + + + LLM Client + Metrics + + + + Intervention Points (When Pilot Acts) + + + + + 1 + START + Before search begins + • Analyze query intent + • Identify target sections + • Set initial direction + • Provide entry points + + + + + 2 + FORK + At branch points + • Multiple children match + • Rank candidates + • Merge LLM + algo scores + • Guide path selection + + + + + 3 + BACKTRACK + When search fails + • Insufficient results + • Suggest alternatives + • Re-rank candidates + • Adjust search params + + + + + 4 + EVALUATE + After content found + • Check sufficiency + • Quality assessment + • Decide more data? + • Final confidence + + + + Score Merging: Algorithm + LLM + + + Algorithm Score + Text similarity + Heuristics + + + + + + LLM Score + Semantic relevance + Reasoning + + = + + + Final + Score + + final = α × algo + β × llm (configurable weights) + + + + 4-Level Fallback Strategy + + + + Normal + LLM OK + + + + + + Retry + Backoff + 2x delay + + + + + Simplified + Less tokens + Short ctx + + + + + Algo + Only + + Automatic escalation on consecutive failures + + + + Design Philosophy + + Algorithm = "How to search" + Efficient, deterministic + + Pilot = "Where to go" + Semantic understanding, direction + + Intervention at key points + Not every step, only when needed + + + + + + + + Backtrack with Pilot guidance + + + + + + + + + + + + + + diff --git a/docs/design/pilot.md b/docs/design/pilot.md new file mode 100644 index 00000000..0f907f25 --- /dev/null +++ b/docs/design/pilot.md @@ -0,0 +1,1613 @@ +# Pilot 设计文档 + +> Pilot - Retriever Pipeline 的大脑 + +## 概述 + +Pilot 是 Vectorless 检索系统的核心智能组件,负责理解查询、分析文档结构、做出搜索决策。与传统的向量检索不同,Pilot 使用 LLM 进行语义理解和导航决策,同时保持算法的高效执行。 + +### 设计哲学 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 设计哲学 │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. 算法负责 "怎么走" - 高效、确定性、低延迟 │ +│ 2. Pilot 负责 "去哪里" - 语义理解、歧义消解、方向判断 │ +│ 3. 关键决策点介入 - 不是每步都问 LLM,而是在需要时才问 │ +│ 4. 分层 fallback - LLM 失败时算法接管,算法失败时 Pilot 救援 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 命名由来 + +**Pilot (驾驶员)** - 像飞机的驾驶员一样,Pilot 不直接操作每个机械部件(那是 Algorithm 的职责),而是负责: +- 理解目的地(用户查询) +- 规划航线(搜索策略) +- 在关键节点做决策(介入点) +- 应对突发情况(fallback) + +--- + +## 1. Pilot 详细设计 + +### 1.1 整体架构 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 架构 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌───────────────────────────────────────────────────────────────────────┐ │ +│ │ Pilot (Core) │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Query │ │ Context │ │ Decision │ │ │ +│ │ │ Analyzer │──▶│ Builder │──▶│ Engine │ │ │ +│ │ │ 查询分析器 │ │ 上下文构建 │ │ 决策引擎 │ │ │ +│ │ └─────────────┘ └─────────────┘ └──────┬──────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Response │◀──│ LLM │◀──│ Prompt │ │ │ +│ │ │ Parser │ │ Client │ │ Builder │ │ │ +│ │ │ 响应解析器 │ │ 客户端 │ │ 提示词构建 │ │ │ +│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────────────────────┐ │ +│ │ Supporting Systems │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Budget │ │ Fallback │ │ Metrics │ │ │ +│ │ │ Controller │ │ Manager │ │ Collector │ │ │ +│ │ │ 预算控制器 │ │ 降级管理 │ │ 指标收集器 │ │ │ +│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Policy │ │ Cache │ │ Logger │ │ │ +│ │ │ Manager │ │ (Optional) │ │ (Tracing) │ │ │ +│ │ │ 策略管理器 │ │ 缓存 │ │ 日志追踪 │ │ │ +│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 1.2 核心接口定义 + +```rust +/// 搜索状态 - 传给 Pilot 的上下文信息 +pub struct SearchState<'a> { + /// 文档树 + pub tree: &'a DocumentTree, + /// 用户查询 + pub query: &'a str, + /// 当前路径(从根到当前节点) + pub path: &'a [NodeId], + /// 候选子节点 + pub candidates: &'a [NodeId], + /// 已访问的节点 + pub visited: &'a HashSet, + /// 当前深度 + pub depth: usize, + /// 搜索迭代次数 + pub iteration: usize, + /// 当前最高分 + pub best_score: f32, + /// 是否在回溯中 + pub is_backtracking: bool, +} + +/// Pilot trait - 核心接口 +#[async_trait] +pub trait Pilot: Send + Sync { + /// 获取 Pilot 名称 + fn name(&self) -> &str; + + /// 判断是否应该介入 + fn should_intervene(&self, state: &SearchState<'_>) -> bool; + + /// 做出决策 + async fn decide(&self, state: &SearchState<'_>) -> PilotDecision; + + /// 搜索开始前的指导 + async fn guide_start( + &self, + tree: &DocumentTree, + query: &str + ) -> Option; + + /// 获取配置 + fn config(&self) -> &PilotConfig; + + /// 获取指标 + fn metrics(&self) -> &PilotMetrics; + + /// 重置状态(新查询开始时调用) + fn reset(&self); +} +``` + +### 1.3 Pilot 决策类型 + +```rust +/// Pilot 决策结果 +#[derive(Debug, Clone)] +pub struct PilotDecision { + /// 候选节点排序(按推荐优先级) + pub ranked_candidates: Vec, + /// 搜索方向建议 + pub direction: SearchDirection, + /// 置信度 (0.0 - 1.0) + pub confidence: f32, + /// 决策原因(可解释性) + pub reasoning: String, + /// 介入点标识 + pub intervention_point: InterventionPoint, +} + +/// 排序后的候选节点 +#[derive(Debug, Clone)] +pub struct RankedCandidate { + pub node_id: NodeId, + pub score: f32, + pub reason: Option, +} + +/// 搜索方向建议 +#[derive(Debug, Clone)] +pub enum SearchDirection { + /// 继续深入当前分支 + GoDeeper { + reason: String, + }, + /// 探索兄弟节点 + ExploreSiblings { + recommended: Vec, + }, + /// 回溯到父节点 + Backtrack { + reason: String, + alternative_branches: Vec, + }, + /// 跳转到指定节点(非局部移动) + JumpTo { + target: NodeId, + reason: String, + }, + /// 当前节点就是答案 + FoundAnswer { + confidence: f32, + }, +} +``` + +--- + +## 1.4 Pilot 决策的信息来源 + +Pilot 的决策依赖于多层信息,其中 TOC View 是核心——它就像导航电子地图。 + +### 信息来源架构 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 的"导航地图" │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────┐ │ +│ │ User Query │ │ +│ │ "PostgreSQL │ │ +│ │ 连接池配置" │ │ +│ └────────┬────────┘ │ +│ │ │ +│ ▼ │ +│ ┌───────────────────────────────────────────────────────────────────────┐ │ +│ │ Pilot 上下文 │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ TOC View │ │ Current │ │ Candidates │ │ │ +│ │ │ (电子地图) │ │ Path │ │ Info │ │ │ +│ │ │ │ │ (当前位置) │ │ (候选路口) │ │ │ +│ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │ +│ │ │ │ │ │ │ +│ │ └─────────────────┼─────────────────┘ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────┐ │ │ +│ │ │ LLM Decision │ │ │ +│ │ │ (去哪里) │ │ │ +│ │ └─────────────────┘ │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### TOC View - 电子地图(核心) + +TOC View 是 Pilot 决策的核心依据,由 Index 阶段生成的内容构建: + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ TOC View - 电子地图 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Index 阶段生成的内容: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ TreeNode { │ │ +│ │ title: "配置", // 标题 │ │ +│ │ summary: "本章介绍...", // LLM 生成的摘要 ← 关键! │ │ +│ │ depth: 1, │ │ +│ │ children: [...], │ │ +│ │ } │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ TOC View 构建逻辑: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ generate_toc_view(tree, current_node): │ │ +│ │ │ │ +│ │ // 1. 从当前节点视角生成 │ │ +│ │ // 2. 包含兄弟节点(横向视野) │ │ +│ │ // 3. 包含子节点(纵向视野) │ │ +│ │ // 4. 每个节点包含 title + summary │ │ +│ │ │ │ +│ │ 输出示例: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 📍 当前位置: Root → 配置 │ │ │ +│ │ │ │ │ │ +│ │ │ 📂 兄弟节点: │ │ │ +│ │ │ ├─ 简介 [概述项目功能和架构] │ │ │ +│ │ │ ├─ 安装 [安装步骤和环境要求] │ │ │ +│ │ │ ├─ 配置 ⭐ [配置项详解] ← 当前节点 │ │ │ +│ │ │ │ ├─ 基本配置 [基础参数设置] │ │ │ +│ │ │ │ ├─ 数据库配置 [数据库连接相关] ← 关键匹配! │ │ │ +│ │ │ │ └─ 高级配置 [性能调优选项] │ │ │ +│ │ │ └─ API 参考 [接口文档] │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 三层信息结构 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 决策的三层信息 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Layer 1: TOC View (全局地图) │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 作用: 提供文档的全局结构视图 │ │ +│ │ 来源: Index Pipeline 的 Enrich 阶段生成的 summary │ │ +│ │ Token: 约 200-500 tokens │ │ +│ │ │ │ +│ │ 示例: │ │ +│ │ "本文档结构: 1.简介 2.安装 3.配置(3.1基本 3.2数据库 3.3高级) 4.API" │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ Layer 2: Current Path (当前位置) │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 作用: 告诉 LLM 我们已经走了哪里 │ │ +│ │ 来源: 搜索过程的路径记录 │ │ +│ │ Token: 约 50-100 tokens │ │ +│ │ │ │ +│ │ 示例: │ │ +│ │ "当前路径: Root → 配置 → 数据库配置" │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ Layer 3: Candidates Detail (候选路口详情) │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 作用: 提供候选节点的详细信息,供 LLM 判断 │ │ +│ │ 来源: TreeNode 的 title + summary + 部分内容 │ │ +│ │ Token: 约 100-300 tokens │ │ +│ │ │ │ +│ │ 示例: │ │ +│ │ 候选节点: │ │ +│ │ A. 连接字符串 │ │ +│ │ 摘要: 配置数据库连接 URL 和认证信息 │ │ +│ │ B. 连接池 ⭐ │ │ +│ │ 摘要: 配置连接池大小、超时、最大连接数等 │ │ +│ │ C. 超时设置 │ │ +│ │ 摘要: 配置查询和连接超时时间 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 决策过程示例 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 决策过程示例 │ +└─────────────────────────────────────────────────────────────────────────────┘ + +Query: "PostgreSQL 连接池的最大连接数怎么配置?" + +Step 1: 构建 TOC View (从 Index 阶段的 summary) +┌─────────────────────────────────────────────────────────────────────────────┐ +│ TOC View (简化版): │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 文档结构: │ │ +│ │ 1. 快速开始 │ │ +│ │ 2. 配置 │ │ +│ │ 2.1 基本配置 │ │ +│ │ 2.2 数据库配置 │ │ +│ │ - 连接字符串 │ │ +│ │ - 连接池 ← 包含"连接池" │ │ +│ │ - 超时设置 │ │ +│ │ 2.3 高级配置 │ │ +│ │ 3. API │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 这个 TOC 是 Index 阶段 LLM 生成的 summary 构成的! │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +Step 2: LLM 分析 +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LLM 看到的信息: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 用户查询: "PostgreSQL 连接池的最大连接数怎么配置?" │ │ +│ │ │ │ +│ │ 当前位置: 配置 → 数据库配置 │ │ +│ │ │ │ +│ │ 候选节点: │ │ +│ │ 1. 连接字符串 [配置数据库 URL 和认证] │ │ +│ │ 2. 连接池 [配置池大小、超时、最大连接数] ← 直接匹配! │ │ +│ │ 3. 超时设置 [配置查询超时时间] │ │ +│ │ │ │ +│ │ 请判断哪个节点最可能包含答案? │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ LLM 推理: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 查询关键词: "连接池", "最大连接数" │ │ +│ │ 候选 2 的摘要包含: "连接池", "最大连接数" │ │ +│ │ → 候选 2 直接匹配,置信度 0.95 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +Step 3: 返回决策 +┌─────────────────────────────────────────────────────────────────────────────┐ +│ PilotDecision { │ +│ ranked_candidates: [ │ +│ (Node 2 "连接池", score: 0.95, reason: "摘要直接匹配查询关键词"), │ │ +│ (Node 3 "超时设置", score: 0.30, reason: "不太相关"), │ │ +│ (Node 1 "连接字符串", score: 0.20, reason: "不相关"), │ │ +│ ], │ +│ direction: GoDeeper, │ +│ confidence: 0.95, │ +│ reasoning: "候选节点'连接池'的摘要明确提到'最大连接数',直接匹配查询", │ │ +│ } │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键洞察 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 关键洞察 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. Index 阶段的 summary 质量决定 Pilot 效果 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 好的 summary: "配置连接池大小、超时、最大连接数等参数" │ │ +│ │ 差的 summary: "本章介绍连接池相关内容" │ │ +│ │ │ │ +│ │ → Index Enrich 阶段的 prompt 很重要! │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 2. TOC View 需要动态生成 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 不是整个文档的 TOC,而是从"当前节点"视角的局部视图 │ │ +│ │ 包含: 兄弟节点 + 子节点 + 父节点链 │ │ +│ │ │ │ +│ │ 这样 Token 消耗可控,且有上下文 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 3. 类比: 高德地图导航 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ TOC View = 地图 (道路网络) │ │ +│ │ Summary = 路标 (路口描述) │ │ +│ │ Current Path = GPS 定位 (当前位置) │ │ +│ │ Candidates = 前方路口 (可选方向) │ │ +│ │ Query = 目的地 (要去哪里) │ │ +│ │ │ │ +│ │ Pilot = 驾驶员 (综合以上信息做决策) │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### ContextBuilder Token 预算分配 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ ContextBuilder - Token 预算分配 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Token 预算分配 (假设 500 tokens 总预算): │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ ┌────────────────────────────────────────┐ 30% (150 tokens) │ │ +│ │ │ Query + Intent │ │ │ +│ │ │ "PostgreSQL 连接池最大连接数配置" │ │ │ +│ │ └────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ ┌────────────────────────────────────────┐ 20% (100 tokens) │ │ +│ │ │ Current Path │ │ │ +│ │ │ Root → 配置 → 数据库配置 │ │ │ +│ │ └────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ ┌────────────────────────────────────────┐ 40% (200 tokens) │ │ +│ │ │ Candidates (title + summary each) │ │ │ +│ │ │ A. 连接字符串 [配置URL和认证] │ │ │ +│ │ │ B. 连接池 [配置池大小、最大连接数] │ │ │ +│ │ │ C. 超时设置 [配置超时时间] │ │ │ +│ │ └────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ ┌────────────────────────────────────────┐ 10% (50 tokens) │ │ +│ │ │ Sibling Context (兄弟节点概览) │ │ │ +│ │ │ 同级还有: 基本配置、高级配置 │ │ │ +│ │ └────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 动态调整策略: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ if candidates.len() > 5: │ │ +│ │ // 候选太多,减少每个候选的 detail │ │ +│ │ 只包含 title,不包含 summary │ │ +│ │ │ │ +│ │ if depth > 3: │ │ +│ │ // 深层搜索,减少 TOC 范围 │ │ +│ │ 只显示当前层和子层 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 2. 介入点详细设计 + +### 2.1 介入点类型 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 介入点 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ START - 搜索开始 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 时机: 搜索算法开始前 │ │ +│ │ 任务: 理解查询意图,确定搜索起点和优先方向 │ │ +│ │ 输入: query, tree (ToC view) │ │ +│ │ 输出: entry_points, initial_direction, confidence │ │ +│ │ 配置: guide_at_start: bool │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ FORK - 分叉路口 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 时机: 当前节点有多个候选子节点时 │ │ +│ │ 任务: 判断哪个分支更可能包含答案 │ │ +│ │ 输入: path, candidates, query │ │ +│ │ 输出: ranked_candidates, direction, confidence │ │ +│ │ 触发条件: candidates.len() > fork_threshold │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ BACKTRACK - 回溯 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 时机: Judge 判断内容不足,需要回溯时 │ │ +│ │ 任务: 分析失败原因,建议新的搜索方向 │ │ +│ │ 输入: failed_path, visited, query │ │ +│ │ 输出: alternative_branches, backtrack_reason │ │ +│ │ 配置: guide_at_backtrack: bool │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ EVALUATE - 节点评估 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 时机: 需要判断当前节点是否包含答案时 │ │ +│ │ 任务: 评估节点内容与查询的相关性 │ │ +│ │ 输入: node_content, query │ │ +│ │ 输出: relevance_score, is_answer, reasoning │ │ +│ │ 触发条件: 到达叶子节点或算法不确定时 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 2.2 介入判断逻辑 + +```rust +impl Pilot for LlmPilot { + fn should_intervene(&self, state: &SearchState<'_>) -> bool { + let config = &self.config.intervention; + + // 条件 1: 预算检查(最高优先级) + if !self.budget.can_call() { + return false; + } + + // 条件 2: 候选数量超过阈值(分叉路口) + if state.candidates.len() > config.fork_threshold { + return true; + } + + // 条件 3: 候选分数接近(算法无法区分) + if self.scores_are_close(state.candidates, state.tree, config.score_gap_threshold) { + return true; + } + + // 条件 4: 当前分数过低(可能走错方向) + if state.best_score < config.low_score_threshold { + return true; + } + + // 条件 5: 回溯时且配置允许 + if state.is_backtracking && self.config.guide_at_backtrack { + return true; + } + + // 条件 6: 每层介入次数限制 + let level_calls = self.get_level_calls(state.depth); + if level_calls >= config.max_interventions_per_level { + return false; + } + + false + } +} + +/// 判断候选分数是否接近 +fn scores_are_close(&self, candidates: &[NodeId], tree: &DocumentTree, threshold: f32) -> bool { + if candidates.len() < 2 { + return false; + } + + let scores: Vec = candidates.iter() + .map(|&id| self.scorer.quick_score(tree, id)) + .collect(); + + let max_score = scores.iter().cloned().fold(0.0, f32::max); + let min_score = scores.iter().cloned().fold(1.0, f32::min); + + (max_score - min_score) < threshold +} +``` + +### 2.3 介入配置 + +```rust +/// 介入配置 +#[derive(Debug, Clone)] +pub struct InterventionConfig { + /// 候选数量阈值(超过此值考虑介入) + pub fork_threshold: usize, + /// 分数差距阈值(差距小于此值时介入) + pub score_gap_threshold: f32, + /// 低分阈值(最高分低于此值时介入) + pub low_score_threshold: f32, + /// 每层最大介入次数 + pub max_interventions_per_level: usize, +} + +impl Default for InterventionConfig { + fn default() -> Self { + Self { + fork_threshold: 3, // 3 个以上候选时介入 + score_gap_threshold: 0.15, // 分数差距 < 0.15 时介入 + low_score_threshold: 0.3, // 分数 < 0.3 时介入 + max_interventions_per_level: 2, // 每层最多介入 2 次 + } + } +} +``` + +--- + +## 3. Fallback 机制 + +### 3.1 降级层级 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Fallback 降级层级 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Level 0: 正常 LLM 调用 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 条件: 预算充足,LLM 服务可用 │ │ +│ │ 行为: 正常调用 LLM,获取决策 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ 失败 │ +│ ▼ │ +│ Level 1: 重试 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 条件: 网络错误、超时、rate limit │ │ +│ │ 行为: 指数退避重试,最多 3 次 │ │ +│ │ 参数: initial_delay=1s, max_delay=10s, max_attempts=3 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ 失败 │ +│ ▼ │ +│ Level 2: 简化上下文 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 条件: token 超限、上下文过长 │ │ +│ │ 行为: 减少上下文信息,只保留核心内容 │ │ +│ │ 策略: 移除 ToC,只保留当前节点和候选标题 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ 失败 │ +│ ▼ │ +│ Level 3: 纯算法模式 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 条件: LLM 完全不可用、预算耗尽 │ │ +│ │ 行为: 完全依赖算法打分,不调用 LLM │ │ +│ │ 结果: 使用 NodeScorer 的关键词匹配 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 3.2 Fallback 策略定义 + +```rust +/// 降级策略 +#[derive(Debug, Clone)] +pub enum FallbackStrategy { + /// 重试策略 + Retry { + max_attempts: usize, + backoff: BackoffPolicy, + }, + /// 简化上下文 + SimplifyContext { + remove_toc: bool, + max_candidates: usize, + }, + /// 使用算法替代 + UseAlgorithm, + /// 返回默认决策 + ReturnDefault, +} + +/// 退避策略 +#[derive(Debug, Clone)] +pub enum BackoffPolicy { + /// 固定间隔 + Fixed { delay_ms: u64 }, + /// 线性增长 + Linear { initial_ms: u64, increment_ms: u64 }, + /// 指数增长 + Exponential { initial_ms: u64, multiplier: f64, max_ms: u64 }, +} + +impl Default for BackoffPolicy { + fn default() -> Self { + Self::Exponential { + initial_ms: 1000, + multiplier: 2.0, + max_ms: 10000, + } + } +} +``` + +### 3.3 FallbackManager 实现 + +```rust +/// 降级管理器 +pub struct FallbackManager { + config: FallbackConfig, + /// 当前降级级别 + current_level: AtomicU8, + /// 连续失败次数 + consecutive_failures: AtomicUsize, +} + +impl FallbackManager { + /// 执行带降级的调用 + pub async fn execute_with_fallback( + &self, + operation: F, + ) -> Result + where + F: Fn() -> std::pin::Pin> + Send>>, + { + let mut level = self.current_level.load(Ordering::Relaxed); + + loop { + match level { + 0 => { + // Level 0: 正常调用 + match operation().await { + Ok(result) => { + self.on_success(); + return Ok(result); + } + Err(e) => { + self.on_failure(); + if self.should_escalate() { + level = 1; + continue; + } + return Err(FallbackError::from(e)); + } + } + } + 1 => { + // Level 1: 重试 + match self.retry_operation(&operation).await { + Ok(result) => { + self.on_success(); + return Ok(result); + } + Err(_) => { + level = 2; + continue; + } + } + } + 2 => { + // Level 2: 简化上下文 + // 由调用方处理,返回特定错误 + return Err(FallbackError::SimplifyContextRequired); + } + 3 => { + // Level 3: 纯算法模式 + return Err(FallbackError::AlgorithmFallback); + } + _ => unreachable!(), + } + } + } + + /// 重试操作 + async fn retry_operation(&self, operation: &F) -> Result + where + F: Fn() -> std::pin::Pin> + Send>>, + { + let policy = &self.config.retry_policy; + let mut delay = policy.initial_delay_ms(); + + for attempt in 0..policy.max_attempts { + if attempt > 0 { + tokio::time::sleep(Duration::from_millis(delay)).await; + delay = policy.next_delay(delay); + } + + match operation().await { + Ok(result) => return Ok(result), + Err(e) if attempt == policy.max_attempts - 1 => return Err(e), + Err(_) => continue, + } + } + + Err(PilotError::RetryExhausted) + } + + fn on_success(&self) { + self.consecutive_failures.store(0, Ordering::Relaxed); + // 逐渐恢复到更高级别 + let current = self.current_level.load(Ordering::Relaxed); + if current > 0 { + self.current_level.fetch_sub(1, Ordering::Relaxed); + } + } + + fn on_failure(&self) { + let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed); + // 连续失败 3 次后升级降级级别 + if failures >= 2 { + let current = self.current_level.load(Ordering::Relaxed); + if current < 3 { + self.current_level.fetch_add(1, Ordering::Relaxed); + } + self.consecutive_failures.store(0, Ordering::Relaxed); + } + } + + fn should_escalate(&self) -> bool { + self.consecutive_failures.load(Ordering::Relaxed) >= 3 + } +} +``` + +--- + +## 4. Token 消耗衡量 + +### 4.1 预算配置 + +```rust +/// 预算配置 +#[derive(Debug, Clone)] +pub struct BudgetConfig { + /// 单次检索最大 token 数 + pub max_tokens_per_query: usize, + /// 单次 LLM 调用最大 token 数 + pub max_tokens_per_call: usize, + /// 单次检索最大 LLM 调用次数 + pub max_calls_per_query: usize, + /// 每层(深度)最大调用次数 + pub max_calls_per_level: usize, + /// 是否硬性限制(true: 超预算直接拒绝;false: 尝试继续) + pub hard_limit: bool, +} + +impl Default for BudgetConfig { + fn default() -> Self { + Self { + max_tokens_per_query: 2000, // 单次检索最多 2000 tokens + max_tokens_per_call: 500, // 单次调用最多 500 tokens + max_calls_per_query: 5, // 最多调用 5 次 + max_calls_per_level: 2, // 每层最多 2 次 + hard_limit: true, + } + } +} +``` + +### 4.2 预算控制器 + +```rust +/// 预算控制器 +pub struct BudgetController { + config: BudgetConfig, + /// 已使用的 token 数 + tokens_used: AtomicUsize, + /// 已调用的次数 + calls_made: AtomicUsize, + /// 每层调用次数 + level_calls: RwLock>, +} + +impl BudgetController { + /// 创建新的预算控制器 + pub fn new(config: BudgetConfig) -> Self { + Self { + config, + tokens_used: AtomicUsize::new(0), + calls_made: AtomicUsize::new(0), + level_calls: RwLock::new(HashMap::new()), + } + } + + /// 检查是否可以调用 LLM + pub fn can_call(&self) -> bool { + let calls = self.calls_made.load(Ordering::Relaxed); + let tokens = self.tokens_used.load(Ordering::Relaxed); + + calls < self.config.max_calls_per_query + && tokens < self.config.max_tokens_per_query + } + + /// 检查特定层是否可以调用 + pub fn can_call_at_level(&self, level: usize) -> bool { + if !self.can_call() { + return false; + } + + let level_calls = self.level_calls.read().unwrap(); + let calls = level_calls.get(&level).copied().unwrap_or(0); + calls < self.config.max_calls_per_level + } + + /// 预估调用成本 + pub fn estimate_cost(&self, context: &str) -> usize { + // 使用 tiktoken 或简单的字符估算 + // 粗略估算:1 token ≈ 4 chars (英文) 或 1.5 chars (中文) + let char_count = context.chars().count(); + // 保守估计,按中文计算 + char_count / 2 + 100 // +100 为输出预留 + } + + /// 检查预估成本是否在预算内 + pub fn can_afford(&self, estimated_cost: usize) -> bool { + let remaining = self.remaining_budget(); + estimated_cost <= remaining && estimated_cost <= self.config.max_tokens_per_call + } + + /// 获取剩余预算 + pub fn remaining_budget(&self) -> usize { + let used = self.tokens_used.load(Ordering::Relaxed); + self.config.max_tokens_per_query.saturating_sub(used) + } + + /// 记录 token 使用 + pub fn record_usage(&self, input_tokens: usize, output_tokens: usize, level: usize) { + let total = input_tokens + output_tokens; + self.tokens_used.fetch_add(total, Ordering::Relaxed); + self.calls_made.fetch_add(1, Ordering::Relaxed); + + // 记录层级调用 + let mut level_calls = self.level_calls.write().unwrap(); + *level_calls.entry(level).or_insert(0) += 1; + } + + /// 获取使用统计 + pub fn get_usage_stats(&self) -> BudgetUsage { + BudgetUsage { + tokens_used: self.tokens_used.load(Ordering::Relaxed), + calls_made: self.calls_made.load(Ordering::Relaxed), + max_tokens: self.config.max_tokens_per_query, + max_calls: self.config.max_calls_per_query, + } + } + + /// 重置(新查询开始时) + pub fn reset(&self) { + self.tokens_used.store(0, Ordering::Relaxed); + self.calls_made.store(0, Ordering::Relaxed); + self.level_calls.write().unwrap().clear(); + } +} + +/// 预算使用统计 +#[derive(Debug, Clone)] +pub struct BudgetUsage { + pub tokens_used: usize, + pub calls_made: usize, + pub max_tokens: usize, + pub max_calls: usize, +} + +impl BudgetUsage { + pub fn utilization(&self) -> f32 { + self.tokens_used as f32 / self.max_tokens as f32 + } +} +``` + +### 4.3 Token 消耗流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Token 消耗流程 │ +└─────────────────────────────────────────────────────────────────────────────┘ + +LLM 调用前: +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 1. BudgetController.can_call() │ +│ └─ 检查: calls_made < max_calls_per_query │ +│ └─ 检查: tokens_used < max_tokens_per_query │ +│ │ +│ 2. BudgetController.can_call_at_level(depth) │ +│ └─ 检查: level_calls[depth] < max_calls_per_level │ +│ │ +│ 3. BudgetController.estimate_cost(context) │ +│ └─ 预估: input_tokens + output_tokens (预留) │ +│ │ +│ 4. BudgetController.can_afford(estimated_cost) │ +│ └─ 检查: estimated_cost <= remaining_budget │ +│ └─ 检查: estimated_cost <= max_tokens_per_call │ +│ │ +│ 决策: 全部通过 → 继续调用;任一失败 → 跳过或降级 │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +LLM 调用: +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LLM Client 返回: │ +│ - usage.prompt_tokens (输入 tokens) │ +│ - usage.completion_tokens (输出 tokens) │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +LLM 调用后: +┌─────────────────────────────────────────────────────────────────────────────┐ +│ BudgetController.record_usage(input_tokens, output_tokens, level) │ +│ └─ tokens_used += input_tokens + output_tokens │ +│ └─ calls_made += 1 │ +│ └─ level_calls[level] += 1 │ +│ │ +│ MetricsCollector.record(...): │ +│ └─ total_input_tokens += input_tokens │ +│ └─ total_output_tokens += output_tokens │ +│ └─ estimated_cost = calculate_cost(tokens, model_price) │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 5. 职责划分 + +### 5.1 模块职责 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 模块职责划分 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ QueryAnalyzer - 查询分析器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 分析查询复杂度(简单/中等/复杂) │ │ +│ │ • 提取关键词和实体 │ │ +│ │ • 识别查询意图(事实查询/对比/解释/操作指南) │ │ +│ │ • 判断是否需要 Pilot 介入 │ │ +│ │ │ │ +│ │ 输入: query: String │ │ +│ │ 输出: QueryAnalysis { complexity, keywords, intent, needs_pilot } │ │ +│ │ │ │ +│ │ 实现策略: │ │ +│ │ • 轻量级:基于规则(关键词计数、句子结构) │ │ +│ │ • 重量级:LLM 分析(复杂查询) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ContextBuilder - 上下文构建器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 构建发送给 LLM 的上下文信息 │ │ +│ │ • 提取当前路径的节点信息(标题、摘要、深度) │ │ +│ │ • 构建候选节点的描述 │ │ +│ │ • 生成 ToC 视图(从当前节点视角) │ │ +│ │ • 控制 token 预算分配 │ │ +│ │ │ │ +│ │ 输入: tree, path, candidates, query │ │ +│ │ 输出: PilotContext { path_info, candidates_info, toc_view } │ │ +│ │ │ │ +│ │ Token 预算分配: │ │ +│ │ • path_info: 20% │ │ +│ │ • candidates_info: 50% │ │ +│ │ • toc_view: 30% │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ PromptBuilder - 提示词构建器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 根据场景选择合适的 prompt 模板 │ │ +│ │ • 填充模板变量 │ │ +│ │ • 管理 system prompt 和 user prompt │ │ +│ │ • 支持多语言 │ │ +│ │ │ │ +│ │ 场景类型: │ │ +│ │ • START: 搜索开始,确定起点 │ │ +│ │ • FORK: 分叉路口,选择分支 │ │ +│ │ • BACKTRACK: 回溯时,分析失败原因 │ │ +│ │ • EVALUATE: 评估节点是否包含答案 │ │ +│ │ │ │ +│ │ 设计要点: │ │ +│ │ • 模板可配置(用户可自定义) │ │ +│ │ • 包含 few-shot 示例(提高质量) │ │ +│ │ • 输出格式明确(JSON schema) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ DecisionEngine - 决策引擎 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 判断何时需要调用 LLM(should_intervene) │ │ +│ │ • 协调 LLM 调用 │ │ +│ │ • 融合算法打分和 LLM 建议 │ │ +│ │ • 做出最终决策 │ │ +│ │ │ │ +│ │ 决策逻辑: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ should_intervene(state) -> bool │ │ │ +│ │ │ │ │ │ +│ │ │ // 策略 1: 分叉路口 │ │ │ +│ │ │ if candidates.len() > config.fork_threshold { return true } │ │ │ +│ │ │ │ │ │ +│ │ │ // 策略 2: 算法不确定 │ │ │ +│ │ │ if scores_are_close(candidates) { return true } │ │ │ +│ │ │ │ │ │ +│ │ │ // 策略 3: 低置信度 │ │ │ +│ │ │ if best_score < config.low_confidence_threshold { return true }│ │ +│ │ │ │ │ │ +│ │ │ // 策略 4: 预算检查 │ │ │ +│ │ │ if budget_exhausted() { return false } │ │ │ +│ │ │ │ │ │ +│ │ │ return false │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 融合逻辑: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ final_score = α * algo_score + β * llm_confidence │ │ │ +│ │ │ │ │ │ +│ │ │ // α 和 β 根据场景动态调整 │ │ │ +│ │ │ // - LLM 高置信度时 β 更高 │ │ │ +│ │ │ // - 算法高分且 LLM 低置信度时 α 更高 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ResponseParser - 响应解析器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 解析 LLM 返回的 JSON │ │ +│ │ • 处理格式错误 │ │ +│ │ • 提取结构化信息(ranked_candidates, direction, confidence) │ │ +│ │ • 验证响应有效性 │ │ +│ │ │ │ +│ │ 解析策略: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ parse(response: String) -> Result │ │ │ +│ │ │ │ │ │ +│ │ │ // 优先级 1: JSON 解析 │ │ │ +│ │ │ if let Ok(json) = parse_json(response) { return json } │ │ │ +│ │ │ │ │ │ +│ │ │ // 优先级 2: 正则提取 │ │ │ +│ │ │ if let Some(data) = extract_by_regex(response) { return data }│ │ +│ │ │ │ │ │ +│ │ │ // 优先级 3: 默认值 │ │ │ +│ │ │ return PilotDecision::default() │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ BudgetController - 预算控制器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 追踪 token 消耗 │ │ +│ │ • 控制 LLM 调用次数 │ │ +│ │ • 预估调用成本 │ │ +│ │ • 强制执行预算限制 +│ │ │ │ │ │ │ │ +│ │ 配置: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ BudgetConfig { │ │ │ +│ │ │ max_tokens_per_query: usize, // 单次检索总预算 │ │ │ +│ │ │ max_tokens_per_call: usize, // 单次调用预算 │ │ │ +│ │ │ max_calls_per_query: usize, // 最大调用次数 │ │ │ +│ │ │ max_calls_per_level: usize, // 每层最大调用 │ │ │ +│ │ │ hard_limit: bool, // 是否硬性限制 │ │ │ +│ │ │ } │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ 接口: │ │ +│ │ • can_call() -> bool │ │ +│ │ • can_call_at_level(level) -> bool │ │ +│ │ • estimate_cost(context) -> usize │ │ +│ │ • can_afford(estimated_cost) -> bool │ │ +│ │ • record_usage(input, output, level) │ │ +│ │ • remaining_budget() -> usize │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ FallbackManager - 降级管理器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 处理 LLM 调用失败 │ │ +│ │ • 提供降级策略 │ │ +│ │ • 记录失败原因 │ │ +│ │ • 自动恢复机制 │ │ +│ │ │ │ +│ │ 降级层级: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Level 0: 正常 LLM 调用 │ │ │ +│ │ │ ↓ 失败 │ │ │ +│ │ │ Level 1: 重试 (最多 3 次,指数退避) │ │ │ +│ │ │ ↓ 失败 │ │ │ +│ │ │ Level 2: 简化 prompt (减少上下文) │ │ │ +│ │ │ ↓ 失败 │ │ │ +│ │ │ Level 3: 纯算法模式 (完全降级) │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 降级策略: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ enum FallbackStrategy { │ │ │ +│ │ │ Retry { max_attempts: usize, backoff: BackoffPolicy }, │ │ │ +│ │ │ SimplifyContext, // 减少上下文信息 │ │ │ +│ │ │ UseAlgorithm, // 使用算法打分 │ │ │ +│ │ │ ReturnDefault, // 返回默认决策 │ │ │ +│ │ │ } │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ PolicyManager - 策略管理器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 管理介入策略配置 │ │ +│ │ • 支持多种运行模式 │ │ +│ │ • 动态调整参数(可选) │ │ +│ │ │ │ +│ │ 策略模式: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ enum PilotMode { │ │ │ +│ │ │ Aggressive, // 激进模式:频繁调用 LLM │ │ │ +│ │ │ Balanced, // 平衡模式:按需调用 (默认) │ │ │ +│ │ │ Conservative, // 保守模式:尽量少调用 │ │ │ +│ │ │ AlgorithmOnly,// 纯算法模式:不调用 LLM │ │ │ +│ │ │ } │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 参数调整: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ // 根据历史效果动态调整 │ │ │ +│ │ │ fn adjust_threshold(&mut self, performance: &PerformanceMetrics) {│ │ +│ │ │ // 如果 LLM 建议准确率高,降低介入阈值 │ │ │ +│ │ │ if performance.llm_accuracy > 0.8 { │ │ │ +│ │ │ self.fork_threshold = 2; │ │ │ +│ │ │ } │ │ │ +│ │ │ } │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ MetricsCollector - 指标收集器 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ 职责: │ │ +│ │ • 收集性能指标 │ │ +│ │ • 追踪 LLM 调用详情 │ │ +│ │ • 计算成本 │ │ +│ │ • 支持可观测性 │ │ +│ │ │ │ +│ │ 指标类型: │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ PilotMetrics { │ │ │ +│ │ │ // 调用统计 │ │ │ +│ │ │ total_calls: usize, │ │ │ +│ │ │ successful_calls: usize, │ │ │ +│ │ │ failed_calls: usize, │ │ │ +│ │ │ fallback_count: usize, │ │ │ +│ │ │ │ │ │ +│ │ │ // Token 统计 │ │ │ +│ │ │ total_input_tokens: usize, │ │ │ +│ │ │ total_output_tokens: usize, │ │ │ +│ │ │ avg_tokens_per_call: f64, │ │ │ +│ │ │ │ │ │ +│ │ │ // 延迟统计 │ │ │ +│ │ │ total_latency_ms: u64, │ │ │ +│ │ │ avg_latency_ms: f64, │ │ │ +│ │ │ p50_latency_ms: u64, │ │ │ +│ │ │ p99_latency_ms: u64, │ │ │ +│ │ │ │ │ │ +│ │ │ // 效果统计 (需要反馈) │ │ │ +│ │ │ llm_decision_accuracy: Option, // LLM 决策准确率 │ │ │ +│ │ │ retrieval_precision: Option, // 检索准确率 │ │ │ +│ │ │ } │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 5.2 Pilot 与 Algorithm 的协作 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 与 Algorithm 协作关系 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 职责边界 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ │ │ +│ │ Pilot (大脑) Algorithm (手脚) │ │ +│ │ ┌─────────────────────┐ ┌─────────────────────┐ │ │ +│ │ │ • 理解查询意图 │ │ • 执行树遍历 │ │ │ +│ │ │ • 分析文档结构 │ │ • 高效搜索路径 │ │ │ +│ │ │ • 语义判断 │ │ • 计算节点分数 │ │ │ +│ │ │ • 方向决策 │ │ • 管理搜索状态 │ │ │ +│ │ │ • 歧义消解 │ │ • 返回搜索结果 │ │ │ +│ │ └─────────────────────┘ └─────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 协作流程 │ │ +│ ├─────────────────────────────────────────────────────────────────────┤ │ +│ │ │ │ +│ │ 1. Algorithm 执行搜索 │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ 2. Algorithm 遇到决策点,询问 Pilot │ │ +│ │ │ Pilot.should_intervene(state) │ │ +│ │ ▼ │ │ +│ │ 3a. Pilot 返回 false → Algorithm 继续用自己的 scorer │ │ +│ │ │ │ │ +│ │ 3b. Pilot 返回 true → Pilot.decide(state) │ │ +│ │ │ │ │ │ +│ │ │ ▼ │ │ +│ │ │ Pilot 返回决策 → Algorithm 融合决策继续搜索 │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ 4. 重复直到搜索完成 │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 6. Pilot 完整调用流程 + +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Pilot 完整调用流程 │ +└─────────────────────────────────────────────────────────────────────────────┘ + +用户查询: "如何配置 PostgreSQL 连接池的最大连接数?" + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 1: QueryAnalyzer 分析查询 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ QueryAnalysis { │ +│ complexity: Medium, // 中等复杂度 │ +│ keywords: ["PostgreSQL", "连接池", "最大连接数", "配置"], │ +│ intent: HowTo, // 操作指南类 │ +│ needs_pilot: true, // 需要 Pilot 介入 │ +│ } │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 2: Pilot.guide_start() - 搜索前指导 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ BudgetController: 检查预算 (通过) │ +│ │ +│ ContextBuilder: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ToC View: │ │ +│ │ 1. 简介 │ │ +│ │ 2. 安装 │ │ +│ │ 3. 配置 │ │ +│ │ 3.1 基本配置 │ │ +│ │ 3.2 数据库配置 │ │ +│ │ 3.3 高级配置 │ │ +│ │ 4. API 参考 │ │ +│ │ ... │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ PromptBuilder: 构建 START 场景 prompt │ +│ │ +│ LLM Response: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ { │ │ +│ │ "entry_points": ["配置", "数据库配置"], │ │ +│ │ "reasoning": "查询关于数据库连接池配置,应从配置章节开始", │ │ +│ │ "confidence": 0.9 │ │ +│ │ } │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ MetricsCollector: 记录 (input: 150, output: 50, latency: 230ms) │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 3: BeamSearch 开始搜索 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 迭代 1: Root → [简介, 安装, 配置, API, ...] │ +│ │ +│ Algorithm 打分: │ +│ "配置" -> 0.75 (关键词匹配) │ +│ "API" -> 0.35 │ +│ "安装" -> 0.10 │ +│ │ +│ Pilot.should_intervene(): │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ candidates.len() (6) > fork_threshold (3) → true │ │ +│ │ → 需要介入 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ Pilot.decide(): │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ LLM 分析: │ │ +│ │ "查询明确指向配置相关内容,'配置'章节最相关" │ │ +│ │ │ │ +│ │ ranked_candidates: [ │ │ +│ │ ("配置", 0.95, "明确提到配置"), │ │ +│ │ ("API", 0.40, "可能有相关 API"), │ │ +│ │ ] │ │ +│ │ confidence: 0.9 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 融合打分: │ +│ "配置" -> 0.75*0.4 + 0.95*0.6*0.9 = 0.84 │ +│ │ +│ 选择: "配置" 节点深入 │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 4: 继续搜索 - 迭代 2 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 当前位置: Root → 配置 │ +│ 候选: [基本配置, 数据库配置, 高级配置, 性能调优] │ +│ │ +│ Algorithm 打分: │ +│ "数据库配置" -> 0.92 (强匹配!) │ +│ "高级配置" -> 0.45 │ +│ │ +│ Pilot.should_intervene(): │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ best_score (0.92) > low_score_threshold (0.3) → OK │ │ +│ │ score_gap (0.47) > threshold (0.15) → OK │ │ +│ │ → 不需要介入,算法很确定 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 直接使用算法打分,选择 "数据库配置" │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 5: 继续搜索 - 迭代 3 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 当前位置: Root → 配置 → 数据库配置 │ +│ 候选: [连接字符串, 连接池, 超时设置, SSL配置] │ +│ │ +│ Algorithm 打分: │ +│ "连接池" -> 0.98 (完美匹配!) │ +│ │ +│ → 找到目标,搜索结束 │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Step 6: 返回结果 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ SearchResult { │ +│ path: [Root → 配置 → 数据库配置 → 连接池], │ +│ nodes_visited: 8, │ +│ } │ +│ │ +│ PilotMetrics { │ +│ llm_calls: 2, │ +│ total_tokens: 380, │ +│ avg_latency: 185ms, │ +│ estimated_cost: $0.0012, │ +│ } │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + + +7. 代码结构 + +``` +src/retrieval/ +├── mod.rs +├── pilot/ # Pilot 模块 +│ ├── mod.rs # 模块入口 +│ ├── trait.rs # Pilot trait 定义 +│ ├── config.rs # 配置类型(PilotConfig, BudgetConfig, InterventionConfig) +│ ├── decision.rs # 决策类型(PilotDecision, SearchDirection) +│ ├── analyzer.rs # QueryAnalyzer +│ ├── builder.rs # ContextBuilder +│ ├── engine.rs # DecisionEngine +│ ├── parser.rs # ResponseParser +│ ├── policy.rs # PolicyManager +│ ├── budget.rs # BudgetController +│ ├── fallback.rs # FallbackManager +│ ├── metrics.rs # MetricsCollector +│ ├── llm_pilot.rs # LlmPilot 实现 +│ ├── noop_pilot.rs # NoopPilot 实现(空实现,用于纯算法模式) +│ └── prompts/ # Prompt 模板 +│ ├── mod.rs +│ ├── start.rs # START 场景模板 +│ ├── fork.rs # FORK 场景模板 +│ ├── backtrack.rs # BACKTRACK 场景模板 +│ └── evaluate.rs # EVALUATE 场景模板 +├── search/ +│ ├── mod.rs +│ ├── trait.rs # SearchTree trait(修改:增加 pilot 参数) +│ ├── scorer.rs # NodeScorer(现有) +│ ├── beam.rs # BeamSearch(修改:集成 Pilot) +│ ├── greedy.rs # GreedySearch(修改:集成 Pilot) +│ └── mcts.rs # MctsSearch(修改:集成 Pilot) +├── stages/ +│ ├── search.rs # SearchStage(修改:注入 Pilot) +│ └── ... +└── ... +``` + +--- + +## 7. 配置示例 + +```rust +// 默认配置 +let config = PilotConfig { + mode: PilotMode::Balanced, + budget: BudgetConfig::default(), + intervention: InterventionConfig::default(), + guide_at_start: true, + guide_at_backtrack: true, + prompt_template_path: None, +}; + +// 高质量模式(更多 LLM 调用) +let high_quality_config = PilotConfig { + mode: PilotMode::Aggressive, + budget: BudgetConfig { + max_tokens_per_query: 5000, + max_tokens_per_call: 1000, + max_calls_per_query: 10, + max_calls_per_level: 3, + hard_limit: false, + }, + intervention: InterventionConfig { + fork_threshold: 2, + score_gap_threshold: 0.2, + low_score_threshold: 0.4, + max_interventions_per_level: 3, + }, + guide_at_start: true, + guide_at_backtrack: true, + prompt_template_path: None, +}; + +// 低成本模式(最少 LLM 调用) +let low_cost_config = PilotConfig { + mode: PilotMode::Conservative, + budget: BudgetConfig { + max_tokens_per_query: 500, + max_tokens_per_call: 200, + max_calls_per_query: 2, + max_calls_per_level: 1, + hard_limit: true, + }, + intervention: InterventionConfig { + fork_threshold: 5, + score_gap_threshold: 0.1, + low_score_threshold: 0.2, + max_interventions_per_level: 1, + }, + guide_at_start: false, + guide_at_backtrack: true, + prompt_template_path: None, +}; + +// 纯算法模式(不调用 LLM) +let algorithm_only_config = PilotConfig { + mode: PilotMode::AlgorithmOnly, + ..Default::default() +}; +``` + +--- + +## 8. 使用示例 + +```rust +use vectorless::retrieval::pilot::{LlmPilot, PilotConfig, PilotMode}; +use vectorless::retrieval::search::BeamSearch; +use vectorless::llm::LlmClient; + +// 创建 Pilot +let llm_client = LlmClient::from_env()?; +let pilot = LlmPilot::new(llm_client, PilotConfig::default()); + +// 创建搜索引擎(注入 Pilot) +let search = BeamSearch::new().with_pilot(pilot); + +// 执行搜索 +let result = search.search(&tree, &context, &config).await?; + +// 查看指标 +println!("LLM calls: {}", result.metrics.llm_calls); +println!("Tokens used: {}", result.metrics.tokens_used); +println!("Avg latency: {}ms", result.metrics.avg_latency_ms); +``` diff --git a/docs/design/v3.md b/docs/design/v3.md new file mode 100644 index 00000000..df575a77 --- /dev/null +++ b/docs/design/v3.md @@ -0,0 +1,436 @@ +# V3 Design: LLM Navigator + Algorithm 协同检索 + +## 🏗️ 架构设计:LLM + 算法协同的 Retriever Pipeline + +### 核心设计原则 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 设计哲学 │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. 算法负责 "怎么走" - 高效、确定性、低延迟 │ +│ 2. LLM 负责 "去哪里" - 语义理解、歧义消解、方向判断 │ +│ 3. 关键决策点介入 - 不是每步都问 LLM,而是在需要时才问 │ +│ 4. 分层 fallback - LLM 失败时算法接管,算法失败时 LLM 救援 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 整体架构 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Index Pipeline (不变) │ +│ Parse → Build → Enhance → Enrich(LLM) → Optimize │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────┐ + │ DocumentTree │ + │ + NodeSummary │ + └─────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Retrieval Pipeline (增强) │ +│ │ +│ ┌─────────┐ ┌─────────┐ ┌─────────────────────┐ ┌─────────┐ │ +│ │ Analyze │───▶│ Plan │───▶│ Search │───▶│ Judge │ │ +│ │ (LLM?) │ │ (LLM?) │ │ ┌───────────────┐ │ │ (LLM) │ │ +│ └─────────┘ └─────────┘ │ │ Navigator │ │ └─────────┘ │ +│ │ │ │ │ ┌───────────┐ │ │ │ │ +│ │ │ │ │ │ LLM + │ │ │ │ │ +│ ▼ ▼ │ │ │ Algorithm │ │ │ ▼ │ +│ ┌─────────────────────────┐ │ │ └───────────┘ │ │ ┌───────────┐ │ +│ │ LLM Navigator │◀──┼──┤ │ │ │ NeedMore │ │ +│ │ (关键决策点介入) │ │ │ Search Alg │ │ │ ◀───────│ │ +│ └─────────────────────────┘ │ │ (Greedy/Beam)│ │ └───────────┘ │ +│ │ │ └───────────────┘ │ │ │ +│ └──────────────────┴─────────────────────┘ │ │ +│ ▼ │ +│ ┌───────────┐ │ +│ │ Backtrack │───┘ +│ └───────────┘ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 🧭 LLM Navigator 设计 + +### Navigator 的职责 + +Navigator 不是替代 Search 算法,而是**在关键决策点提供语义判断**: + +``` +┌────────────────────────────────────────────────────────────┐ +│ LLM Navigator 职责 │ +├────────────────┬───────────────────────────────────────────┤ +│ 时机 │ LLM 任务 │ +├────────────────┼───────────────────────────────────────────┤ +│ 搜索开始前 │ 理解 query,确定搜索起点和优先方向 │ +│ 分叉路口 │ 多个候选路径时,判断哪个更相关 │ +│ 迷路时 │ 算法陷入低分路径时,提供纠正建议 │ +│ 不确定时 │ 算法评分接近时,做语义判断 │ +│ 回溯时 │ 分析失败原因,建议新的搜索方向 │ +└────────────────┴───────────────────────────────────────────┘ +``` + +### Navigator 接口设计 + +```rust +/// LLM Navigator - 在关键决策点提供语义导航 +pub struct LlmNavigator { + client: LlmClient, + config: NavigatorConfig, +} + +/// Navigator 配置 +pub struct NavigatorConfig { + /// 是否在搜索开始前介入 + pub guide_at_start: bool, + /// 是否在分叉路口介入 (候选数 > threshold 时) + pub guide_at_fork: bool, + /// 分叉路口阈值 + pub fork_threshold: usize, + /// 是否在回溯时介入 + pub guide_at_backtrack: bool, + /// 低分阈值 (低于此值时请求 LLM 干预) + pub low_score_threshold: f32, + /// 最大 LLM 调用次数 (控制成本) + pub max_llm_calls: usize, +} + +/// 导航建议 +pub struct NavigationGuidance { + /// 推荐的节点顺序 (按相关性排序) + pub preferred_order: Vec, + /// 推荐的搜索方向 + pub direction: SearchDirection, + /// LLM 的推理过程 (可解释性) + pub reasoning: String, + /// 置信度 + pub confidence: f32, +} + +pub enum SearchDirection { + /// 深入当前分支 + GoDeeper, + /// 探索兄弟节点 + ExploreSiblings, + /// 回溯到父节点 + Backtrack, + /// 跳转到特定节点 + JumpTo(NodeId), + /// 当前路径就是答案 + ThisIsIt, +} + +impl LlmNavigator { + /// 搜索开始前:理解 query,确定起点 + pub async fn guide_start( + &self, + tree: &DocumentTree, + query: &str, + ) -> Result; + + /// 分叉路口:选择最佳分支 + pub async fn guide_fork( + &self, + tree: &DocumentTree, + current_path: &[NodeId], + candidates: &[NodeId], + query: &str, + ) -> Result; + + /// 回溯时:分析失败,建议新方向 + pub async fn guide_backtrack( + &self, + tree: &DocumentTree, + failed_path: &[NodeId], + visited: &HashSet, + query: &str, + ) -> Result; +} +``` + +--- + +## 🔄 Search 阶段集成方案 + +### 新的 Search 架构 + +```rust +/// 增强 Search 阶段 - 算法 + LLM 协同 +pub struct SearchStage { + /// 搜索算法 + algorithm: SearchAlgorithm, + /// LLM Navigator (可选) + navigator: Option>, + /// 配置 + config: SearchConfig, +} + +/// 协同搜索器 +pub struct CollaborativeSearch { + /// 底层搜索算法 + algorithm: Box, + /// LLM Navigator + navigator: LlmNavigator, + /// 调用统计 + stats: SearchStats, +} + +impl CollaborativeSearch { + pub async fn search(&mut self, tree: &DocumentTree, ctx: &RetrievalContext) -> SearchResult { + let mut result = SearchResult::default(); + let mut state = SearchState::new(tree.root()); + + // 1. 开始前:LLM 指导起点 + if self.navigator.config.guide_at_start { + let guidance = self.navigator.guide_start(tree, &ctx.query).await?; + state.apply_guidance(guidance); + } + + // 2. 搜索循环 + while !state.is_complete() { + // 2.1 算法选择候选 + let candidates = self.algorithm.select_candidates(tree, &state); + + // 2.2 判断是否需要 LLM 介入 + if self.should_consult_llm(&candidates, &state) { + let guidance = self.navigator.guide_fork( + tree, + &state.path, + &candidates, + &ctx.query + ).await?; + + // 2.3 用 LLM 建议重排序候选 + state.candidates = self.merge_algorithm_and_llm( + candidates, + guidance + ); + } + + // 2.4 算法执行下一步 + self.algorithm.step(tree, &mut state); + + // 2.5 检查是否需要回溯 + if state.needs_backtrack() { + if self.navigator.config.guide_at_backtrack { + let guidance = self.navigator.guide_backtrack( + tree, + &state.path, + &state.visited, + &ctx.query + ).await?; + state.apply_backtrack_guidance(guidance); + } else { + state.backtrack(); + } + } + + self.stats.iterations += 1; + } + + result + } + + /// 判断是否需要咨询 LLM + fn should_consult_llm(&self, candidates: &[NodeId], state: &SearchState) -> bool { + // 条件 1: 候选数量超过阈值 (分叉路口) + if candidates.len() > self.navigator.config.fork_threshold { + return true; + } + + // 条件 2: 候选分数接近 (算法无法区分) + if self.scores_are_close(candidates) { + return true; + } + + // 条件 3: 当前分数过低 (可能走错方向) + if state.best_score < self.navigator.config.low_score_threshold { + return true; + } + + // 条件 4: 未超过 LLM 调用限制 + self.stats.llm_calls < self.navigator.config.max_llm_calls + } +} +``` + +--- + +## 📊 Pipeline 各阶段的 LLM 介入点 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Retrieval Pipeline │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Analyze Stage │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ [算法] 关键词提取、复杂度估计 │ │ +│ │ [LLM] 可选:深度语义分析、意图识别 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Plan Stage │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ [算法] 根据复杂度选择策略 (keyword/llm/semantic) │ │ +│ │ [LLM] 可选:复杂查询的策略推荐 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Search Stage ◀━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────────────────────────────┐ │ │ +│ │ │ Algorithm │────▶│ LLM Navigator │ │ │ +│ │ │ (主控) │ │ ┌─────────────────────────────┐ │ │ │ +│ │ │ │ │ │ guide_start() 开始指导 │ │ │ │ +│ │ │ - Greedy │◀───▶│ │ guide_fork() 分叉选择 │ │ │ │ +│ │ │ - Beam │ │ │ guide_backtrack()回溯指导 │ │ │ │ +│ │ │ - MCTS │ │ └─────────────────────────────┘ │ │ │ +│ │ │ │ │ │ │ │ +│ │ └─────────────┘ └─────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Judge Stage │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ [算法] Token 数量检查、阈值判断 │ │ +│ │ [LLM] 内容充分性判断、答案完整性评估 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌───────────────┐ │ +│ │ Sufficient? │─── No ──▶ Backtrack ──┐ │ +│ └───────────────┘ │ │ +│ │ Yes │ │ +│ ▼ │ │ +│ ┌───────────────┐ │ │ +│ │ Result │◀───────────────────────┘ │ +│ └───────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 🎯 实施方案 + +### 阶段一:基础集成 (1-2 周) + +```rust +// 1. 定义 Navigator trait 和基础实现 +pub trait Navigator: Send + Sync { + async fn guide_fork(&self, ctx: &NavigationContext) -> NavigationGuidance; +} + +// 2. 在 SearchStage 中集成 +pub struct SearchStage { + algorithm: SearchAlgorithm, + navigator: Option>, // 新增 +} + +// 3. 修改搜索循环,在分叉点调用 navigator +``` + +### 阶段二:增强能力 (2-3 周) + +```rust +// 1. 实现完整的 LlmNavigator +// 2. 添加 guide_start, guide_backtrack +// 3. 实现智能介入判断逻辑 +// 4. 添加缓存 (相同 query + 相同上下文 → 缓存结果) +``` + +### 阶段三:优化与监控 (1-2 周) + +```rust +// 1. 添加 A/B 测试能力 (纯算法 vs 算法+LLM) +// 2. 添加成本控制 (max_llm_calls, budget) +// 3. 添加效果监控 (检索准确率、延迟、成本) +// 4. 自适应介入 (根据历史效果动态调整介入频率) +``` + +--- + +## 📁 代码结构建议 + +``` +src/retrieval/ +├── mod.rs +├── pipeline/ +│ ├── mod.rs +│ ├── stage.rs +│ ├── orchestrator.rs +│ └── context.rs +├── stages/ +│ ├── analyze.rs +│ ├── plan.rs +│ ├── search.rs # 集成 Navigator +│ └── judge.rs +├── search/ +│ ├── mod.rs +│ ├── trait.rs +│ ├── greedy.rs +│ ├── beam.rs +│ └── mcts.rs +├── navigator/ # 新增模块 +│ ├── mod.rs +│ ├── trait.rs # Navigator trait +│ ├── llm_navigator.rs # LLM 实现 +│ ├── noop_navigator.rs # 空实现 +│ ├── guidance.rs # NavigationGuidance 类型 +│ └── config.rs # NavigatorConfig +├── strategy/ +│ ├── mod.rs +│ ├── keyword.rs +│ ├── llm.rs +│ └── semantic.rs +``` + +--- + +## 🤔 几个关键问题 + +### Q1: Navigator 和 Strategy 的区别? + +| | Strategy | Navigator | +|---|----------|-----------| +| 粒度 | 单节点评估 | 全局导航建议 | +| 输入 | 单个节点信息 | 路径 + 候选 + 上下文 | +| 输出 | 分数 (0-1) | 方向 + 排序 + 推理 | +| 调用频率 | 每个候选节点 | 关键决策点 | + +### Q2: 如何控制 LLM 调用成本? + +```rust +pub struct CostControl { + /// 单次检索最大 LLM 调用 + max_calls_per_query: usize, + /// 每日预算 + daily_budget: Option, + /// 低置信度时才调用 + min_uncertainty: f32, +} +``` + +### Q3: 如何评估效果? + +```rust +pub struct RetrievalMetrics { + /// 检索准确率 + pub precision: f32, + /// 检索召回率 + pub recall: f32, + /// LLM 调用次数 + pub llm_calls: usize, + /// 总延迟 + pub latency_ms: u64, + /// 成本 + pub cost: Money, +} +``` diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs index cb58b930..5ff07413 100644 --- a/src/retrieval/mod.rs +++ b/src/retrieval/mod.rs @@ -54,6 +54,7 @@ mod types; pub mod cache; pub mod complexity; +pub mod pilot; pub mod pipeline; pub mod search; pub mod stages; @@ -96,3 +97,10 @@ pub use complexity::ComplexityDetector; // Cache exports pub use cache::PathCache; + +// Pilot exports +pub use pilot::{ + BudgetConfig, InterventionConfig, InterventionPoint, Pilot, PilotConfig, PilotDecision, + PilotMode, RankedCandidate, SearchDirection, SearchState, +}; +pub use pilot::NoopPilot; diff --git a/src/retrieval/pilot/budget.rs b/src/retrieval/pilot/budget.rs new file mode 100644 index 00000000..defa1847 --- /dev/null +++ b/src/retrieval/pilot/budget.rs @@ -0,0 +1,351 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Budget controller for Pilot LLM calls. +//! +//! Tracks token consumption and call counts to enforce budget limits +//! and control costs during retrieval. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::RwLock; + +use super::config::BudgetConfig; + +/// Budget usage statistics. +#[derive(Debug, Clone, Default)] +pub struct BudgetUsage { + /// Total input tokens used. + pub input_tokens: usize, + /// Total output tokens used. + pub output_tokens: usize, + /// Total LLM calls made. + pub calls_made: usize, + /// Maximum tokens allowed. + pub max_tokens: usize, + /// Maximum calls allowed. + pub max_calls: usize, +} + +impl BudgetUsage { + /// Get total tokens used (input + output). + pub fn total_tokens(&self) -> usize { + self.input_tokens + self.output_tokens + } + + /// Get token utilization (0.0 - 1.0). + pub fn token_utilization(&self) -> f32 { + if self.max_tokens == 0 { + 0.0 + } else { + (self.total_tokens() as f32 / self.max_tokens as f32).min(1.0) + } + } + + /// Get call utilization (0.0 - 1.0). + pub fn call_utilization(&self) -> f32 { + if self.max_calls == 0 { + 0.0 + } else { + (self.calls_made as f32 / self.max_calls as f32).min(1.0) + } + } + + /// Check if budget is exhausted. + pub fn is_exhausted(&self) -> bool { + self.total_tokens() >= self.max_tokens || self.calls_made >= self.max_calls + } +} + +/// Controller for Pilot budget management. +/// +/// Tracks token usage and call counts per query, enforcing limits +/// to control costs. Thread-safe for concurrent access. +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::{BudgetController, BudgetConfig}; +/// +/// let config = BudgetConfig::default(); +/// let budget = BudgetController::new(config); +/// +/// // Check if we can make a call +/// if budget.can_call() { +/// // Estimate cost first +/// let estimated = budget.estimate_cost(context); +/// if budget.can_afford(estimated) { +/// // Make the call... +/// budget.record_usage(150, 50, 0); +/// } +/// } +/// ``` +pub struct BudgetController { + config: BudgetConfig, + /// Total input tokens used. + input_tokens: AtomicUsize, + /// Total output tokens used. + output_tokens: AtomicUsize, + /// Total calls made. + calls_made: AtomicUsize, + /// Calls per level (for level-based limits). + level_calls: RwLock>, +} + +impl BudgetController { + /// Create a new budget controller with the given config. + pub fn new(config: BudgetConfig) -> Self { + Self { + config, + input_tokens: AtomicUsize::new(0), + output_tokens: AtomicUsize::new(0), + calls_made: AtomicUsize::new(0), + level_calls: RwLock::new(HashMap::new()), + } + } + + /// Create with default configuration. + pub fn with_defaults() -> Self { + Self::new(BudgetConfig::default()) + } + + /// Check if a new LLM call is allowed. + /// + /// Returns `true` if: + /// - Token budget not exhausted + /// - Call count not exceeded + pub fn can_call(&self) -> bool { + let tokens = self.total_tokens(); + let calls = self.calls_made.load(Ordering::Relaxed); + + tokens < self.config.max_tokens_per_query + && calls < self.config.max_calls_per_query + } + + /// Check if a call is allowed at a specific tree level. + pub fn can_call_at_level(&self, level: usize) -> bool { + if !self.can_call() { + return false; + } + + let level_calls = self.level_calls.read().unwrap(); + let calls = level_calls.get(&level).copied().unwrap_or(0); + calls < self.config.max_calls_per_level + } + + /// Estimate token cost for a context string. + /// + /// Uses a simple heuristic: + /// - 1 token ≈ 4 chars (English) + /// - 1 token ≈ 1.5 chars (Chinese) + /// - Plus output reserve (100 tokens) + pub fn estimate_cost(&self, context: &str) -> usize { + let char_count = context.chars().count(); + + // Count Chinese characters + let chinese_count = context + .chars() + .filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c)) + .count(); + + let english_count = char_count - chinese_count; + + // Estimate tokens + let input_tokens = (chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize; + + // Add output reserve + input_tokens + 100 + } + + /// Check if we can afford an estimated cost. + pub fn can_afford(&self, estimated_cost: usize) -> bool { + let remaining = self.remaining_tokens(); + + estimated_cost <= remaining + && estimated_cost <= self.config.max_tokens_per_call + } + + /// Get remaining token budget. + pub fn remaining_tokens(&self) -> usize { + self.config + .max_tokens_per_query + .saturating_sub(self.total_tokens()) + } + + /// Get remaining call budget. + pub fn remaining_calls(&self) -> usize { + self.config + .max_calls_per_query + .saturating_sub(self.calls_made.load(Ordering::Relaxed)) + } + + /// Record token usage after an LLM call. + /// + /// # Arguments + /// + /// * `input_tokens` - Tokens in the prompt + /// * `output_tokens` - Tokens in the response + /// * `level` - Tree level where call was made + pub fn record_usage(&self, input_tokens: usize, output_tokens: usize, level: usize) { + self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed); + self.output_tokens.fetch_add(output_tokens, Ordering::Relaxed); + self.calls_made.fetch_add(1, Ordering::Relaxed); + + // Track level calls + { + let mut level_calls = self.level_calls.write().unwrap(); + *level_calls.entry(level).or_insert(0) += 1; + } + } + + /// Get total tokens used. + pub fn total_tokens(&self) -> usize { + self.input_tokens.load(Ordering::Relaxed) + + self.output_tokens.load(Ordering::Relaxed) + } + + /// Get current usage statistics. + pub fn usage(&self) -> BudgetUsage { + BudgetUsage { + input_tokens: self.input_tokens.load(Ordering::Relaxed), + output_tokens: self.output_tokens.load(Ordering::Relaxed), + calls_made: self.calls_made.load(Ordering::Relaxed), + max_tokens: self.config.max_tokens_per_query, + max_calls: self.config.max_calls_per_query, + } + } + + /// Get calls made at a specific level. + pub fn calls_at_level(&self, level: usize) -> usize { + let level_calls = self.level_calls.read().unwrap(); + level_calls.get(&level).copied().unwrap_or(0) + } + + /// Reset budget state for a new query. + pub fn reset(&self) { + self.input_tokens.store(0, Ordering::Relaxed); + self.output_tokens.store(0, Ordering::Relaxed); + self.calls_made.store(0, Ordering::Relaxed); + self.level_calls.write().unwrap().clear(); + } + + /// Get the configuration. + pub fn config(&self) -> &BudgetConfig { + &self.config + } + + /// Check if hard limit is enforced. + pub fn is_hard_limit(&self) -> bool { + self.config.hard_limit + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_budget_controller_new() { + let config = BudgetConfig::default(); + let max_calls = config.max_calls_per_query; + let budget = BudgetController::new(config); + + assert!(budget.can_call()); + assert_eq!(budget.remaining_calls(), max_calls); + } + + #[test] + fn test_budget_can_call() { + let config = BudgetConfig { + max_tokens_per_query: 100, + max_calls_per_query: 2, + ..Default::default() + }; + let budget = BudgetController::new(config); + + assert!(budget.can_call()); + + budget.record_usage(50, 30, 0); + assert!(budget.can_call()); // 80 tokens, 1 call + + budget.record_usage(50, 30, 0); + assert!(!budget.can_call()); // 160 tokens, 2 calls - exceeded + } + + #[test] + fn test_budget_level_limit() { + let config = BudgetConfig { + max_calls_per_query: 10, + max_calls_per_level: 2, + ..Default::default() + }; + let budget = BudgetController::new(config); + + assert!(budget.can_call_at_level(0)); + + budget.record_usage(10, 10, 0); + budget.record_usage(10, 10, 0); + assert!(!budget.can_call_at_level(0)); // 2 calls at level 0 + assert!(budget.can_call_at_level(1)); // Can still call at level 1 + } + + #[test] + fn test_budget_estimate_cost() { + let budget = BudgetController::with_defaults(); + + // English text - 26 chars ≈ 7 tokens + 100 output reserve = ~107 + let english = "Hello world this is a test"; + let cost = budget.estimate_cost(english); + assert!(cost > 100 && cost < 150, "Expected cost between 100-150, got {}", cost); + + // Chinese text - 6 chars ≈ 4 tokens + 100 output reserve = ~104 + let chinese = "这是一个测试"; + let cost_chinese = budget.estimate_cost(chinese); + // Both have ~100 token base from output reserve, so just check it's reasonable + assert!(cost_chinese > 100, "Expected Chinese cost > 100, got {}", cost_chinese); + } + + #[test] + fn test_budget_can_afford() { + let config = BudgetConfig { + max_tokens_per_query: 200, + max_tokens_per_call: 100, + ..Default::default() + }; + let budget = BudgetController::new(config); + + assert!(budget.can_afford(50)); + assert!(budget.can_afford(100)); + assert!(!budget.can_afford(150)); // Exceeds max_tokens_per_call + + budget.record_usage(100, 50, 0); // 150 tokens used + assert!(budget.can_afford(50)); // 50 remaining + assert!(!budget.can_afford(60)); // Only 50 remaining + } + + #[test] + fn test_budget_reset() { + let budget = BudgetController::with_defaults(); + + budget.record_usage(100, 50, 0); + assert_eq!(budget.total_tokens(), 150); + assert_eq!(budget.calls_made.load(Ordering::Relaxed), 1); + + budget.reset(); + assert_eq!(budget.total_tokens(), 0); + assert_eq!(budget.calls_made.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_budget_usage_stats() { + let budget = BudgetController::with_defaults(); + + budget.record_usage(100, 50, 0); + let usage = budget.usage(); + + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.calls_made, 1); + assert_eq!(usage.total_tokens(), 150); + } +} diff --git a/src/retrieval/pilot/builder.rs b/src/retrieval/pilot/builder.rs new file mode 100644 index 00000000..725b4394 --- /dev/null +++ b/src/retrieval/pilot/builder.rs @@ -0,0 +1,546 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Context builder for Pilot LLM calls. +//! +//! Constructs the context information sent to the LLM, including: +//! - Current path in the document tree +//! - Candidate nodes with their summaries +//! - TOC view for navigation context +//! +//! Token budget is distributed across components: +//! - Query: 30% +//! - Current path: 20% +//! - Candidates: 40% +//! - Sibling context: 10% + +use std::collections::HashSet; + +use crate::domain::{DocumentTree, NodeId}; +use super::SearchState; + +/// Token budget distribution for context building. +#[derive(Debug, Clone)] +pub struct TokenBudget { + /// Total tokens available. + pub total: usize, + /// Tokens for query section. + pub query: usize, + /// Tokens for current path. + pub path: usize, + /// Tokens for candidates. + pub candidates: usize, + /// Tokens for sibling context. + pub siblings: usize, +} + +impl TokenBudget { + /// Create a new token budget with the given total. + pub fn new(total: usize) -> Self { + Self { + total, + query: (total as f32 * 0.30) as usize, + path: (total as f32 * 0.20) as usize, + candidates: (total as f32 * 0.40) as usize, + siblings: (total as f32 * 0.10) as usize, + } + } + + /// Create budget with custom distribution. + pub fn with_distribution(total: usize, query_pct: f32, path_pct: f32, candidates_pct: f32, siblings_pct: f32) -> Self { + let sum = query_pct + path_pct + candidates_pct + siblings_pct; + Self { + total, + query: (total as f32 * query_pct / sum) as usize, + path: (total as f32 * path_pct / sum) as usize, + candidates: (total as f32 * candidates_pct / sum) as usize, + siblings: (total as f32 * siblings_pct / sum) as usize, + } + } +} + +impl Default for TokenBudget { + fn default() -> Self { + Self::new(500) + } +} + +/// Built context for LLM call. +#[derive(Debug, Clone, Default)] +pub struct PilotContext { + /// Formatted query section. + pub query_section: String, + /// Formatted current path. + pub path_section: String, + /// Formatted candidates section. + pub candidates_section: String, + /// Formatted TOC/sibling context. + pub toc_section: String, + /// Estimated total tokens. + pub estimated_tokens: usize, +} + +impl PilotContext { + /// Get the full context as a single string. + pub fn to_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}", + self.query_section, + self.path_section, + self.candidates_section, + self.toc_section + ) + } + + /// Check if context is empty. + pub fn is_empty(&self) -> bool { + self.query_section.is_empty() + && self.path_section.is_empty() + && self.candidates_section.is_empty() + } +} + +/// Context builder for Pilot LLM calls. +/// +/// Builds structured context from search state, optimized for +/// token efficiency while providing enough information for +/// good LLM decisions. +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::ContextBuilder; +/// +/// let builder = ContextBuilder::new(500); +/// let context = builder.build(&state, &tree); +/// println!("Estimated tokens: {}", context.estimated_tokens); +/// ``` +pub struct ContextBuilder { + /// Token budget for context. + budget: TokenBudget, + /// Maximum candidates to include. + max_candidates: usize, + /// Maximum path depth to show. + max_path_depth: usize, + /// Whether to include summaries for candidates. + include_summaries: bool, +} + +impl Default for ContextBuilder { + fn default() -> Self { + Self::new(500) + } +} + +impl ContextBuilder { + /// Create a new context builder with the given token budget. + pub fn new(token_budget: usize) -> Self { + Self { + budget: TokenBudget::new(token_budget), + max_candidates: 10, + max_path_depth: 5, + include_summaries: true, + } + } + + /// Create with custom budget object. + pub fn with_budget(budget: TokenBudget) -> Self { + Self { + budget, + max_candidates: 10, + max_path_depth: 5, + include_summaries: true, + } + } + + /// Set maximum candidates to include. + pub fn with_max_candidates(mut self, max: usize) -> Self { + self.max_candidates = max; + self + } + + /// Set maximum path depth to show. + pub fn with_max_path_depth(mut self, max: usize) -> Self { + self.max_path_depth = max; + self + } + + /// Set whether to include summaries for candidates. + pub fn with_summaries(mut self, include: bool) -> Self { + self.include_summaries = include; + self + } + + /// Build context from search state. + pub fn build(&self, state: &SearchState<'_>) -> PilotContext { + let mut ctx = PilotContext::default(); + + // Build query section + ctx.query_section = self.build_query_section(state.query); + ctx.estimated_tokens += self.estimate_tokens(&ctx.query_section); + + // Build path section + ctx.path_section = self.build_path_section(state.tree, state.path); + ctx.estimated_tokens += self.estimate_tokens(&ctx.path_section); + + // Build candidates section + ctx.candidates_section = self.build_candidates_section(state.tree, state.candidates); + ctx.estimated_tokens += self.estimate_tokens(&ctx.candidates_section); + + // Build TOC section (siblings context) + ctx.toc_section = self.build_toc_section(state.tree, state.path); + ctx.estimated_tokens += self.estimate_tokens(&ctx.toc_section); + + ctx + } + + /// Build context for START intervention point. + pub fn build_start_context(&self, tree: &DocumentTree, query: &str) -> PilotContext { + let mut ctx = PilotContext::default(); + + // Build query section + ctx.query_section = self.build_query_section(query); + ctx.estimated_tokens += self.estimate_tokens(&ctx.query_section); + + // Build full TOC for start + ctx.toc_section = self.build_full_toc(tree); + ctx.estimated_tokens += self.estimate_tokens(&ctx.toc_section); + + ctx + } + + /// Build context for BACKTRACK intervention point. + pub fn build_backtrack_context( + &self, + state: &SearchState<'_>, + failed_path: &[NodeId], + ) -> PilotContext { + let mut ctx = PilotContext::default(); + + // Build query section + ctx.query_section = self.build_query_section(state.query); + ctx.estimated_tokens += self.estimate_tokens(&ctx.query_section); + + // Show failed path + ctx.path_section = format!("Failed path:\n{}", self.build_path_section(state.tree, failed_path)); + ctx.estimated_tokens += self.estimate_tokens(&ctx.path_section); + + // Show unvisited alternatives + ctx.candidates_section = self.build_unvisited_section(state.tree, state.visited); + ctx.estimated_tokens += self.estimate_tokens(&ctx.candidates_section); + + ctx + } + + /// Build query section. + fn build_query_section(&self, query: &str) -> String { + // Truncate if needed + let truncated = if query.chars().count() > self.budget.query * 4 { + let chars: Vec = query.chars().take(self.budget.query * 4).collect(); + format!("{}...", chars.into_iter().collect::()) + } else { + query.to_string() + }; + + format!("User Query:\n{}\n", truncated) + } + + /// Build current path section. + fn build_path_section(&self, tree: &DocumentTree, path: &[NodeId]) -> String { + if path.is_empty() { + return "Current Position: Root\n".to_string(); + } + + let mut result = String::from("Current Path:\n"); + result.push_str("Root"); + + // Limit depth shown + let start = if path.len() > self.max_path_depth { + path.len() - self.max_path_depth + } else { + 0 + }; + + if start > 0 { + result.push_str(" → ..."); + } + + for node_id in path.iter().skip(start) { + if let Some(node) = tree.get(*node_id) { + result.push_str(" → "); + result.push_str(&node.title); + } + } + + result.push('\n'); + result + } + + /// Build candidates section. + fn build_candidates_section(&self, tree: &DocumentTree, candidates: &[NodeId]) -> String { + if candidates.is_empty() { + return "Candidates: (none)\n".to_string(); + } + + let mut result = String::from("Candidate Nodes:\n"); + let mut tokens_used = 0; + let max_tokens = self.budget.candidates; + + for (i, node_id) in candidates.iter().take(self.max_candidates).enumerate() { + if tokens_used >= max_tokens { + result.push_str("... (more candidates omitted)\n"); + break; + } + + if let Some(node) = tree.get(*node_id) { + let entry = if self.include_summaries && !node.summary.is_empty() { + format!("{}. {} [{}]\n", i + 1, node.title, node.summary) + } else { + format!("{}. {}\n", i + 1, node.title) + }; + + tokens_used += self.estimate_tokens(&entry); + result.push_str(&entry); + } + } + + result + } + + /// Build TOC section showing siblings. + fn build_toc_section(&self, tree: &DocumentTree, path: &[NodeId]) -> String { + if path.is_empty() { + return String::new(); + } + + // Get parent of current node + let parent_id = if path.len() >= 2 { + path[path.len() - 2] + } else { + tree.root() + }; + + let siblings = tree.children(parent_id); + if siblings.len() <= 1 { + return String::new(); + } + + let current_id = path[path.len() - 1]; + let mut result = String::from("Sibling Context:\n"); + + for sibling_id in siblings.iter().take(8) { + if let Some(node) = tree.get(*sibling_id) { + let marker = if *sibling_id == current_id { "⭐ " } else { "" }; + result.push_str(&format!(" {}{}\n", marker, node.title)); + } + } + + result + } + + /// Build full TOC for start context. + fn build_full_toc(&self, tree: &DocumentTree) -> String { + let mut result = String::from("Document Structure:\n"); + let mut tokens_used = 0; + let max_tokens = self.budget.siblings + self.budget.candidates; + + fn build_toc_recursive( + tree: &DocumentTree, + node_id: NodeId, + depth: usize, + result: &mut String, + tokens_used: &mut usize, + max_tokens: usize, + max_depth: usize, + ) { + if *tokens_used >= max_tokens || depth > max_depth { + return; + } + + if let Some(node) = tree.get(node_id) { + let indent = " ".repeat(depth); + let entry = format!("{}{}\n", indent, node.title); + *tokens_used += entry.len() / 4; // Rough estimate + result.push_str(&entry); + + // Only show children for first few levels + if depth < max_depth { + for child_id in tree.children(node_id) { + build_toc_recursive(tree, child_id, depth + 1, result, tokens_used, max_tokens, max_depth); + } + } + } + } + + build_toc_recursive( + tree, + tree.root(), + 0, + &mut result, + &mut tokens_used, + max_tokens, + 3, // Max depth to show + ); + + result + } + + /// Build section showing unvisited nodes. + fn build_unvisited_section(&self, tree: &DocumentTree, visited: &HashSet) -> String { + let mut result = String::from("Unvisited Alternatives:\n"); + let mut count = 0; + + // Find unvisited nodes from root's children + for child_id in tree.children(tree.root()) { + if !visited.contains(&child_id) { + if let Some(node) = tree.get(child_id) { + result.push_str(&format!("• {} [{}]\n", node.title, node.summary)); + count += 1; + if count >= 5 { + break; + } + } + } + } + + if count == 0 { + result.push_str("(all branches explored)\n"); + } + + result + } + + /// Estimate token count for a string. + fn estimate_tokens(&self, text: &str) -> usize { + // Rough estimation: 1 token ≈ 4 chars (English) or 1.5 chars (Chinese) + let char_count = text.chars().count(); + let chinese_count = text + .chars() + .filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c)) + .count(); + let english_count = char_count - chinese_count; + + (chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize + } + + /// Get the token budget. + pub fn budget(&self) -> &TokenBudget { + &self.budget + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indextree::Arena; + + fn create_test_tree() -> DocumentTree { + let mut arena = Arena::new(); + let root = arena.new_node(crate::domain::TreeNode { + title: "Root".to_string(), + content: "Root content".to_string(), + summary: "Root summary".to_string(), + depth: 0, + ..Default::default() + }); + + let child1 = arena.new_node(crate::domain::TreeNode { + title: "Configuration".to_string(), + content: "Config content".to_string(), + summary: "Configuration options".to_string(), + depth: 1, + ..Default::default() + }); + + let child2 = arena.new_node(crate::domain::TreeNode { + title: "API Reference".to_string(), + content: "API content".to_string(), + summary: "API documentation".to_string(), + depth: 1, + ..Default::default() + }); + + root.append(child1, &mut arena); + root.append(child2, &mut arena); + + DocumentTree::from_raw(arena, crate::domain::NodeId(root)) + } + + #[test] + fn test_token_budget_distribution() { + let budget = TokenBudget::new(500); + assert_eq!(budget.query, 150); // 30% + assert_eq!(budget.path, 100); // 20% + assert_eq!(budget.candidates, 200); // 40% + assert_eq!(budget.siblings, 50); // 10% + } + + #[test] + fn test_context_builder_creation() { + let builder = ContextBuilder::new(500); + assert_eq!(builder.max_candidates, 10); + assert_eq!(builder.max_path_depth, 5); + assert!(builder.include_summaries); + } + + #[test] + fn test_build_query_section() { + let builder = ContextBuilder::new(500); + let result = builder.build_query_section("How to configure PostgreSQL?"); + assert!(result.contains("How to configure PostgreSQL?")); + assert!(result.starts_with("User Query:")); + } + + #[test] + fn test_build_query_section_truncation() { + let builder = ContextBuilder::new(20); // Very small budget - 20 * 0.30 = 6 tokens for query = ~24 chars + let long_query = "This is a very long query that should be truncated because it exceeds the token budget"; + let result = builder.build_query_section(long_query); + assert!(result.contains("..."), "Expected truncation, got: {}", result); + } + + #[test] + fn test_estimate_tokens_english() { + let builder = ContextBuilder::new(500); + let text = "Hello world"; // 11 chars ≈ 3 tokens + let tokens = builder.estimate_tokens(text); + assert!(tokens >= 2 && tokens <= 4); + } + + #[test] + fn test_estimate_tokens_chinese() { + let builder = ContextBuilder::new(500); + let text = "这是一个测试"; // 6 chars ≈ 4 tokens + let tokens = builder.estimate_tokens(text); + assert!(tokens >= 3 && tokens <= 5); + } + + #[test] + fn test_pilot_context_to_string() { + let ctx = PilotContext { + query_section: "Query".to_string(), + path_section: "Path".to_string(), + candidates_section: "Candidates".to_string(), + toc_section: "TOC".to_string(), + estimated_tokens: 100, + }; + + let result = ctx.to_string(); + assert!(result.contains("Query")); + assert!(result.contains("Path")); + assert!(result.contains("Candidates")); + assert!(result.contains("TOC")); + } + + #[test] + fn test_pilot_context_is_empty() { + let empty = PilotContext::default(); + assert!(empty.is_empty()); + + let non_empty = PilotContext { + query_section: "Query".to_string(), + ..Default::default() + }; + assert!(!non_empty.is_empty()); + } +} diff --git a/src/retrieval/pilot/config.rs b/src/retrieval/pilot/config.rs new file mode 100644 index 00000000..f381393e --- /dev/null +++ b/src/retrieval/pilot/config.rs @@ -0,0 +1,279 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Configuration types for Pilot. +//! +//! This module defines all configuration structures that control +//! Pilot's behavior, including budget limits, intervention thresholds, +//! and operation modes. + +use serde::{Deserialize, Serialize}; + +/// Main Pilot configuration. +/// +/// Controls all aspects of Pilot behavior including budget, +/// intervention strategy, and feature flags. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PilotConfig { + /// Operation mode controlling how aggressively Pilot intervenes. + pub mode: PilotMode, + /// Token and call budget constraints. + pub budget: BudgetConfig, + /// Intervention threshold settings. + pub intervention: InterventionConfig, + /// Whether to provide guidance at search start. + pub guide_at_start: bool, + /// Whether to provide guidance during backtracking. + pub guide_at_backtrack: bool, + /// Optional path to custom prompt templates. + pub prompt_template_path: Option, +} + +impl Default for PilotConfig { + fn default() -> Self { + Self { + mode: PilotMode::Balanced, + budget: BudgetConfig::default(), + intervention: InterventionConfig::default(), + guide_at_start: true, + guide_at_backtrack: true, + prompt_template_path: None, + } + } +} + +impl PilotConfig { + /// Create a new config with specified mode. + pub fn with_mode(mode: PilotMode) -> Self { + Self { + mode, + ..Default::default() + } + } + + /// Create a high-quality config (more LLM calls). + pub fn high_quality() -> Self { + Self { + mode: PilotMode::Aggressive, + budget: BudgetConfig { + max_tokens_per_query: 5000, + max_tokens_per_call: 1000, + max_calls_per_query: 10, + max_calls_per_level: 3, + hard_limit: false, + }, + intervention: InterventionConfig { + fork_threshold: 2, + score_gap_threshold: 0.2, + low_score_threshold: 0.4, + max_interventions_per_level: 3, + }, + guide_at_start: true, + guide_at_backtrack: true, + prompt_template_path: None, + } + } + + /// Create a low-cost config (fewer LLM calls). + pub fn low_cost() -> Self { + Self { + mode: PilotMode::Conservative, + budget: BudgetConfig { + max_tokens_per_query: 500, + max_tokens_per_call: 200, + max_calls_per_query: 2, + max_calls_per_level: 1, + hard_limit: true, + }, + intervention: InterventionConfig { + fork_threshold: 5, + score_gap_threshold: 0.1, + low_score_threshold: 0.2, + max_interventions_per_level: 1, + }, + guide_at_start: false, + guide_at_backtrack: true, + prompt_template_path: None, + } + } + + /// Create a pure algorithm config (no LLM calls). + pub fn algorithm_only() -> Self { + Self { + mode: PilotMode::AlgorithmOnly, + ..Default::default() + } + } +} + +/// Pilot operation mode. +/// +/// Controls the trade-off between LLM usage and algorithm-only search. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +pub enum PilotMode { + /// Aggressive mode: frequent LLM calls for maximum accuracy. + Aggressive, + /// Balanced mode: LLM calls at key decision points (default). + #[default] + Balanced, + /// Conservative mode: minimal LLM calls, rely more on algorithm. + Conservative, + /// Pure algorithm mode: no LLM calls at all. + AlgorithmOnly, +} + +impl PilotMode { + /// Check if this mode uses LLM at all. + pub fn uses_llm(&self) -> bool { + !matches!(self, PilotMode::AlgorithmOnly) + } + + /// Get the fork threshold multiplier for this mode. + pub fn fork_threshold_multiplier(&self) -> f32 { + match self { + PilotMode::Aggressive => 0.5, // Lower threshold = more interventions + PilotMode::Balanced => 1.0, + PilotMode::Conservative => 2.0, // Higher threshold = fewer interventions + PilotMode::AlgorithmOnly => f32::MAX, + } + } +} + +/// Token and call budget configuration. +/// +/// Controls resource consumption during retrieval. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BudgetConfig { + /// Maximum total tokens per query (input + output). + pub max_tokens_per_query: usize, + /// Maximum tokens per single LLM call. + pub max_tokens_per_call: usize, + /// Maximum number of LLM calls per query. + pub max_calls_per_query: usize, + /// Maximum number of LLM calls per tree level. + pub max_calls_per_level: usize, + /// Whether to enforce hard limits (true) or soft limits with warnings (false). + pub hard_limit: bool, +} + +impl Default for BudgetConfig { + fn default() -> Self { + Self { + max_tokens_per_query: 2000, + max_tokens_per_call: 500, + max_calls_per_query: 5, + max_calls_per_level: 2, + hard_limit: true, + } + } +} + +impl BudgetConfig { + /// Check if a given token count is within budget. + pub fn is_within_budget(&self, used: usize) -> bool { + used < self.max_tokens_per_query + } + + /// Get remaining tokens given current usage. + pub fn remaining_tokens(&self, used: usize) -> usize { + self.max_tokens_per_query.saturating_sub(used) + } +} + +/// Intervention threshold configuration. +/// +/// Controls when Pilot decides to intervene in the search process. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InterventionConfig { + /// Minimum number of candidates to trigger fork intervention. + pub fork_threshold: usize, + /// Score gap threshold (intervene when top scores are within this range). + pub score_gap_threshold: f32, + /// Low score threshold (intervene when best score is below this). + pub low_score_threshold: f32, + /// Maximum interventions allowed per tree level. + pub max_interventions_per_level: usize, +} + +impl Default for InterventionConfig { + fn default() -> Self { + Self { + fork_threshold: 3, + score_gap_threshold: 0.15, + low_score_threshold: 0.3, + max_interventions_per_level: 2, + } + } +} + +impl InterventionConfig { + /// Check if the candidate count triggers intervention. + pub fn should_intervene_at_fork(&self, candidate_count: usize) -> bool { + candidate_count > self.fork_threshold + } + + /// Check if scores are too close (algorithm uncertain). + pub fn scores_are_close(&self, scores: &[f32]) -> bool { + if scores.len() < 2 { + return false; + } + let max_score = scores.iter().cloned().fold(0.0, f32::max); + let min_score = scores.iter().cloned().fold(1.0, f32::min); + (max_score - min_score) < self.score_gap_threshold + } + + /// Check if the best score is too low. + pub fn is_low_confidence(&self, best_score: f32) -> bool { + best_score < self.low_score_threshold + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pilot_mode_uses_llm() { + assert!(PilotMode::Aggressive.uses_llm()); + assert!(PilotMode::Balanced.uses_llm()); + assert!(PilotMode::Conservative.uses_llm()); + assert!(!PilotMode::AlgorithmOnly.uses_llm()); + } + + #[test] + fn test_budget_config() { + let config = BudgetConfig::default(); + assert!(config.is_within_budget(1000)); + assert!(!config.is_within_budget(3000)); + assert_eq!(config.remaining_tokens(1500), 500); + } + + #[test] + fn test_intervention_config() { + let config = InterventionConfig::default(); + + // Fork threshold + assert!(!config.should_intervene_at_fork(2)); + assert!(config.should_intervene_at_fork(4)); + + // Scores close + assert!(config.scores_are_close(&[0.5, 0.55, 0.52])); + assert!(!config.scores_are_close(&[0.3, 0.8])); + + // Low confidence + assert!(config.is_low_confidence(0.2)); + assert!(!config.is_low_confidence(0.5)); + } + + #[test] + fn test_pilot_config_presets() { + let high = PilotConfig::high_quality(); + assert_eq!(high.mode, PilotMode::Aggressive); + + let low = PilotConfig::low_cost(); + assert_eq!(low.mode, PilotMode::Conservative); + + let algo = PilotConfig::algorithm_only(); + assert_eq!(algo.mode, PilotMode::AlgorithmOnly); + } +} diff --git a/src/retrieval/pilot/decision.rs b/src/retrieval/pilot/decision.rs new file mode 100644 index 00000000..09c76add --- /dev/null +++ b/src/retrieval/pilot/decision.rs @@ -0,0 +1,312 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Decision types for Pilot navigation. +//! +//! This module defines the types that represent Pilot's navigation decisions, +//! including direction recommendations, candidate rankings, and intervention points. + +use serde::{Deserialize, Serialize}; + +use crate::domain::NodeId; + +/// Pilot's navigation decision result. +/// +/// Contains all information about where to go next and why. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PilotDecision { + /// Ranked list of candidate nodes (most relevant first). + pub ranked_candidates: Vec, + /// Recommended search direction. + pub direction: SearchDirection, + /// Confidence level of this decision (0.0 - 1.0). + pub confidence: f32, + /// Human-readable explanation of the decision. + pub reasoning: String, + /// The intervention point that triggered this decision. + pub intervention_point: InterventionPoint, +} + +impl Default for PilotDecision { + fn default() -> Self { + Self { + ranked_candidates: Vec::new(), + direction: SearchDirection::GoDeeper { + reason: "Default decision".to_string(), + }, + confidence: 0.0, + reasoning: "No specific guidance available".to_string(), + intervention_point: InterventionPoint::Evaluate, + } + } +} + +impl PilotDecision { + /// Create a new decision with the given candidates and direction. + pub fn new( + ranked_candidates: Vec, + direction: SearchDirection, + confidence: f32, + reasoning: String, + ) -> Self { + Self { + ranked_candidates, + direction, + confidence, + reasoning, + intervention_point: InterventionPoint::Fork, + } + } + + /// Create a decision that preserves original order (no-op). + pub fn preserve_order(candidates: &[NodeId]) -> Self { + Self { + ranked_candidates: candidates + .iter() + .enumerate() + .map(|(i, &id)| RankedCandidate { + node_id: id, + score: 1.0 - (i as f32 * 0.1), + reason: None, + }) + .collect(), + direction: SearchDirection::GoDeeper { + reason: "Preserving original order".to_string(), + }, + confidence: 0.0, + reasoning: "No intervention performed".to_string(), + intervention_point: InterventionPoint::Fork, + } + } + + /// Check if this decision has any ranked candidates. + pub fn has_candidates(&self) -> bool { + !self.ranked_candidates.is_empty() + } + + /// Get the top-ranked candidate. + pub fn top_candidate(&self) -> Option<&RankedCandidate> { + self.ranked_candidates.first() + } + + /// Get node IDs in ranked order. + pub fn ranked_node_ids(&self) -> Vec { + self.ranked_candidates.iter().map(|c| c.node_id).collect() + } + + /// Check if the decision indicates an answer was found. + pub fn found_answer(&self) -> bool { + matches!(self.direction, SearchDirection::FoundAnswer { .. }) + } + + /// Check if the decision indicates backtracking is needed. + pub fn needs_backtrack(&self) -> bool { + matches!(self.direction, SearchDirection::Backtrack { .. }) + } +} + +/// A ranked candidate node with score and optional reason. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RankedCandidate { + /// The node ID. + pub node_id: NodeId, + /// Relevance score (0.0 - 1.0). + pub score: f32, + /// Optional reason for this ranking. + pub reason: Option, +} + +impl RankedCandidate { + /// Create a new ranked candidate. + pub fn new(node_id: NodeId, score: f32) -> Self { + Self { + node_id, + score, + reason: None, + } + } + + /// Create with a reason. + pub fn with_reason(node_id: NodeId, score: f32, reason: impl Into) -> Self { + Self { + node_id, + score, + reason: Some(reason.into()), + } + } + + /// Set the reason for this ranking. + pub fn reason(mut self, reason: impl Into) -> Self { + self.reason = Some(reason.into()); + self + } +} + +/// Search direction recommendation from Pilot. +/// +/// Indicates where the search should go next. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SearchDirection { + /// Continue deeper into the current branch. + GoDeeper { + /// Reason for going deeper. + reason: String, + }, + /// Explore sibling nodes at the same level. + ExploreSiblings { + /// Recommended siblings to explore. + recommended: Vec, + }, + /// Backtrack to parent and try other branches. + Backtrack { + /// Reason for backtracking. + reason: String, + /// Alternative branches to try. + alternative_branches: Vec, + }, + /// Jump to a non-local node (global navigation). + JumpTo { + /// Target node to jump to. + target: NodeId, + /// Reason for the jump. + reason: String, + }, + /// Current node contains the answer. + FoundAnswer { + /// Confidence that this is the answer. + confidence: f32, + }, +} + +impl SearchDirection { + /// Create a GoDeeper direction. + pub fn go_deeper(reason: impl Into) -> Self { + Self::GoDeeper { + reason: reason.into(), + } + } + + /// Create a Backtrack direction. + pub fn backtrack(reason: impl Into, alternatives: Vec) -> Self { + Self::Backtrack { + reason: reason.into(), + alternative_branches: alternatives, + } + } + + /// Create a JumpTo direction. + pub fn jump_to(target: NodeId, reason: impl Into) -> Self { + Self::JumpTo { + target, + reason: reason.into(), + } + } + + /// Create a FoundAnswer direction. + pub fn found_answer(confidence: f32) -> Self { + Self::FoundAnswer { confidence } + } +} + +/// The point in search where Pilot intervenes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +pub enum InterventionPoint { + /// Before search begins (initial guidance). + Start, + /// At a fork with multiple candidates. + #[default] + Fork, + /// During backtracking after a dead end. + Backtrack, + /// Evaluating a specific node for relevance. + Evaluate, +} + +impl InterventionPoint { + /// Get a human-readable name for this point. + pub fn name(&self) -> &'static str { + match self { + Self::Start => "start", + Self::Fork => "fork", + Self::Backtrack => "backtrack", + Self::Evaluate => "evaluate", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indextree::Arena; + + fn create_test_node_ids(count: usize) -> Vec { + let mut arena = Arena::new(); + let mut ids = Vec::new(); + for i in 0..count { + let node = crate::domain::TreeNode { + title: format!("Node {}", i), + content: String::new(), + summary: String::new(), + depth: 0, + start_index: 1, + end_index: 1, + start_page: None, + end_page: None, + node_id: None, + physical_index: None, + token_count: None, + }; + ids.push(NodeId(arena.new_node(node))); + } + ids + } + + #[test] + fn test_pilot_decision_default() { + let decision = PilotDecision::default(); + assert!(!decision.has_candidates()); + assert!(decision.top_candidate().is_none()); + assert!(!decision.found_answer()); + assert!(!decision.needs_backtrack()); + } + + #[test] + fn test_pilot_decision_preserve_order() { + let node_ids = create_test_node_ids(2); + let decision = PilotDecision::preserve_order(&node_ids); + + assert!(decision.has_candidates()); + assert_eq!(decision.ranked_candidates.len(), 2); + assert_eq!(decision.confidence, 0.0); + } + + #[test] + fn test_ranked_candidate() { + let node_ids = create_test_node_ids(1); + let candidate = RankedCandidate::new(node_ids[0], 0.8); + assert_eq!(candidate.score, 0.8); + assert!(candidate.reason.is_none()); + + let candidate_with_reason = + RankedCandidate::with_reason(node_ids[0], 0.9, "test reason"); + assert_eq!(candidate_with_reason.score, 0.9); + assert_eq!(candidate_with_reason.reason, Some("test reason".to_string())); + } + + #[test] + fn test_search_direction_constructors() { + let deeper = SearchDirection::go_deeper("test"); + assert!(matches!(deeper, SearchDirection::GoDeeper { .. })); + + let found = SearchDirection::found_answer(0.9); + assert!(matches!(found, SearchDirection::FoundAnswer { confidence: 0.9 })); + } + + #[test] + fn test_intervention_point() { + assert_eq!(InterventionPoint::Start.name(), "start"); + assert_eq!(InterventionPoint::Fork.name(), "fork"); + assert_eq!(InterventionPoint::Backtrack.name(), "backtrack"); + assert_eq!(InterventionPoint::Evaluate.name(), "evaluate"); + } +} diff --git a/src/retrieval/pilot/fallback.rs b/src/retrieval/pilot/fallback.rs new file mode 100644 index 00000000..874cdd96 --- /dev/null +++ b/src/retrieval/pilot/fallback.rs @@ -0,0 +1,445 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Fallback manager for Pilot LLM calls. +//! +//! Implements layered fallback strategy: +//! 1. Normal LLM call +//! 2. Retry with exponential backoff +//! 3. Simplified context (reduce tokens) +//! 4. Algorithm-only mode (no LLM) + +use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; +use std::time::Duration; +use tracing::{debug, warn}; + +/// Fallback level indicating current degradation state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FallbackLevel { + /// Normal operation - LLM calls working. + Normal = 0, + /// Retrying - transient failures, using backoff. + Retry = 1, + /// Simplified - using reduced context. + Simplified = 2, + /// Algorithm only - LLM unavailable. + AlgorithmOnly = 3, +} + +impl Default for FallbackLevel { + fn default() -> Self { + Self::Normal + } +} + +impl From for FallbackLevel { + fn from(value: u8) -> Self { + match value { + 0 => Self::Normal, + 1 => Self::Retry, + 2 => Self::Simplified, + _ => Self::AlgorithmOnly, + } + } +} + +/// Configuration for fallback behavior. +#[derive(Debug, Clone)] +pub struct FallbackConfig { + /// Maximum retry attempts before escalating. + pub max_retries: usize, + /// Initial delay for exponential backoff (ms). + pub initial_delay_ms: u64, + /// Maximum delay for exponential backoff (ms). + pub max_delay_ms: u64, + /// Multiplier for exponential backoff. + pub backoff_multiplier: f64, + /// Consecutive failures before escalating level. + pub failures_before_escalate: usize, + /// Consecutive successes before de-escalating level. + pub successes_before_deescalate: usize, +} + +impl Default for FallbackConfig { + fn default() -> Self { + Self { + max_retries: 3, + initial_delay_ms: 1000, + max_delay_ms: 10000, + backoff_multiplier: 2.0, + failures_before_escalate: 3, + successes_before_deescalate: 2, + } + } +} + +/// Errors that can trigger fallback. +#[derive(Debug, Clone, thiserror::Error)] +pub enum FallbackError { + /// Network/timeout error (retryable). + #[error("Network error: {0}")] + Network(String), + /// Rate limit error (retryable with backoff). + #[error("Rate limited")] + RateLimited, + /// Token limit exceeded (need simplified context). + #[error("Token limit exceeded")] + TokenLimitExceeded, + /// LLM service unavailable (use algorithm). + #[error("LLM unavailable: {0}")] + Unavailable(String), + /// Parsing error (may use default). + #[error("Response parsing failed: {0}")] + ParseError(String), + /// All fallbacks exhausted. + #[error("All fallback strategies exhausted")] + Exhausted, +} + +impl FallbackError { + /// Check if this error should trigger a retry. + pub fn is_retryable(&self) -> bool { + matches!(self, Self::Network(_) | Self::RateLimited) + } + + /// Check if this error suggests using simplified context. + pub fn needs_simplification(&self) -> bool { + matches!(self, Self::TokenLimitExceeded) + } + + /// Check if this error requires algorithm fallback. + pub fn needs_algorithm_fallback(&self) -> bool { + matches!(self, Self::Unavailable(_) | Self::Exhausted) + } +} + +/// Statistics for fallback operations. +#[derive(Debug, Clone, Default)] +pub struct FallbackStats { + /// Total operations attempted. + pub total_attempts: usize, + /// Successful operations (no fallback needed). + pub successful: usize, + /// Operations that needed retry. + pub retried: usize, + /// Operations that needed simplified context. + pub simplified: usize, + /// Operations that fell back to algorithm. + pub algorithm_fallbacks: usize, + /// Current fallback level. + pub current_level: FallbackLevel, +} + +/// Manager for handling LLM call failures with layered fallback. +/// +/// Implements a 4-level fallback strategy: +/// 1. Normal: Direct LLM calls +/// 2. Retry: Exponential backoff retry +/// 3. Simplified: Reduced context to fit token limits +/// 4. Algorithm: Pure algorithm mode, no LLM +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::FallbackManager; +/// +/// let manager = FallbackManager::new(FallbackConfig::default()); +/// +/// // Check current level +/// if manager.current_level() == FallbackLevel::Normal { +/// // Make LLM call +/// } +/// +/// // Record failure +/// manager.record_failure(&error); +/// ``` +pub struct FallbackManager { + config: FallbackConfig, + /// Current fallback level. + current_level: AtomicU8, + /// Consecutive failures at current level. + consecutive_failures: AtomicUsize, + /// Consecutive successes at current level. + consecutive_successes: AtomicUsize, + /// Total retry attempts in current session. + retry_attempts: AtomicUsize, +} + +impl std::fmt::Debug for FallbackManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FallbackManager") + .field("config", &self.config) + .field("current_level", &self.current_level()) + .field("consecutive_failures", &self.consecutive_failures.load(Ordering::Relaxed)) + .finish() + } +} + +impl FallbackManager { + /// Create a new fallback manager with configuration. + pub fn new(config: FallbackConfig) -> Self { + Self { + config, + current_level: AtomicU8::new(0), + consecutive_failures: AtomicUsize::new(0), + consecutive_successes: AtomicUsize::new(0), + retry_attempts: AtomicUsize::new(0), + } + } + + /// Create with default configuration. + pub fn with_defaults() -> Self { + Self::new(FallbackConfig::default()) + } + + /// Get current fallback level. + pub fn current_level(&self) -> FallbackLevel { + self.current_level.load(Ordering::Relaxed).into() + } + + /// Check if we're at algorithm-only level. + pub fn is_algorithm_only(&self) -> bool { + self.current_level() == FallbackLevel::AlgorithmOnly + } + + /// Check if we should use simplified context. + pub fn should_simplify(&self) -> bool { + matches!( + self.current_level(), + FallbackLevel::Simplified | FallbackLevel::AlgorithmOnly + ) + } + + /// Get delay for next retry based on attempt number. + pub fn retry_delay(&self, attempt: usize) -> Duration { + let delay = self.config.initial_delay_ms as f64 + * self.config.backoff_multiplier.powi(attempt as i32); + let delay = delay.min(self.config.max_delay_ms as f64); + Duration::from_millis(delay as u64) + } + + /// Record a successful operation. + /// + /// May de-escalate the fallback level after consecutive successes. + pub fn record_success(&self) { + self.consecutive_failures.store(0, Ordering::Relaxed); + + let successes = self.consecutive_successes.fetch_add(1, Ordering::Relaxed) + 1; + + // De-escalate after enough consecutive successes + if successes >= self.config.successes_before_deescalate { + let current = self.current_level.load(Ordering::Relaxed); + if current > 0 { + self.current_level.fetch_sub(1, Ordering::Relaxed); + debug!("Fallback level de-escalated to {:?}", self.current_level()); + } + self.consecutive_successes.store(0, Ordering::Relaxed); + } + } + + /// Record a failure and potentially escalate level. + /// + /// Returns the recommended action. + pub fn record_failure(&self, error: &FallbackError) -> FallbackAction { + self.consecutive_successes.store(0, Ordering::Relaxed); + + // Check if we should escalate + let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1; + + if failures >= self.config.failures_before_escalate { + self.escalate_level(); + self.consecutive_failures.store(0, Ordering::Relaxed); + } + + // Determine action based on error and current level + match error { + FallbackError::Network(_) | FallbackError::RateLimited => { + if self.retry_attempts.load(Ordering::Relaxed) < self.config.max_retries { + FallbackAction::Retry + } else { + FallbackAction::Escalate + } + } + FallbackError::TokenLimitExceeded => FallbackAction::Simplify, + FallbackError::Unavailable(_) | FallbackError::Exhausted => { + FallbackAction::UseAlgorithm + } + FallbackError::ParseError(_) => { + // Try default decision, don't escalate + FallbackAction::UseDefault + } + } + } + + /// Escalate to next fallback level. + fn escalate_level(&self) { + let current = self.current_level.load(Ordering::Relaxed); + if current < 3 { + self.current_level.fetch_add(1, Ordering::Relaxed); + warn!("Fallback level escalated to {:?}", self.current_level()); + } + } + + /// Start a retry attempt. + pub fn start_retry(&self) { + self.retry_attempts.fetch_add(1, Ordering::Relaxed); + } + + /// Reset retry counter (after successful operation). + pub fn reset_retry_count(&self) { + self.retry_attempts.store(0, Ordering::Relaxed); + } + + /// Reset all state for new query. + pub fn reset(&self) { + self.current_level.store(0, Ordering::Relaxed); + self.consecutive_failures.store(0, Ordering::Relaxed); + self.consecutive_successes.store(0, Ordering::Relaxed); + self.retry_attempts.store(0, Ordering::Relaxed); + } + + /// Get current statistics. + pub fn stats(&self) -> FallbackStats { + FallbackStats { + current_level: self.current_level(), + ..Default::default() + } + } + + /// Get the configuration. + pub fn config(&self) -> &FallbackConfig { + &self.config + } +} + +/// Action to take after a failure. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FallbackAction { + /// Retry the operation (with backoff). + Retry, + /// Simplify context and retry. + Simplify, + /// Escalate to next fallback level. + Escalate, + /// Use algorithm-only mode. + UseAlgorithm, + /// Use a default decision. + UseDefault, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fallback_level_conversion() { + assert_eq!(FallbackLevel::from(0), FallbackLevel::Normal); + assert_eq!(FallbackLevel::from(1), FallbackLevel::Retry); + assert_eq!(FallbackLevel::from(2), FallbackLevel::Simplified); + assert_eq!(FallbackLevel::from(3), FallbackLevel::AlgorithmOnly); + assert_eq!(FallbackLevel::from(4), FallbackLevel::AlgorithmOnly); + } + + #[test] + fn test_fallback_manager_creation() { + let manager = FallbackManager::with_defaults(); + assert_eq!(manager.current_level(), FallbackLevel::Normal); + assert!(!manager.is_algorithm_only()); + assert!(!manager.should_simplify()); + } + + #[test] + fn test_retry_delay() { + let manager = FallbackManager::with_defaults(); + + let d0 = manager.retry_delay(0); + let d1 = manager.retry_delay(1); + let d2 = manager.retry_delay(2); + + assert!(d1 > d0); + assert!(d2 > d1); + } + + #[test] + fn test_retry_delay_max() { + let config = FallbackConfig { + max_delay_ms: 5000, + ..Default::default() + }; + let manager = FallbackManager::new(config); + + // High attempt should cap at max + let delay = manager.retry_delay(10); + assert!(delay.as_millis() <= 5000); + } + + #[test] + fn test_record_success() { + let manager = FallbackManager::with_defaults(); + manager.current_level.store(1, Ordering::Relaxed); + + // Need multiple successes to de-escalate + for _ in 0..manager.config.successes_before_deescalate { + manager.record_success(); + } + + assert_eq!(manager.current_level(), FallbackLevel::Normal); + } + + #[test] + fn test_record_failure_escalate() { + let manager = FallbackManager::with_defaults(); + + // Trigger failures to escalate + for _ in 0..manager.config.failures_before_escalate { + let action = manager.record_failure(&FallbackError::Network("test".to_string())); + assert!(matches!(action, FallbackAction::Retry | FallbackAction::Escalate)); + } + + assert_eq!(manager.current_level(), FallbackLevel::Retry); + } + + #[test] + fn test_record_failure_token_limit() { + let manager = FallbackManager::with_defaults(); + + let action = manager.record_failure(&FallbackError::TokenLimitExceeded); + assert_eq!(action, FallbackAction::Simplify); + } + + #[test] + fn test_record_failure_unavailable() { + let manager = FallbackManager::with_defaults(); + + let action = manager.record_failure(&FallbackError::Unavailable("test".to_string())); + assert_eq!(action, FallbackAction::UseAlgorithm); + } + + #[test] + fn test_reset() { + let manager = FallbackManager::with_defaults(); + + // Escalate level + manager.current_level.store(3, Ordering::Relaxed); + manager.consecutive_failures.store(5, Ordering::Relaxed); + + manager.reset(); + + assert_eq!(manager.current_level(), FallbackLevel::Normal); + assert_eq!(manager.consecutive_failures.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_error_retryable() { + assert!(FallbackError::Network("test".to_string()).is_retryable()); + assert!(FallbackError::RateLimited.is_retryable()); + assert!(!FallbackError::TokenLimitExceeded.is_retryable()); + assert!(!FallbackError::Unavailable("test".to_string()).is_retryable()); + } + + #[test] + fn test_error_needs_simplification() { + assert!(FallbackError::TokenLimitExceeded.needs_simplification()); + assert!(!FallbackError::Network("test".to_string()).needs_simplification()); + } +} diff --git a/src/retrieval/pilot/llm_pilot.rs b/src/retrieval/pilot/llm_pilot.rs new file mode 100644 index 00000000..9342ffa4 --- /dev/null +++ b/src/retrieval/pilot/llm_pilot.rs @@ -0,0 +1,436 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! LLM-based Pilot implementation. +//! +//! This module provides the main Pilot implementation that uses LLM +//! for semantic navigation guidance. + +use async_trait::async_trait; +use std::sync::Arc; +use tracing::{debug, info, warn}; + +use crate::domain::DocumentTree; +use crate::llm::LlmClient; + +use super::builder::ContextBuilder; +use super::budget::BudgetController; +use super::config::PilotConfig; +use super::decision::{InterventionPoint, PilotDecision}; +use super::parser::ResponseParser; +use super::prompts::PromptBuilder; +use super::r#trait::{Pilot, SearchState}; + +/// LLM-based Pilot implementation. +/// +/// Uses an LLM client to provide semantic navigation guidance +/// at key decision points during tree search. +/// +/// # Architecture +/// +/// ```text +/// ┌─────────────────────────────────────────────────────────────┐ +/// │ LlmPilot │ +/// ├─────────────────────────────────────────────────────────────┤ +/// │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +/// │ │ Context │ │ Prompt │ │ Response │ │ +/// │ │ Builder │─▶│ Builder │─▶│ Parser │ │ +/// │ └─────────────┘ └─────────────┘ └─────────────┘ │ +/// │ │ +/// │ ┌─────────────┐ ┌─────────────┐ │ +/// │ │ Budget │ │ LLM │ │ +/// │ │ Controller │ │ Client │ │ +/// │ └─────────────┘ └─────────────┘ │ +/// └─────────────────────────────────────────────────────────────┘ +/// ``` +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::{LlmPilot, PilotConfig}; +/// use vectorless::llm::LlmClient; +/// +/// let client = LlmClient::for_model("gpt-4o-mini"); +/// let pilot = LlmPilot::new(client, PilotConfig::default()); +/// +/// // Use in search +/// if pilot.should_intervene(&state) { +/// let decision = pilot.decide(&state).await; +/// } +/// ``` +pub struct LlmPilot { + /// LLM client for making requests. + client: LlmClient, + /// Pilot configuration. + config: PilotConfig, + /// Budget controller. + budget: BudgetController, + /// Context builder. + context_builder: ContextBuilder, + /// Prompt builder. + prompt_builder: PromptBuilder, + /// Response parser. + response_parser: ResponseParser, +} + +impl std::fmt::Debug for LlmPilot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlmPilot") + .field("config", &self.config) + .field("budget", &self.budget.usage()) + .finish() + } +} + +impl LlmPilot { + /// Create a new LLM-based Pilot. + pub fn new(client: LlmClient, config: PilotConfig) -> Self { + let budget = BudgetController::new(config.budget.clone()); + let token_budget = config.budget.max_tokens_per_call; + + Self { + client, + config, + budget, + context_builder: ContextBuilder::new(token_budget), + prompt_builder: PromptBuilder::new(), + response_parser: ResponseParser::new(), + } + } + + /// Create with custom builders. + pub fn with_builders( + client: LlmClient, + config: PilotConfig, + context_builder: ContextBuilder, + prompt_builder: PromptBuilder, + ) -> Self { + let budget = BudgetController::new(config.budget.clone()); + + Self { + client, + config, + budget, + context_builder, + prompt_builder, + response_parser: ResponseParser::new(), + } + } + + /// Check if budget allows LLM calls. + fn has_budget(&self) -> bool { + self.budget.can_call() + } + + /// Check if scores are too close (algorithm uncertain). + fn scores_are_close(&self, state: &SearchState<'_>) -> bool { + // Use the config's score_gap_threshold with the state's best_score + // If best_score is low, consider scores as close + state.candidates.len() >= 2 && state.best_score < self.config.intervention.score_gap_threshold + } + + /// Determine the intervention point type. + fn get_intervention_point(&self, state: &SearchState<'_>) -> InterventionPoint { + if state.is_at_root() || state.iteration == 0 { + InterventionPoint::Start + } else if state.is_backtracking { + InterventionPoint::Backtrack + } else if state.is_fork_point() { + InterventionPoint::Fork + } else { + InterventionPoint::Evaluate + } + } + + /// Make an LLM call and return the decision. + async fn call_llm( + &self, + point: InterventionPoint, + context: &super::builder::PilotContext, + candidates: &[crate::domain::NodeId], + ) -> PilotDecision { + // Build prompt + let prompt = self.prompt_builder.build(point, context); + + // Check if we can afford this call + if !self.budget.can_afford(prompt.estimated_tokens) { + warn!("Budget cannot afford LLM call (estimated: {} tokens)", prompt.estimated_tokens); + return self.default_decision(candidates, point); + } + + debug!( + "Calling LLM for {:?} point (estimated: {} tokens)", + point, prompt.estimated_tokens + ); + + // Make LLM call + match self.client.complete(&prompt.system, &prompt.user).await { + Ok(response) => { + // Record usage (estimate output tokens) + let output_tokens = self.estimate_tokens(&response); + self.budget.record_usage(prompt.estimated_tokens, output_tokens, 0); + + // Parse response + let decision = self.response_parser.parse(&response, candidates, point); + + info!( + "LLM decision: direction={:?}, confidence={:.2}, candidates={}", + std::mem::discriminant(&decision.direction), + decision.confidence, + decision.ranked_candidates.len() + ); + + decision + } + Err(e) => { + warn!("LLM call failed: {}", e); + self.default_decision(candidates, point) + } + } + } + + /// Create a default decision when LLM fails. + fn default_decision( + &self, + candidates: &[crate::domain::NodeId], + point: InterventionPoint, + ) -> PilotDecision { + let ranked = candidates + .iter() + .enumerate() + .map(|(i, &node_id)| super::decision::RankedCandidate { + node_id, + score: 1.0 / (i + 1) as f32, + reason: None, + }) + .collect(); + + PilotDecision { + ranked_candidates: ranked, + direction: super::decision::SearchDirection::GoDeeper { + reason: "Default decision (LLM unavailable)".to_string(), + }, + confidence: 0.0, + reasoning: "LLM call failed or budget exhausted".to_string(), + intervention_point: point, + } + } + + /// Estimate token count for a string. + fn estimate_tokens(&self, text: &str) -> usize { + let char_count = text.chars().count(); + let chinese_count = text + .chars() + .filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c)) + .count(); + let english_count = char_count - chinese_count; + + (chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize + } +} + +#[async_trait] +impl Pilot for LlmPilot { + fn name(&self) -> &str { + "llm_pilot" + } + + fn should_intervene(&self, state: &SearchState<'_>) -> bool { + // Check mode + if !self.config.mode.uses_llm() { + return false; + } + + // Check budget + if !self.has_budget() { + debug!("Budget exhausted, skipping intervention"); + return false; + } + + let intervention = &self.config.intervention; + + // Condition 1: Fork point with enough candidates + if state.candidates.len() > intervention.fork_threshold { + debug!("Intervening: fork point with {} candidates", state.candidates.len()); + return true; + } + + // Condition 2: Scores are too close (algorithm uncertain) + if self.scores_are_close(state) { + debug!("Intervening: scores are close"); + return true; + } + + // Condition 3: Low confidence (best score too low) + if intervention.is_low_confidence(state.best_score) { + debug!("Intervening: low confidence (best_score={:.2})", state.best_score); + return true; + } + + // Condition 4: Backtracking and guide_at_backtrack is enabled + if state.is_backtracking && self.config.guide_at_backtrack { + debug!("Intervening: backtracking"); + return true; + } + + false + } + + async fn decide(&self, state: &SearchState<'_>) -> PilotDecision { + let point = self.get_intervention_point(state); + + // Build context + let context = self.context_builder.build(state); + + // Make LLM call + self.call_llm(point, &context, state.candidates).await + } + + async fn guide_start( + &self, + tree: &DocumentTree, + query: &str, + ) -> Option { + // Check if guide_at_start is enabled + if !self.config.guide_at_start { + return None; + } + + // Check budget + if !self.has_budget() { + return None; + } + + // Build start context + let context = self.context_builder.build_start_context(tree, query); + + // Get root's children as candidates + let candidates = tree.children(tree.root()); + + // Make LLM call + Some(self.call_llm(InterventionPoint::Start, &context, &candidates).await) + } + + async fn guide_backtrack( + &self, + state: &SearchState<'_>, + ) -> Option { + // Check if guide_at_backtrack is enabled + if !self.config.guide_at_backtrack { + return None; + } + + // Check budget + if !self.has_budget() { + return None; + } + + // Build backtrack context + let context = self.context_builder.build_backtrack_context(state, state.path); + + // Make LLM call + Some(self.call_llm(InterventionPoint::Backtrack, &context, state.candidates).await) + } + + fn config(&self) -> &PilotConfig { + &self.config + } + + fn is_active(&self) -> bool { + self.config.mode.uses_llm() && self.has_budget() + } + + fn reset(&self) { + self.budget.reset(); + debug!("LlmPilot reset for new query"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::NodeId; + use indextree::Arena; + + fn create_test_node_ids(count: usize) -> Vec { + let mut arena = Arena::new(); + let mut ids = Vec::new(); + for i in 0..count { + let node = crate::domain::TreeNode { + title: format!("Node {}", i), + content: String::new(), + summary: String::new(), + depth: 0, + start_index: 1, + end_index: 1, + start_page: None, + end_page: None, + node_id: None, + physical_index: None, + token_count: None, + }; + ids.push(NodeId(arena.new_node(node))); + } + ids + } + + #[test] + fn test_llm_pilot_creation() { + let client = LlmClient::for_model("gpt-4o-mini"); + let config = PilotConfig::default(); + let pilot = LlmPilot::new(client, config); + + assert_eq!(pilot.name(), "llm_pilot"); + assert!(pilot.is_active()); + } + + #[test] + fn test_llm_pilot_algorithm_only_mode() { + let client = LlmClient::for_model("gpt-4o-mini"); + let config = PilotConfig::algorithm_only(); + let pilot = LlmPilot::new(client, config); + + assert!(!pilot.config().mode.uses_llm()); + } + + #[test] + fn test_llm_pilot_budget_exhausted() { + let client = LlmClient::for_model("gpt-4o-mini"); + let config = PilotConfig::default(); + let pilot = LlmPilot::new(client, config); + + // Exhaust budget + pilot.budget.record_usage(3000, 500, 0); + + assert!(!pilot.has_budget()); + } + + #[test] + fn test_default_decision() { + let client = LlmClient::for_model("gpt-4o-mini"); + let config = PilotConfig::default(); + let pilot = LlmPilot::new(client, config); + + let candidates = create_test_node_ids(2); + let decision = pilot.default_decision(&candidates, InterventionPoint::Fork); + + assert_eq!(decision.ranked_candidates.len(), 2); + assert_eq!(decision.confidence, 0.0); + assert!(decision.reasoning.contains("LLM")); + } + + #[test] + fn test_reset() { + let client = LlmClient::for_model("gpt-4o-mini"); + let config = PilotConfig::default(); + let pilot = LlmPilot::new(client, config); + + // Use some budget + pilot.budget.record_usage(100, 50, 0); + assert!(pilot.budget.total_tokens() > 0); + + // Reset + pilot.reset(); + assert_eq!(pilot.budget.total_tokens(), 0); + } +} diff --git a/src/retrieval/pilot/metrics.rs b/src/retrieval/pilot/metrics.rs new file mode 100644 index 00000000..d9e7b656 --- /dev/null +++ b/src/retrieval/pilot/metrics.rs @@ -0,0 +1,537 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Metrics collector for Pilot operations. +//! +//! Collects performance metrics including: +//! - LLM call statistics (count, success/failure) +//! - Token usage (input, output, total) +//! - Latency tracking (average, p50, p99) +//! - Decision quality metrics + +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::time::{Duration, Instant}; + +use super::decision::InterventionPoint; + +/// Snapshot of Pilot metrics at a point in time. +#[derive(Debug, Clone, Default)] +pub struct PilotMetrics { + // LLM call statistics + /// Total LLM calls attempted. + pub total_calls: usize, + /// Successful LLM calls. + pub successful_calls: usize, + /// Failed LLM calls. + pub failed_calls: usize, + /// Calls that needed fallback. + pub fallback_calls: usize, + + // Token statistics + /// Total input tokens consumed. + pub total_input_tokens: usize, + /// Total output tokens generated. + pub total_output_tokens: usize, + /// Average tokens per call. + pub avg_tokens_per_call: f64, + + // Latency statistics + /// Total time spent in LLM calls (ms). + pub total_latency_ms: u64, + /// Average latency per call (ms). + pub avg_latency_ms: f64, + /// P50 latency (ms). + pub p50_latency_ms: u64, + /// P99 latency (ms). + pub p99_latency_ms: u64, + + // Intervention statistics + /// Calls at START point. + pub start_interventions: usize, + /// Calls at FORK point. + pub fork_interventions: usize, + /// Calls at BACKTRACK point. + pub backtrack_interventions: usize, + /// Calls at EVALUATE point. + pub evaluate_interventions: usize, + + // Quality metrics (require feedback) + /// LLM decision accuracy (0.0-1.0). + pub llm_accuracy: Option, + /// Retrieval precision (0.0-1.0). + pub retrieval_precision: Option, +} + +impl PilotMetrics { + /// Calculate success rate (0.0-1.0). + pub fn success_rate(&self) -> f64 { + if self.total_calls == 0 { + return 0.0; + } + self.successful_calls as f64 / self.total_calls as f64 + } + + /// Calculate token utilization. + pub fn token_utilization(&self, budget: usize) -> f64 { + if budget == 0 { + return 0.0; + } + let total = self.total_input_tokens + self.total_output_tokens; + (total as f64 / budget as f64).min(1.0) + } + + /// Calculate fallback rate (0.0-1.0). + pub fn fallback_rate(&self) -> f64 { + if self.total_calls == 0 { + return 0.0; + } + self.fallback_calls as f64 / self.total_calls as f64 + } +} + +/// Record of a single LLM call. +#[derive(Debug, Clone)] +pub struct CallRecord { + /// Intervention point. + pub point: InterventionPoint, + /// Input tokens used. + pub input_tokens: usize, + /// Output tokens generated. + pub output_tokens: usize, + /// Latency in milliseconds. + pub latency_ms: u64, + /// Whether the call succeeded. + pub success: bool, + /// Whether fallback was used. + pub used_fallback: bool, +} + +/// Latency sample for percentile calculation. +#[derive(Debug, Clone)] +struct LatencySample { + latency_ms: u64, +} + +/// Metrics collector for Pilot operations. +/// +/// Thread-safe collector that tracks all Pilot metrics. +/// Uses atomic operations for concurrent access. +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::MetricsCollector; +/// +/// let metrics = MetricsCollector::new(); +/// +/// // Record a call +/// let start = std::time::Instant::now(); +/// // ... make LLM call ... +/// metrics.record_call(InterventionPoint::Fork, 100, 50, start.elapsed(), true, false); +/// +/// // Get snapshot +/// let snapshot = metrics.snapshot(); +/// println!("Success rate: {:.2}%", snapshot.success_rate() * 100.0); +/// ``` +pub struct MetricsCollector { + // Call counters + total_calls: AtomicUsize, + successful_calls: AtomicUsize, + failed_calls: AtomicUsize, + fallback_calls: AtomicUsize, + + // Token counters + total_input_tokens: AtomicUsize, + total_output_tokens: AtomicUsize, + + // Latency tracking + total_latency_ms: AtomicU64, + latency_samples: std::sync::RwLock>, + + // Intervention counters + start_interventions: AtomicUsize, + fork_interventions: AtomicUsize, + backtrack_interventions: AtomicUsize, + evaluate_interventions: AtomicUsize, + + // Quality metrics (set externally) + llm_accuracy: std::sync::RwLock>, + retrieval_precision: std::sync::RwLock>, +} + +impl Default for MetricsCollector { + fn default() -> Self { + Self::new() + } +} + +impl MetricsCollector { + /// Create a new metrics collector. + pub fn new() -> Self { + Self { + total_calls: AtomicUsize::new(0), + successful_calls: AtomicUsize::new(0), + failed_calls: AtomicUsize::new(0), + fallback_calls: AtomicUsize::new(0), + total_input_tokens: AtomicUsize::new(0), + total_output_tokens: AtomicUsize::new(0), + total_latency_ms: AtomicU64::new(0), + latency_samples: std::sync::RwLock::new(Vec::with_capacity(100)), + start_interventions: AtomicUsize::new(0), + fork_interventions: AtomicUsize::new(0), + backtrack_interventions: AtomicUsize::new(0), + evaluate_interventions: AtomicUsize::new(0), + llm_accuracy: std::sync::RwLock::new(None), + retrieval_precision: std::sync::RwLock::new(None), + } + } + + /// Record an LLM call. + pub fn record_call( + &self, + point: InterventionPoint, + input_tokens: usize, + output_tokens: usize, + latency: Duration, + success: bool, + used_fallback: bool, + ) { + // Update call counters + self.total_calls.fetch_add(1, Ordering::Relaxed); + if success { + self.successful_calls.fetch_add(1, Ordering::Relaxed); + } else { + self.failed_calls.fetch_add(1, Ordering::Relaxed); + } + if used_fallback { + self.fallback_calls.fetch_add(1, Ordering::Relaxed); + } + + // Update token counters + self.total_input_tokens.fetch_add(input_tokens, Ordering::Relaxed); + self.total_output_tokens.fetch_add(output_tokens, Ordering::Relaxed); + + // Update latency + let latency_ms = latency.as_millis() as u64; + self.total_latency_ms.fetch_add(latency_ms, Ordering::Relaxed); + + // Store latency sample + if let Ok(mut samples) = self.latency_samples.write() { + samples.push(LatencySample { latency_ms }); + // Keep last 1000 samples + if samples.len() > 1000 { + samples.remove(0); + } + } + + // Update intervention counters + match point { + InterventionPoint::Start => { + self.start_interventions.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Fork => { + self.fork_interventions.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Backtrack => { + self.backtrack_interventions.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Evaluate => { + self.evaluate_interventions.fetch_add(1, Ordering::Relaxed); + } + } + } + + /// Record a call using CallRecord. + pub fn record(&self, record: CallRecord) { + let latency = Duration::from_millis(record.latency_ms); + self.record_call( + record.point, + record.input_tokens, + record.output_tokens, + latency, + record.success, + record.used_fallback, + ); + } + + /// Set LLM accuracy (from external feedback). + pub fn set_llm_accuracy(&self, accuracy: f64) { + if let Ok(mut acc) = self.llm_accuracy.write() { + *acc = Some(accuracy.clamp(0.0, 1.0)); + } + } + + /// Set retrieval precision (from external feedback). + pub fn set_retrieval_precision(&self, precision: f64) { + if let Ok(mut prec) = self.retrieval_precision.write() { + *prec = Some(precision.clamp(0.0, 1.0)); + } + } + + /// Get a snapshot of current metrics. + pub fn snapshot(&self) -> PilotMetrics { + let total_calls = self.total_calls.load(Ordering::Relaxed); + let successful_calls = self.successful_calls.load(Ordering::Relaxed); + let failed_calls = self.failed_calls.load(Ordering::Relaxed); + let fallback_calls = self.fallback_calls.load(Ordering::Relaxed); + let total_input_tokens = self.total_input_tokens.load(Ordering::Relaxed); + let total_output_tokens = self.total_output_tokens.load(Ordering::Relaxed); + let total_latency_ms = self.total_latency_ms.load(Ordering::Relaxed); + + let avg_tokens_per_call = if total_calls > 0 { + (total_input_tokens + total_output_tokens) as f64 / total_calls as f64 + } else { + 0.0 + }; + + let avg_latency_ms = if total_calls > 0 { + total_latency_ms as f64 / total_calls as f64 + } else { + 0.0 + }; + + // Calculate percentiles from samples + let (p50_latency_ms, p99_latency_ms) = self.calculate_percentiles(); + + PilotMetrics { + total_calls, + successful_calls, + failed_calls, + fallback_calls, + total_input_tokens, + total_output_tokens, + avg_tokens_per_call, + total_latency_ms, + avg_latency_ms, + p50_latency_ms, + p99_latency_ms, + start_interventions: self.start_interventions.load(Ordering::Relaxed), + fork_interventions: self.fork_interventions.load(Ordering::Relaxed), + backtrack_interventions: self.backtrack_interventions.load(Ordering::Relaxed), + evaluate_interventions: self.evaluate_interventions.load(Ordering::Relaxed), + llm_accuracy: self.llm_accuracy.read().ok().and_then(|v| *v), + retrieval_precision: self.retrieval_precision.read().ok().and_then(|v| *v), + } + } + + /// Calculate p50 and p99 latencies. + fn calculate_percentiles(&self) -> (u64, u64) { + if let Ok(samples) = self.latency_samples.read() { + if samples.is_empty() { + return (0, 0); + } + + let mut latencies: Vec = samples.iter().map(|s| s.latency_ms).collect(); + latencies.sort(); + + let p50_idx = (latencies.len() as f64 * 0.50) as usize; + let p99_idx = (latencies.len() as f64 * 0.99) as usize; + + let p50 = latencies.get(p50_idx).copied().unwrap_or(0); + let p99 = latencies.get(p99_idx.min(latencies.len() - 1)).copied().unwrap_or(0); + + (p50, p99) + } else { + (0, 0) + } + } + + /// Reset all metrics for a new query. + pub fn reset(&self) { + self.total_calls.store(0, Ordering::Relaxed); + self.successful_calls.store(0, Ordering::Relaxed); + self.failed_calls.store(0, Ordering::Relaxed); + self.fallback_calls.store(0, Ordering::Relaxed); + self.total_input_tokens.store(0, Ordering::Relaxed); + self.total_output_tokens.store(0, Ordering::Relaxed); + self.total_latency_ms.store(0, Ordering::Relaxed); + self.start_interventions.store(0, Ordering::Relaxed); + self.fork_interventions.store(0, Ordering::Relaxed); + self.backtrack_interventions.store(0, Ordering::Relaxed); + self.evaluate_interventions.store(0, Ordering::Relaxed); + + if let Ok(mut samples) = self.latency_samples.write() { + samples.clear(); + } + } + + /// Get total tokens used. + pub fn total_tokens(&self) -> usize { + self.total_input_tokens.load(Ordering::Relaxed) + + self.total_output_tokens.load(Ordering::Relaxed) + } + + /// Get total calls made. + pub fn total_calls(&self) -> usize { + self.total_calls.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_metrics_creation() { + let metrics = MetricsCollector::new(); + let snapshot = metrics.snapshot(); + + assert_eq!(snapshot.total_calls, 0); + assert_eq!(snapshot.successful_calls, 0); + assert_eq!(snapshot.failed_calls, 0); + } + + #[test] + fn test_record_call() { + let metrics = MetricsCollector::new(); + + metrics.record_call( + InterventionPoint::Fork, + 100, + 50, + Duration::from_millis(200), + true, + false, + ); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.total_calls, 1); + assert_eq!(snapshot.successful_calls, 1); + assert_eq!(snapshot.failed_calls, 0); + assert_eq!(snapshot.total_input_tokens, 100); + assert_eq!(snapshot.total_output_tokens, 50); + assert_eq!(snapshot.fork_interventions, 1); + } + + #[test] + fn test_record_failed_call() { + let metrics = MetricsCollector::new(); + + metrics.record_call( + InterventionPoint::Start, + 100, + 0, + Duration::from_millis(100), + false, + true, + ); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.total_calls, 1); + assert_eq!(snapshot.successful_calls, 0); + assert_eq!(snapshot.failed_calls, 1); + assert_eq!(snapshot.fallback_calls, 1); + assert_eq!(snapshot.start_interventions, 1); + } + + #[test] + fn test_success_rate() { + let metrics = MetricsCollector::new(); + + // No calls + assert_eq!(metrics.snapshot().success_rate(), 0.0); + + // 3 successful, 1 failed + metrics.record_call(InterventionPoint::Fork, 0, 0, Duration::ZERO, true, false); + metrics.record_call(InterventionPoint::Fork, 0, 0, Duration::ZERO, true, false); + metrics.record_call(InterventionPoint::Fork, 0, 0, Duration::ZERO, true, false); + metrics.record_call(InterventionPoint::Fork, 0, 0, Duration::ZERO, false, false); + + assert!((metrics.snapshot().success_rate() - 0.75).abs() < 0.01); + } + + #[test] + fn test_token_utilization() { + let metrics = MetricsCollector::new(); + + metrics.record_call(InterventionPoint::Fork, 500, 200, Duration::ZERO, true, false); + + let utilization = metrics.snapshot().token_utilization(1000); + assert!((utilization - 0.7).abs() < 0.01); + } + + #[test] + fn test_latency_percentiles() { + let metrics = MetricsCollector::new(); + + // Add 100 samples with increasing latency + for i in 0..100 { + metrics.record_call( + InterventionPoint::Fork, + 0, + 0, + Duration::from_millis(i as u64 + 1), + true, + false, + ); + } + + let snapshot = metrics.snapshot(); + + // P50 should be around 50 + assert!(snapshot.p50_latency_ms >= 40 && snapshot.p50_latency_ms <= 60); + + // P99 should be around 99 + assert!(snapshot.p99_latency_ms >= 90 && snapshot.p99_latency_ms <= 100); + } + + #[test] + fn test_reset() { + let metrics = MetricsCollector::new(); + + metrics.record_call(InterventionPoint::Fork, 100, 50, Duration::from_millis(200), true, false); + assert!(metrics.total_calls() > 0); + + metrics.reset(); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.total_calls, 0); + assert_eq!(snapshot.total_input_tokens, 0); + } + + #[test] + fn test_quality_metrics() { + let metrics = MetricsCollector::new(); + + metrics.set_llm_accuracy(0.85); + metrics.set_retrieval_precision(0.92); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.llm_accuracy, Some(0.85)); + assert_eq!(snapshot.retrieval_precision, Some(0.92)); + } + + #[test] + fn test_quality_metrics_clamping() { + let metrics = MetricsCollector::new(); + + metrics.set_llm_accuracy(1.5); + metrics.set_retrieval_precision(-0.1); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.llm_accuracy, Some(1.0)); + assert_eq!(snapshot.retrieval_precision, Some(0.0)); + } + + #[test] + fn test_call_record() { + let metrics = MetricsCollector::new(); + + let record = CallRecord { + point: InterventionPoint::Backtrack, + input_tokens: 150, + output_tokens: 75, + latency_ms: 300, + success: true, + used_fallback: false, + }; + + metrics.record(record); + + let snapshot = metrics.snapshot(); + assert_eq!(snapshot.total_calls, 1); + assert_eq!(snapshot.backtrack_interventions, 1); + assert_eq!(snapshot.total_input_tokens, 150); + } +} diff --git a/src/retrieval/pilot/mod.rs b/src/retrieval/pilot/mod.rs new file mode 100644 index 00000000..6aaa9eb5 --- /dev/null +++ b/src/retrieval/pilot/mod.rs @@ -0,0 +1,80 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Pilot - The brain of the Retriever Pipeline. +//! +//! Pilot is the core intelligence component responsible for understanding queries, +//! analyzing document structure, and making navigation decisions. Unlike traditional +//! vector-based retrieval, Pilot uses LLM for semantic understanding and navigation +//! while keeping the algorithm efficient for execution. +//! +//! # Design Philosophy +//! +//! 1. Algorithm handles "how to search" - efficient, deterministic, low latency +//! 2. Pilot handles "where to go" - semantic understanding, disambiguation, direction +//! 3. Intervention at key decision points - not every step, only when needed +//! 4. Layered fallback - algorithm takes over when LLM fails, Pilot rescues when algorithm fails +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────────────┐ +//! │ Pilot Architecture │ +//! ├─────────────────────────────────────────────────────────────────────────┤ +//! │ │ +//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +//! │ │ Query │ │ Context │ │ Decision │ │ +//! │ │ Analyzer │──▶│ Builder │──▶│ Engine │ │ +//! │ └─────────────┘ └─────────────┘ └──────┬──────┘ │ +//! │ │ │ +//! │ ▼ │ +//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +//! │ │ Response │◀──│ LLM │◀──│ Prompt │ │ +//! │ │ Parser │ │ Client │ │ Builder │ │ +//! │ └─────────────┘ └─────────────┘ └─────────────┘ │ +//! │ │ +//! │ Supporting: BudgetController, FallbackManager, MetricsCollector │ +//! └─────────────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Quick Start +//! +//! ```rust,ignore +//! use vectorless::retrieval::pilot::{LlmPilot, PilotConfig, Pilot}; +//! +//! let pilot = LlmPilot::new(llm_client, PilotConfig::default()); +//! +//! // Check if intervention needed +//! if pilot.should_intervene(&state) { +//! let decision = pilot.decide(&state).await; +//! // Use decision to guide search +//! } +//! ``` + +mod budget; +mod builder; +mod config; +mod decision; +mod fallback; +mod llm_pilot; +mod metrics; +mod noop; +mod parser; +mod prompts; +mod r#trait; + +pub use budget::{BudgetController, BudgetUsage}; +pub use builder::{ContextBuilder, PilotContext, TokenBudget}; +pub use config::{ + BudgetConfig, InterventionConfig, PilotConfig, PilotMode, +}; +pub use decision::{ + InterventionPoint, PilotDecision, RankedCandidate, SearchDirection, +}; +pub use fallback::{FallbackAction, FallbackConfig, FallbackError, FallbackLevel, FallbackManager}; +pub use llm_pilot::LlmPilot; +pub use metrics::{CallRecord, MetricsCollector, PilotMetrics}; +pub use noop::NoopPilot; +pub use parser::ResponseParser; +pub use prompts::PromptBuilder; +pub use r#trait::{Pilot, PilotExt, SearchState}; diff --git a/src/retrieval/pilot/noop.rs b/src/retrieval/pilot/noop.rs new file mode 100644 index 00000000..daa95648 --- /dev/null +++ b/src/retrieval/pilot/noop.rs @@ -0,0 +1,157 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! NoopPilot - A no-operation Pilot implementation. +//! +//! This module provides a Pilot implementation that never intervenes, +//! useful for testing, benchmarking, and as a fallback when LLM +//! is unavailable. + +use async_trait::async_trait; + +use crate::domain::DocumentTree; + +use super::{InterventionPoint, Pilot, PilotConfig, PilotDecision, SearchState}; + +/// A Pilot implementation that never intervenes. +/// +/// This is useful for: +/// - Testing the search algorithm without LLM interference +/// - Benchmarking baseline performance +/// - Fallback when LLM is unavailable +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::NoopPilot; +/// +/// let pilot = NoopPilot::new(); +/// +/// // This will always return false +/// assert!(!pilot.should_intervene(&state)); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct NoopPilot { + config: PilotConfig, +} + +impl NoopPilot { + /// Create a new NoopPilot. + pub fn new() -> Self { + Self { + config: PilotConfig::algorithm_only(), + } + } + + /// Create with custom config. + pub fn with_config(config: PilotConfig) -> Self { + Self { config } + } +} + +#[async_trait] +impl Pilot for NoopPilot { + fn name(&self) -> &str { + "noop" + } + + fn should_intervene(&self, _state: &SearchState<'_>) -> bool { + // Never intervene + false + } + + async fn decide(&self, state: &SearchState<'_>) -> PilotDecision { + // Return a default decision that preserves original order + let decision = PilotDecision::preserve_order(state.candidates); + PilotDecision { + intervention_point: InterventionPoint::Fork, + ..decision + } + } + + async fn guide_start( + &self, + _tree: &DocumentTree, + _query: &str, + ) -> Option { + // No guidance at start + None + } + + async fn guide_backtrack( + &self, + _state: &SearchState<'_>, + ) -> Option { + // No guidance during backtrack + None + } + + fn config(&self) -> &PilotConfig { + &self.config + } + + fn is_active(&self) -> bool { + // NoopPilot is never active + false + } + + fn reset(&self) { + // No state to reset + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::NodeId; + use std::collections::HashSet; + + #[test] + fn test_noop_pilot_never_intervenes() { + let pilot = NoopPilot::new(); + + // Create a minimal state + let tree = DocumentTree::new("test", "test content"); + let query = "test query"; + let path: &[NodeId] = &[]; + let candidates: &[NodeId] = &[]; + let visited = HashSet::new(); + + let state = SearchState::new(&tree, query, path, candidates, &visited); + + // Should never intervene + assert!(!pilot.should_intervene(&state)); + } + + #[tokio::test] + async fn test_noop_pilot_returns_default_decision() { + let pilot = NoopPilot::new(); + + let tree = DocumentTree::new("test", "test content"); + let query = "test query"; + let path: &[NodeId] = &[]; + let candidates: &[NodeId] = &[]; + let visited = HashSet::new(); + + let state = SearchState::new(&tree, query, path, candidates, &visited); + let decision = pilot.decide(&state).await; + + assert_eq!(decision.confidence, 0.0); + assert!(!decision.has_candidates()); + } + + #[tokio::test] + async fn test_noop_pilot_no_start_guidance() { + let pilot = NoopPilot::new(); + let tree = DocumentTree::new("test", "test content"); + + let guidance = pilot.guide_start(&tree, "test").await; + assert!(guidance.is_none()); + } + + #[test] + fn test_noop_pilot_not_active() { + let pilot = NoopPilot::new(); + assert!(!pilot.is_active()); + } +} diff --git a/src/retrieval/pilot/parser.rs b/src/retrieval/pilot/parser.rs new file mode 100644 index 00000000..0447a259 --- /dev/null +++ b/src/retrieval/pilot/parser.rs @@ -0,0 +1,484 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Response parser for Pilot LLM calls. +//! +//! Parses LLM responses into structured `PilotDecision` objects. +//! Uses multiple parsing strategies with graceful fallbacks: +//! 1. JSON parse (preferred) +//! 2. Regex extraction +//! 3. Default decision (fallback) + +use regex::Regex; +use serde::{Deserialize, Serialize}; +use tracing::warn; + +use crate::domain::NodeId; +use super::decision::{PilotDecision, RankedCandidate, SearchDirection, InterventionPoint}; + +/// Parsed response from LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmResponse { + /// Ranked candidates with scores. + #[serde(default)] + pub ranked_candidates: Vec, + /// Recommended search direction. + #[serde(default)] + pub direction: DirectionResponse, + /// Confidence level (0.0 - 1.0). + #[serde(default = "default_confidence")] + pub confidence: f32, + /// Reasoning for the decision. + #[serde(default)] + pub reasoning: String, +} + +/// Candidate score from LLM response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CandidateScore { + /// Index of the candidate (0-based). + pub index: usize, + /// Score for this candidate (0.0 - 1.0). + pub score: f32, + /// Optional reason for the score. + #[serde(default)] + pub reason: Option, +} + +/// Direction response from LLM. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum DirectionResponse { + #[default] + GoDeeper, + ExploreSiblings, + Backtrack, + FoundAnswer, +} + +fn default_confidence() -> f32 { + 0.5 +} + +/// Response parser for LLM outputs. +/// +/// Implements layered parsing with graceful degradation: +/// 1. Try JSON parse first +/// 2. Fall back to regex extraction +/// 3. Return default decision if all else fails +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::ResponseParser; +/// +/// let parser = ResponseParser::new(); +/// let decision = parser.parse(&llm_response, candidates, InterventionPoint::Fork); +/// ``` +pub struct ResponseParser { + /// Regex for extracting JSON from markdown code blocks. + json_block_regex: Regex, + /// Regex for extracting confidence. + confidence_regex: Regex, + /// Regex for extracting direction. + direction_regex: Regex, +} + +impl Default for ResponseParser { + fn default() -> Self { + Self::new() + } +} + +impl ResponseParser { + /// Create a new response parser. + pub fn new() -> Self { + Self { + // Match JSON in markdown code blocks + json_block_regex: Regex::new(r"```(?:json)?\s*([\s\S]*?)```").unwrap(), + // Match confidence: 0.8 or confidence: 0.8 + confidence_regex: Regex::new(r"(?i)confidence[:\s]+([0-9.]+)").unwrap(), + // Match direction keywords + direction_regex: Regex::new( + r"(?i)(go.?deeper|explore.?siblings|backtrack|found.?answer)" + ).unwrap(), + } + } + + /// Parse LLM response into a PilotDecision. + /// + /// # Arguments + /// + /// * `response` - Raw LLM response text + /// * `candidates` - Original candidate NodeIds (for mapping indices) + /// * `point` - The intervention point + pub fn parse( + &self, + response: &str, + candidates: &[NodeId], + point: InterventionPoint, + ) -> PilotDecision { + // Try JSON parse first + if let Some(decision) = self.try_json_parse(response, candidates, point) { + return decision; + } + + // Try regex extraction + if let Some(decision) = self.try_regex_parse(response, candidates, point) { + return decision; + } + + // Return default decision + self.default_decision(candidates, point) + } + + /// Try to parse response as JSON. + fn try_json_parse( + &self, + response: &str, + candidates: &[NodeId], + point: InterventionPoint, + ) -> Option { + // First, try to extract JSON from code blocks + let json_str = if let Some(caps) = self.json_block_regex.captures(response) { + caps.get(1)?.as_str().trim().to_string() + } else { + // Try to find raw JSON object + let start = response.find('{')?; + let end = response.rfind('}')? + 1; + response[start..end].to_string() + }; + + // Parse JSON + let llm_response: LlmResponse = match serde_json::from_str(&json_str) { + Ok(r) => r, + Err(e) => { + warn!("Failed to parse LLM response as JSON: {}", e); + return None; + } + }; + + // Convert to PilotDecision + Some(self.llm_response_to_decision(llm_response, candidates, point)) + } + + /// Try to parse response using regex. + fn try_regex_parse( + &self, + response: &str, + candidates: &[NodeId], + point: InterventionPoint, + ) -> Option { + // Extract confidence + let confidence = self.confidence_regex + .captures(response) + .and_then(|caps| caps.get(1)?.as_str().parse::().ok()) + .unwrap_or(0.5) + .clamp(0.0, 1.0); + + // Extract direction + let direction = self.direction_regex + .captures(response) + .map(|caps| { + let dir = caps.get(1)?.as_str().to_lowercase(); + match dir.as_str() { + d if d.contains("deeper") => Some(SearchDirection::GoDeeper { reason: String::new() }), + d if d.contains("sibling") => Some(SearchDirection::ExploreSiblings { recommended: vec![] }), + d if d.contains("backtrack") => Some(SearchDirection::Backtrack { + reason: String::new(), + alternative_branches: vec![], + }), + d if d.contains("found") || d.contains("answer") => Some(SearchDirection::FoundAnswer { confidence }), + _ => None, + } + }) + .flatten() + .unwrap_or_else(|| SearchDirection::GoDeeper { reason: String::new() }); + + // Try to extract candidate rankings from numbered list + let ranked = self.extract_ranked_candidates(response, candidates); + + if ranked.is_empty() && candidates.len() > 1 { + return None; // Regex parse failed + } + + Some(PilotDecision { + ranked_candidates: ranked, + direction, + confidence, + reasoning: "Extracted via regex".to_string(), + intervention_point: point, + }) + } + + /// Extract ranked candidates from text using patterns. + fn extract_ranked_candidates(&self, response: &str, candidates: &[NodeId]) -> Vec { + let mut ranked = Vec::new(); + + // Pattern: "1. Candidate Name (score: 0.8)" + let ranking_pattern = Regex::new(r"(\d+)[.\)]\s*(?:Candidate\s*)?(\d+)[\s:]+(?:score[:\s]*)?([0-9.]+)?").unwrap(); + + for caps in ranking_pattern.captures_iter(response) { + if let Some(index_match) = caps.get(2) { + if let Ok(index) = index_match.as_str().parse::() { + let score: f32 = caps.get(3) + .and_then(|m| m.as_str().parse().ok()) + .unwrap_or(0.5); + + if index < candidates.len() { + ranked.push(RankedCandidate { + node_id: candidates[index], + score: score.clamp(0.0, 1.0), + reason: None, + }); + } + } + } + } + + // If we got some rankings, return them + if !ranked.is_empty() { + return ranked; + } + + // Fallback: look for numbers that might be candidate indices + let number_pattern = Regex::new(r"\b(\d+)\b").unwrap(); + let mut seen = std::collections::HashSet::new(); + + for caps in number_pattern.captures_iter(response) { + if let Some(match_1) = caps.get(1) { + if let Ok(idx) = match_1.as_str().parse::() { + if idx < candidates.len() && seen.insert(idx) { + ranked.push(RankedCandidate { + node_id: candidates[idx], + score: 1.0 - (ranked.len() as f32 * 0.1), // Decreasing scores + reason: None, + }); + } + } + } + + if ranked.len() >= candidates.len() { + break; + } + } + + ranked + } + + /// Convert LlmResponse to PilotDecision. + fn llm_response_to_decision( + &self, + llm_response: LlmResponse, + candidates: &[NodeId], + point: InterventionPoint, + ) -> PilotDecision { + // Convert candidate scores to RankedCandidate + let ranked_candidates: Vec = llm_response + .ranked_candidates + .into_iter() + .filter_map(|cs| { + if cs.index < candidates.len() { + Some(RankedCandidate { + node_id: candidates[cs.index], + score: cs.score.clamp(0.0, 1.0), + reason: cs.reason, + }) + } else { + None + } + }) + .collect(); + + // Convert direction + let direction = match llm_response.direction { + DirectionResponse::GoDeeper => SearchDirection::GoDeeper { + reason: llm_response.reasoning.clone(), + }, + DirectionResponse::ExploreSiblings => SearchDirection::ExploreSiblings { + recommended: ranked_candidates.iter().take(3).map(|c| c.node_id).collect(), + }, + DirectionResponse::Backtrack => SearchDirection::Backtrack { + reason: llm_response.reasoning.clone(), + alternative_branches: ranked_candidates.iter().take(3).map(|c| c.node_id).collect(), + }, + DirectionResponse::FoundAnswer => SearchDirection::FoundAnswer { + confidence: llm_response.confidence, + }, + }; + + PilotDecision { + ranked_candidates, + direction, + confidence: llm_response.confidence.clamp(0.0, 1.0), + reasoning: llm_response.reasoning, + intervention_point: point, + } + } + + /// Create a default decision when parsing fails. + fn default_decision(&self, candidates: &[NodeId], point: InterventionPoint) -> PilotDecision { + // Score candidates uniformly + let ranked: Vec = candidates + .iter() + .enumerate() + .map(|(i, &node_id)| RankedCandidate { + node_id, + score: 1.0 / (i + 1) as f32, // Decreasing scores + reason: None, + }) + .collect(); + + PilotDecision { + ranked_candidates: ranked, + direction: SearchDirection::GoDeeper { reason: String::new() }, + confidence: 0.0, + reasoning: "Default decision (parsing failed)".to_string(), + intervention_point: point, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indextree::Arena; + + fn create_test_node_ids(count: usize) -> Vec { + let mut arena = Arena::new(); + let mut ids = Vec::new(); + for i in 0..count { + let node = crate::domain::TreeNode { + title: format!("Node {}", i), + content: String::new(), + summary: String::new(), + depth: 0, + start_index: 1, + end_index: 1, + start_page: None, + end_page: None, + node_id: None, + physical_index: None, + token_count: None, + }; + ids.push(NodeId(arena.new_node(node))); + } + ids + } + + #[test] + fn test_parse_json_response() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(3); + + let response = r#"{ + "ranked_candidates": [ + {"index": 1, "score": 0.9, "reason": "Best match"}, + {"index": 0, "score": 0.5} + ], + "direction": "go_deeper", + "confidence": 0.85, + "reasoning": "Candidate 1 is most relevant" + }"#; + + let decision = parser.parse(response, &candidates, InterventionPoint::Fork); + + assert_eq!(decision.ranked_candidates.len(), 2); + assert_eq!(decision.ranked_candidates[0].node_id, candidates[1]); + assert!((decision.confidence - 0.85).abs() < 0.01); + assert!(matches!(decision.direction, SearchDirection::GoDeeper { .. })); + } + + #[test] + fn test_parse_json_in_code_block() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(2); + + let response = r#" +Here's my analysis: + +```json +{ + "ranked_candidates": [{"index": 0, "score": 0.8}], + "direction": "go_deeper", + "confidence": 0.8, + "reasoning": "Test" +} +``` +"#; + + let decision = parser.parse(response, &candidates, InterventionPoint::Fork); + assert_eq!(decision.ranked_candidates.len(), 1); + } + + #[test] + fn test_parse_with_regex_fallback() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(2); + + // Non-JSON response with some structure + let response = r#" +I think candidate 0 is the best match. +Confidence: 0.75 +Direction: go_deeper +"#; + + let decision = parser.parse(response, &candidates, InterventionPoint::Fork); + + // Should use regex extraction + assert!((decision.confidence - 0.75).abs() < 0.01); + } + + #[test] + fn test_default_decision() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(2); + + let decision = parser.parse( + "This is unparseable gibberish", + &candidates, + InterventionPoint::Fork, + ); + + // Should return default + assert_eq!(decision.ranked_candidates.len(), 2); + assert_eq!(decision.confidence, 0.0); + assert!(decision.reasoning.contains("parsing failed")); + } + + #[test] + fn test_confidence_clamping() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(1); + + let response = r#"{ + "ranked_candidates": [{"index": 0, "score": 1.5}], + "confidence": 1.5, + "direction": "go_deeper" + }"#; + + let decision = parser.parse(response, &candidates, InterventionPoint::Fork); + + // Confidence should be clamped to 1.0 + assert!((decision.confidence - 1.0).abs() < 0.01); + } + + #[test] + fn test_direction_conversion() { + let parser = ResponseParser::new(); + let candidates = create_test_node_ids(1); + + let test_cases = vec![ + ("\"direction\": \"go_deeper\"", true), + ("\"direction\": \"explore_siblings\"", true), + ("\"direction\": \"backtrack\"", true), + ("\"direction\": \"found_answer\"", true), + ]; + + for (dir_json, should_parse) in test_cases { + let response = format!(r#"{{"ranked_candidates": [], "confidence": 0.5, {}}}"#, dir_json); + let decision = parser.parse(&response, &candidates, InterventionPoint::Fork); + assert!(should_parse, "Direction should parse correctly"); + } + } +} diff --git a/src/retrieval/pilot/prompts/builder.rs b/src/retrieval/pilot/prompts/builder.rs new file mode 100644 index 00000000..47e90d47 --- /dev/null +++ b/src/retrieval/pilot/prompts/builder.rs @@ -0,0 +1,290 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Prompt builder for constructing LLM prompts. +//! +//! Combines templates with context to produce final prompts. + +use super::super::builder::PilotContext; +use super::super::decision::InterventionPoint; +use super::templates::{ForkPrompt, PromptTemplate, StartPrompt, BacktrackPrompt, EvaluatePrompt}; + +/// Built prompt ready for LLM call. +#[derive(Debug, Clone)] +pub struct BuiltPrompt { + /// System prompt. + pub system: String, + /// User prompt. + pub user: String, + /// Total estimated tokens. + pub estimated_tokens: usize, +} + +/// Builder for constructing LLM prompts. +/// +/// Manages prompt templates and constructs final prompts +/// by combining templates with context. +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::prompts::PromptBuilder; +/// +/// let builder = PromptBuilder::new(); +/// let prompt = builder.build(InterventionPoint::Fork, &context); +/// println!("System: {}", prompt.system); +/// println!("User: {}", prompt.user); +/// ``` +pub struct PromptBuilder { + start_template: StartPrompt, + fork_template: ForkPrompt, + backtrack_template: BacktrackPrompt, + evaluate_template: EvaluatePrompt, +} + +impl Default for PromptBuilder { + fn default() -> Self { + Self::new() + } +} + +impl PromptBuilder { + /// Create a new prompt builder with default templates. + pub fn new() -> Self { + Self { + start_template: StartPrompt::with_fallback(), + fork_template: ForkPrompt::with_fallback(), + backtrack_template: BacktrackPrompt::with_fallback(), + evaluate_template: EvaluatePrompt::with_fallback(), + } + } + + /// Create with custom templates. + pub fn with_templates( + start: StartPrompt, + fork: ForkPrompt, + backtrack: BacktrackPrompt, + evaluate: EvaluatePrompt, + ) -> Self { + Self { + start_template: start, + fork_template: fork, + backtrack_template: backtrack, + evaluate_template: evaluate, + } + } + + /// Build a prompt for the given intervention point. + pub fn build(&self, point: InterventionPoint, context: &PilotContext) -> BuiltPrompt { + match point { + InterventionPoint::Start => self.build_start(context), + InterventionPoint::Fork => self.build_fork(context), + InterventionPoint::Backtrack => self.build_backtrack(context), + InterventionPoint::Evaluate => self.build_evaluate(context), + } + } + + /// Build START prompt. + fn build_start(&self, context: &PilotContext) -> BuiltPrompt { + let template = &self.start_template; + let system = template.system_prompt().to_string(); + let user = self.fill_template(template.user_prompt_template(), context); + let estimated_tokens = self.estimate_tokens(&system) + self.estimate_tokens(&user); + + BuiltPrompt { + system, + user, + estimated_tokens, + } + } + + /// Build FORK prompt. + fn build_fork(&self, context: &PilotContext) -> BuiltPrompt { + let template = &self.fork_template; + let system = template.system_prompt().to_string(); + let user = self.fill_template(template.user_prompt_template(), context); + let estimated_tokens = self.estimate_tokens(&system) + self.estimate_tokens(&user); + + BuiltPrompt { + system, + user, + estimated_tokens, + } + } + + /// Build BACKTRACK prompt. + fn build_backtrack(&self, context: &PilotContext) -> BuiltPrompt { + let template = &self.backtrack_template; + let system = template.system_prompt().to_string(); + let user = self.fill_template(template.user_prompt_template(), context); + let estimated_tokens = self.estimate_tokens(&system) + self.estimate_tokens(&user); + + BuiltPrompt { + system, + user, + estimated_tokens, + } + } + + /// Build EVALUATE prompt. + fn build_evaluate(&self, context: &PilotContext) -> BuiltPrompt { + let template = &self.evaluate_template; + let system = template.system_prompt().to_string(); + let user = self.fill_template(template.user_prompt_template(), context); + let estimated_tokens = self.estimate_tokens(&system) + self.estimate_tokens(&user); + + BuiltPrompt { + system, + user, + estimated_tokens, + } + } + + /// Fill template with context. + fn fill_template(&self, template: &str, context: &PilotContext) -> String { + let mut result = template.to_string(); + + // Replace context placeholder with full context + result = result.replace("{context}", &context.to_string()); + + // Replace individual sections + result = result.replace("{query}", &context.query_section); + result = result.replace("{path}", &context.path_section); + result = result.replace("{candidates}", &context.candidates_section); + result = result.replace("{toc}", &context.toc_section); + + result + } + + /// Estimate token count for a string. + fn estimate_tokens(&self, text: &str) -> usize { + let char_count = text.chars().count(); + let chinese_count = text + .chars() + .filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c)) + .count(); + let english_count = char_count - chinese_count; + + (chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize + } + + /// Get the template for an intervention point. + pub fn get_template(&self, point: InterventionPoint) -> &dyn PromptTemplate { + match point { + InterventionPoint::Start => &self.start_template, + InterventionPoint::Fork => &self.fork_template, + InterventionPoint::Backtrack => &self.backtrack_template, + InterventionPoint::Evaluate => &self.evaluate_template, + } + } + + /// Get output format hint for an intervention point. + pub fn output_format(&self, point: InterventionPoint) -> &'static str { + match point { + InterventionPoint::Start => { + r#"{ + "entry_points": ["list of starting node titles"], + "reasoning": "explanation", + "confidence": 0.0-1.0 +}"# + } + InterventionPoint::Fork => { + r#"{ + "ranked_candidates": [ + {"index": 0, "score": 0.9, "reason": "explanation"} + ], + "direction": "go_deeper|explore_siblings|backtrack|found_answer", + "confidence": 0.0-1.0, + "reasoning": "explanation" +}"# + } + InterventionPoint::Backtrack => { + r#"{ + "alternative_branches": [ + {"index": 0, "score": 0.8, "reason": "explanation"} + ], + "direction": "backtrack", + "confidence": 0.0-1.0, + "reasoning": "explanation" +}"# + } + InterventionPoint::Evaluate => { + r#"{ + "relevance_score": 0.0-1.0, + "is_answer": true|false, + "direction": "go_deeper|found_answer", + "confidence": 0.0-1.0, + "reasoning": "explanation" +}"# + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prompt_builder_creation() { + let builder = PromptBuilder::new(); + assert!(!builder.start_template.system_prompt().is_empty()); + assert!(!builder.fork_template.system_prompt().is_empty()); + } + + #[test] + fn test_build_fork_prompt() { + let builder = PromptBuilder::new(); + let context = PilotContext { + query_section: "Query: test query\n".to_string(), + path_section: "Path: Root → Test\n".to_string(), + candidates_section: "Candidates:\n1. Option A\n".to_string(), + toc_section: String::new(), + estimated_tokens: 50, + }; + + let prompt = builder.build(InterventionPoint::Fork, &context); + + assert!(!prompt.system.is_empty()); + assert!(!prompt.user.is_empty()); + assert!(prompt.user.contains("test query") || prompt.user.contains("Query")); + } + + #[test] + fn test_build_start_prompt() { + let builder = PromptBuilder::new(); + let context = PilotContext { + query_section: "Query: how to configure\n".to_string(), + path_section: String::new(), + candidates_section: String::new(), + toc_section: "TOC:\n1. Config\n".to_string(), + estimated_tokens: 30, + }; + + let prompt = builder.build(InterventionPoint::Start, &context); + + assert!(!prompt.system.is_empty()); + assert!(prompt.estimated_tokens > 0); + } + + #[test] + fn test_output_format() { + let builder = PromptBuilder::new(); + + let fork_format = builder.output_format(InterventionPoint::Fork); + assert!(fork_format.contains("ranked_candidates")); + + let start_format = builder.output_format(InterventionPoint::Start); + assert!(start_format.contains("entry_points")); + } + + #[test] + fn test_template_fallback() { + let start = StartPrompt::with_fallback(); + assert!(!start.system_prompt().is_empty()); + assert!(!start.user_prompt_template().is_empty()); + + let fork = ForkPrompt::with_fallback(); + assert!(!fork.system_prompt().is_empty()); + } +} diff --git a/src/retrieval/pilot/prompts/mod.rs b/src/retrieval/pilot/prompts/mod.rs new file mode 100644 index 00000000..52684dcb --- /dev/null +++ b/src/retrieval/pilot/prompts/mod.rs @@ -0,0 +1,18 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Prompt builders for Pilot LLM calls. +//! +//! Provides specialized prompts for each intervention point: +//! - START: Search initialization guidance +//! - FORK: Branch selection at decision points +//! - BACKTRACK: Recovery after dead ends +//! - EVALUATE: Node relevance assessment + +mod builder; +mod templates; + +pub use builder::PromptBuilder; +pub use templates::{ + ForkPrompt, PromptTemplate, StartPrompt, BacktrackPrompt, EvaluatePrompt, +}; diff --git a/src/retrieval/pilot/prompts/system_backtrack.txt b/src/retrieval/pilot/prompts/system_backtrack.txt new file mode 100644 index 00000000..fef7e16b --- /dev/null +++ b/src/retrieval/pilot/prompts/system_backtrack.txt @@ -0,0 +1,11 @@ +You are a document navigation assistant specialized in recovery strategies. + +Your task is to analyze why a search path failed to find the answer and suggest alternative branches to explore. + +Guidelines: +- Identify what made the failed path unsuccessful +- Look for unexplored branches that might contain the answer +- Consider if the query might be satisfied by combining information from multiple branches +- Suggest the most promising alternatives first + +You must respond in valid JSON format. diff --git a/src/retrieval/pilot/prompts/system_evaluate.txt b/src/retrieval/pilot/prompts/system_evaluate.txt new file mode 100644 index 00000000..f5d66410 --- /dev/null +++ b/src/retrieval/pilot/prompts/system_evaluate.txt @@ -0,0 +1,11 @@ +You are a document analysis assistant specialized in content evaluation. + +Your task is to determine if the current node contains the answer to the user's query. + +Guidelines: +- Carefully analyze the node's content against the query +- Consider if the content fully or partially answers the query +- If the answer seems to be in child nodes, suggest going deeper +- Only mark as "found_answer" if you're confident the content satisfies the query + +You must respond in valid JSON format. diff --git a/src/retrieval/pilot/prompts/system_fork.txt b/src/retrieval/pilot/prompts/system_fork.txt new file mode 100644 index 00000000..e4a4a5f8 --- /dev/null +++ b/src/retrieval/pilot/prompts/system_fork.txt @@ -0,0 +1,19 @@ +You are a document navigation assistant specialized in making decisions at branch points. + +Your task is to rank candidate branches by their likelihood of containing the answer to the user's query. + +Guidelines: +- Analyze each candidate's title and summary for relevance +- Consider the current search path and context +- Higher scores should go to more relevant candidates +- If uncertain between candidates, reflect this in closer scores +- If no candidate seems relevant, suggest backtracking + +Scoring guide: +- 0.9-1.0: Highly confident this branch contains the answer +- 0.7-0.9: Likely contains relevant information +- 0.5-0.7: Possibly relevant, worth exploring +- 0.3-0.5: Unlikely but may have related content +- 0.0-0.3: Not relevant + +You must respond in valid JSON format. diff --git a/src/retrieval/pilot/prompts/system_start.txt b/src/retrieval/pilot/prompts/system_start.txt new file mode 100644 index 00000000..d3a65f49 --- /dev/null +++ b/src/retrieval/pilot/prompts/system_start.txt @@ -0,0 +1,11 @@ +You are a document navigation assistant specialized in hierarchical document search. + +Your task is to analyze a user's query and the document structure to identify the best starting points for search. + +Guidelines: +- Identify sections that are most likely to contain the answer +- Consider the query's domain, keywords, and intent +- Prefer more specific sections over general ones when appropriate +- Multiple entry points can be suggested if the query is ambiguous + +You must respond in valid JSON format. diff --git a/src/retrieval/pilot/prompts/templates.rs b/src/retrieval/pilot/prompts/templates.rs new file mode 100644 index 00000000..d9106f6d --- /dev/null +++ b/src/retrieval/pilot/prompts/templates.rs @@ -0,0 +1,337 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Prompt templates for different intervention points. +//! +//! Each template is designed for a specific decision context +//! and follows a consistent structure: +//! 1. System context (role definition) +//! 2. Task description +//! 3. Input format +//! 4. Output format (JSON schema) + +use super::super::decision::InterventionPoint; + +/// Common trait for prompt templates. +pub trait PromptTemplate: Send + Sync { + /// Get the system prompt. + fn system_prompt(&self) -> &str; + + /// Get the user prompt template. + fn user_prompt_template(&self) -> &str; + + /// Get the intervention point this template is for. + fn intervention_point(&self) -> InterventionPoint; + + /// Get the expected output format (JSON schema hint). + fn output_format_hint(&self) -> &str; +} + +/// Prompt template for START intervention point. +/// +/// Used at the beginning of search to: +/// - Understand query intent +/// - Identify entry points +/// - Set search direction +#[derive(Debug, Clone)] +pub struct StartPrompt { + system: String, + template: String, +} + +impl Default for StartPrompt { + fn default() -> Self { + Self::with_fallback() + } +} + +impl StartPrompt { + /// Create a new start prompt template. + pub fn new() -> Self { + Self::default() + } + + /// Create with custom templates. + pub fn with_templates(system: String, template: String) -> Self { + Self { system, template } + } +} + +impl PromptTemplate for StartPrompt { + fn system_prompt(&self) -> &str { + &self.system + } + + fn user_prompt_template(&self) -> &str { + &self.template + } + + fn intervention_point(&self) -> InterventionPoint { + InterventionPoint::Start + } + + fn output_format_hint(&self) -> &str { + r#"{ + "entry_points": ["list of node titles to start from"], + "reasoning": "explanation of why these entry points", + "confidence": 0.0-1.0 +}"# + } +} + +/// Prompt template for FORK intervention point. +/// +/// Used when multiple candidate branches are available to: +/// - Rank candidates by relevance +/// - Recommend search direction +/// - Provide reasoning +#[derive(Debug, Clone)] +pub struct ForkPrompt { + system: String, + template: String, +} + +impl Default for ForkPrompt { + fn default() -> Self { + Self::with_fallback() + } +} + +impl ForkPrompt { + /// Create a new fork prompt template. + pub fn new() -> Self { + Self::default() + } + + /// Create with custom templates. + pub fn with_templates(system: String, template: String) -> Self { + Self { system, template } + } +} + +impl PromptTemplate for ForkPrompt { + fn system_prompt(&self) -> &str { + &self.system + } + + fn user_prompt_template(&self) -> &str { + &self.template + } + + fn intervention_point(&self) -> InterventionPoint { + InterventionPoint::Fork + } + + fn output_format_hint(&self) -> &str { + r#"{ + "ranked_candidates": [ + {"index": 0, "score": 0.9, "reason": "why this candidate"} + ], + "direction": "go_deeper|explore_siblings|backtrack|found_answer", + "confidence": 0.0-1.0, + "reasoning": "overall explanation" +}"# + } +} + +/// Prompt template for BACKTRACK intervention point. +/// +/// Used when search needs to recover from a dead end to: +/// - Analyze failure reason +/// - Suggest alternative branches +/// - Guide recovery strategy +#[derive(Debug, Clone)] +pub struct BacktrackPrompt { + system: String, + template: String, +} + +impl Default for BacktrackPrompt { + fn default() -> Self { + Self::with_fallback() + } +} + +impl BacktrackPrompt { + /// Create a new backtrack prompt template. + pub fn new() -> Self { + Self::default() + } + + /// Create with custom templates. + pub fn with_templates(system: String, template: String) -> Self { + Self { system, template } + } +} + +impl PromptTemplate for BacktrackPrompt { + fn system_prompt(&self) -> &str { + &self.system + } + + fn user_prompt_template(&self) -> &str { + &self.template + } + + fn intervention_point(&self) -> InterventionPoint { + InterventionPoint::Backtrack + } + + fn output_format_hint(&self) -> &str { + r#"{ + "alternative_branches": [ + {"index": 0, "score": 0.8, "reason": "why this alternative"} + ], + "direction": "backtrack", + "confidence": 0.0-1.0, + "reasoning": "why the original path failed and alternatives chosen" +}"# + } +} + +/// Prompt template for EVALUATE intervention point. +/// +/// Used to assess if a node contains the answer to: +/// - Determine relevance score +/// - Check if answer is found +/// - Guide further search +#[derive(Debug, Clone)] +pub struct EvaluatePrompt { + system: String, + template: String, +} + +impl Default for EvaluatePrompt { + fn default() -> Self { + Self::with_fallback() + } +} + +impl EvaluatePrompt { + /// Create a new evaluate prompt template. + pub fn new() -> Self { + Self::default() + } + + /// Create with custom templates. + pub fn with_templates(system: String, template: String) -> Self { + Self { system, template } + } +} + +impl PromptTemplate for EvaluatePrompt { + fn system_prompt(&self) -> &str { + &self.system + } + + fn user_prompt_template(&self) -> &str { + &self.template + } + + fn intervention_point(&self) -> InterventionPoint { + InterventionPoint::Evaluate + } + + fn output_format_hint(&self) -> &str { + r#"{ + "relevance_score": 0.0-1.0, + "is_answer": true|false, + "direction": "go_deeper|found_answer", + "confidence": 0.0-1.0, + "reasoning": "why this node is or isn't the answer" +}"# + } +} + +/// Fallback templates when file loading fails. +pub mod fallback { + use super::*; + + pub fn system_start() -> String { + "You are a document navigation assistant. Help identify the best starting point for searching a hierarchical document.".to_string() + } + + pub fn user_start() -> String { + r#"Given the following document structure and user query, identify the best entry points for search. + +{context} + +Respond in JSON format with your analysis."#.to_string() + } + + pub fn system_fork() -> String { + "You are a document navigation assistant. At each decision point, rank the candidate branches by their likelihood of containing the answer to the user's query.".to_string() + } + + pub fn user_fork() -> String { + r#"Given the current search context and candidate branches, rank them by relevance. + +{context} + +Respond in JSON format with ranked candidates."#.to_string() + } + + pub fn system_backtrack() -> String { + "You are a document navigation assistant. When a search path fails to find the answer, analyze why and suggest alternative branches to explore.".to_string() + } + + pub fn user_backtrack() -> String { + r#"The current search path did not find the answer. Analyze the failure and suggest alternatives. + +{context} + +Respond in JSON format with alternative branches."#.to_string() + } + + pub fn system_evaluate() -> String { + "You are a document analysis assistant. Evaluate whether the current node contains the answer to the user's query.".to_string() + } + + pub fn user_evaluate() -> String { + r#"Evaluate if this node contains the answer to the user's query. + +{context} + +Respond in JSON format with your evaluation."#.to_string() + } +} + +impl StartPrompt { + /// Get template with fallback. + pub fn with_fallback() -> Self { + Self { + system: fallback::system_start(), + template: fallback::user_start(), + } + } +} + +impl ForkPrompt { + /// Get template with fallback. + pub fn with_fallback() -> Self { + Self { + system: fallback::system_fork(), + template: fallback::user_fork(), + } + } +} + +impl BacktrackPrompt { + /// Get template with fallback. + pub fn with_fallback() -> Self { + Self { + system: fallback::system_backtrack(), + template: fallback::user_backtrack(), + } + } +} + +impl EvaluatePrompt { + /// Get template with fallback. + pub fn with_fallback() -> Self { + Self { + system: fallback::system_evaluate(), + template: fallback::user_evaluate(), + } + } +} diff --git a/src/retrieval/pilot/prompts/user_backtrack.txt b/src/retrieval/pilot/prompts/user_backtrack.txt new file mode 100644 index 00000000..b8feab8b --- /dev/null +++ b/src/retrieval/pilot/prompts/user_backtrack.txt @@ -0,0 +1,9 @@ +The current search path did not find a satisfactory answer. Analyze the situation and suggest alternative branches. + +{context} + +Provide your response as a JSON object with: +- alternative_branches: array of suggested branches with index, score, and reason +- direction: should be "backtrack" +- confidence: your confidence in these alternatives (0.0-1.0) +- reasoning: explanation of why the original path failed and why these alternatives are promising diff --git a/src/retrieval/pilot/prompts/user_evaluate.txt b/src/retrieval/pilot/prompts/user_evaluate.txt new file mode 100644 index 00000000..ca4bf51c --- /dev/null +++ b/src/retrieval/pilot/prompts/user_evaluate.txt @@ -0,0 +1,10 @@ +Evaluate whether this node's content answers the user's query. + +{context} + +Provide your response as a JSON object with: +- relevance_score: how relevant is this content (0.0-1.0) +- is_answer: true if this content answers the query, false otherwise +- direction: "go_deeper" if children might have the answer, or "found_answer" +- confidence: your confidence in this evaluation (0.0-1.0) +- reasoning: explanation of your evaluation diff --git a/src/retrieval/pilot/prompts/user_fork.txt b/src/retrieval/pilot/prompts/user_fork.txt new file mode 100644 index 00000000..a4d7f37e --- /dev/null +++ b/src/retrieval/pilot/prompts/user_fork.txt @@ -0,0 +1,9 @@ +Given the current search context and candidate branches, rank them by relevance to the user's query. + +{context} + +Provide your response as a JSON object with: +- ranked_candidates: array of objects with index, score (0.0-1.0), and reason +- direction: one of "go_deeper", "explore_siblings", "backtrack", or "found_answer" +- confidence: your overall confidence (0.0-1.0) +- reasoning: brief explanation of your decision diff --git a/src/retrieval/pilot/prompts/user_start.txt b/src/retrieval/pilot/prompts/user_start.txt new file mode 100644 index 00000000..b091735e --- /dev/null +++ b/src/retrieval/pilot/prompts/user_start.txt @@ -0,0 +1,8 @@ +Analyze the following document structure and user query to identify the best entry points for search. + +{context} + +Provide your response as a JSON object with: +- entry_points: list of section titles to start searching from +- reasoning: brief explanation of why these entry points +- confidence: your confidence in this recommendation (0.0-1.0) diff --git a/src/retrieval/pilot/trait.rs b/src/retrieval/pilot/trait.rs new file mode 100644 index 00000000..2017aa94 --- /dev/null +++ b/src/retrieval/pilot/trait.rs @@ -0,0 +1,227 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Pilot trait definition - the core interface for navigation intelligence. +//! +//! This module defines the [`Pilot`] trait which represents the brain of the +//! retrieval pipeline. Implementations provide navigation guidance at key +//! decision points during tree search. + +use async_trait::async_trait; +use std::collections::HashSet; +use std::sync::LazyLock; + +use crate::domain::{DocumentTree, NodeId}; + +use super::{PilotConfig, PilotDecision, InterventionPoint}; + +/// Empty HashSet for use in SearchState::for_start +static EMPTY_VISITED: LazyLock> = LazyLock::new(HashSet::new); + +/// Search state passed to Pilot for decision making. +/// +/// This struct contains all the context Pilot needs to understand +/// the current search situation and make informed decisions. +#[derive(Debug, Clone)] +pub struct SearchState<'a> { + /// The document tree being searched. + pub tree: &'a DocumentTree, + /// The user's query string. + pub query: &'a str, + /// Current path from root to current node. + pub path: &'a [NodeId], + /// Candidate child nodes to evaluate. + pub candidates: &'a [NodeId], + /// Set of already visited nodes (to avoid cycles). + pub visited: &'a HashSet, + /// Current depth in the tree. + pub depth: usize, + /// Current search iteration number. + pub iteration: usize, + /// Best score found so far in this search. + pub best_score: f32, + /// Whether the search is currently backtracking. + pub is_backtracking: bool, +} + +impl<'a> SearchState<'a> { + /// Create a new search state. + pub fn new( + tree: &'a DocumentTree, + query: &'a str, + path: &'a [NodeId], + candidates: &'a [NodeId], + visited: &'a HashSet, + ) -> Self { + Self { + tree, + query, + path, + candidates, + visited, + depth: path.len(), + iteration: 0, + best_score: 0.0, + is_backtracking: false, + } + } + + /// Create a minimal search state for start guidance. + pub fn for_start(tree: &'a DocumentTree, query: &'a str) -> Self { + Self { + tree, + query, + path: &[], + candidates: &[], + visited: &EMPTY_VISITED, + depth: 0, + iteration: 0, + best_score: 0.0, + is_backtracking: false, + } + } + + /// Check if we're at the root level. + pub fn is_at_root(&self) -> bool { + self.path.is_empty() + } + + /// Check if there are multiple candidates (fork point). + pub fn is_fork_point(&self) -> bool { + self.candidates.len() > 1 + } + + /// Get the current node (last in path). + pub fn current_node(&self) -> Option { + self.path.last().copied() + } +} + +/// Pilot trait - the brain of the retrieval pipeline. +/// +/// Pilot provides navigation guidance at key decision points during +/// tree search. It uses LLM intelligence for semantic understanding +/// while allowing the algorithm to handle efficient execution. +/// +/// # Implementation Notes +/// +/// Implementations should: +/// - Be cheap to construct +/// - Handle LLM failures gracefully +/// - Respect budget constraints +/// - Provide explainable decisions +/// +/// # Example +/// +/// ```rust,ignore +/// use vectorless::retrieval::pilot::{Pilot, SearchState, PilotDecision}; +/// +/// struct MyPilot; +/// +/// #[async_trait] +/// impl Pilot for MyPilot { +/// fn name(&self) -> &str { "my_pilot" } +/// +/// fn should_intervene(&self, state: &SearchState<'_>) -> bool { +/// state.candidates.len() > 3 +/// } +/// +/// async fn decide(&self, state: &SearchState<'_>) -> PilotDecision { +/// // LLM-based decision making +/// PilotDecision::default() +/// } +/// } +/// ``` +#[async_trait] +pub trait Pilot: Send + Sync { + /// Get the name of this Pilot implementation. + fn name(&self) -> &str; + + /// Determine if Pilot should intervene at this point. + /// + /// This is the key method for controlling when LLM is called. + /// Implementations should consider: + /// - Candidate count (fork points) + /// - Score uncertainty + /// - Budget constraints + /// - Current depth and iteration + /// + /// Returns `true` if Pilot should be consulted for a decision. + fn should_intervene(&self, state: &SearchState<'_>) -> bool; + + /// Make a navigation decision. + /// + /// Called when `should_intervene` returns `true`. + /// Implementations should: + /// - Build appropriate context + /// - Call LLM (if applicable) + /// - Parse and validate response + /// - Return a structured decision + /// + /// This method should never panic. On errors, return a default + /// decision that preserves the original candidate order. + async fn decide(&self, state: &SearchState<'_>) -> PilotDecision; + + /// Provide guidance before search starts. + /// + /// Called once at the beginning of search to help determine + /// the starting point and initial direction. + /// + /// Returns `None` if no guidance is available or needed. + async fn guide_start( + &self, + tree: &DocumentTree, + query: &str, + ) -> Option; + + /// Provide guidance during backtracking. + /// + /// Called when search needs to backtrack due to insufficient + /// results. Pilot can analyze the failure and suggest + /// alternative paths. + /// + /// Returns `None` if no guidance is available. + async fn guide_backtrack( + &self, + state: &SearchState<'_>, + ) -> Option; + + /// Get the current configuration. + fn config(&self) -> &PilotConfig; + + /// Check if this Pilot is actually capable of providing guidance. + /// + /// Returns `false` for NoopPilot or when budget is exhausted. + fn is_active(&self) -> bool { + true + } + + /// Reset internal state for a new query. + /// + /// Called at the start of each new search to reset + /// budget counters, caches, and other per-query state. + fn reset(&self); +} + +/// Extension trait for Pilot with utility methods. +pub trait PilotExt: Pilot { + /// Check if Pilot can intervene given current state and budget. + fn can_intervene(&self, state: &SearchState<'_>) -> bool { + self.is_active() && self.should_intervene(state) + } + + /// Get the current intervention point type. + fn intervention_point(&self, state: &SearchState<'_>) -> InterventionPoint { + if state.is_at_root() || state.iteration == 0 { + InterventionPoint::Start + } else if state.is_backtracking { + InterventionPoint::Backtrack + } else if state.is_fork_point() { + InterventionPoint::Fork + } else { + InterventionPoint::Evaluate + } + } +} + +impl PilotExt for T {} diff --git a/src/retrieval/pipeline/context.rs b/src/retrieval/pipeline/context.rs index 5dd7c388..5dafaf36 100644 --- a/src/retrieval/pipeline/context.rs +++ b/src/retrieval/pipeline/context.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use std::time::Instant; use crate::domain::{DocumentTree, NodeId}; +use crate::retrieval::pilot::Pilot; use crate::retrieval::types::{ NavigationStep, QueryComplexity, RetrieveOptions, RetrieveResponse, SearchPath, StrategyPreference, SufficiencyLevel, @@ -196,6 +197,8 @@ pub struct PipelineContext { pub tree: Arc, /// Retrieval options. pub options: RetrieveOptions, + /// Optional Pilot for navigation guidance. + pub pilot: Option>, // ============ Analyze Stage Output ============ /// Detected query complexity. @@ -255,6 +258,7 @@ impl PipelineContext { query: query.into(), tree, options, + pilot: None, complexity: None, keywords: Vec::new(), target_sections: Vec::new(), @@ -275,6 +279,28 @@ impl PipelineContext { } } + /// Create a new retrieval context with Pilot. + pub fn with_pilot( + tree: Arc, + query: impl Into, + options: RetrieveOptions, + pilot: Option>, + ) -> Self { + let mut ctx = Self::new(tree, query, options); + ctx.pilot = pilot; + ctx + } + + /// Set the Pilot for this context. + pub fn set_pilot(&mut self, pilot: Option>) { + self.pilot = pilot; + } + + /// Get the Pilot reference, if available. + pub fn pilot(&self) -> Option<&dyn Pilot> { + self.pilot.as_deref() + } + /// Start timing a stage. pub fn start_stage(&mut self) { self.stage_start = Some(Instant::now()); diff --git a/src/retrieval/pipeline/orchestrator.rs b/src/retrieval/pipeline/orchestrator.rs index 682d3e22..2dcde02e 100644 --- a/src/retrieval/pipeline/orchestrator.rs +++ b/src/retrieval/pipeline/orchestrator.rs @@ -8,17 +8,19 @@ //! - Parallel execution of independent stages //! - Backtracking for incremental retrieval //! - Failure policies +//! - Pilot integration for navigation guidance use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use crate::domain::{DocumentTree, Result}; +use crate::retrieval::pilot::{Pilot, SearchState}; // FailurePolicy is re-exported for stages use crate::retrieval::types::{RetrieveOptions, RetrieveResponse}; -use super::context::PipelineContext; +use super::context::{CandidateNode, PipelineContext}; use super::outcome::StageOutcome; use super::stage::RetrievalStage; @@ -55,6 +57,7 @@ pub struct ExecutionGroup { /// - Parallel execution of independent stages /// - Backtracking support for incremental retrieval /// - Configurable failure policies +/// - Pilot integration for intelligent navigation /// /// # Example /// @@ -64,12 +67,14 @@ pub struct ExecutionGroup { /// .stage(PlanStage::new()) /// .stage(SearchStage::new()) /// .stage(JudgeStage::new()) +/// .with_pilot(pilot) /// .with_max_backtracks(3); /// /// let response = orchestrator.execute(tree, query, options).await?; /// ``` pub struct RetrievalOrchestrator { stages: Vec, + pilot: Option>, max_backtracks: usize, max_total_iterations: usize, } @@ -85,6 +90,7 @@ impl RetrievalOrchestrator { pub fn new() -> Self { Self { stages: Vec::new(), + pilot: None, max_backtracks: 5, max_total_iterations: 10, } @@ -107,6 +113,15 @@ impl RetrievalOrchestrator { self } + /// Add Pilot for navigation guidance during backtracking. + /// + /// When set, the Pilot will be consulted during backtracking + /// to provide intelligent guidance on alternative search paths. + pub fn with_pilot(mut self, pilot: Arc) -> Self { + self.pilot = Some(pilot); + self + } + /// Set maximum number of backtracks allowed. pub fn with_max_backtracks(mut self, n: usize) -> Self { self.max_backtracks = n; @@ -301,8 +316,8 @@ impl RetrievalOrchestrator { groups.iter().filter(|g| g.parallel).count() ); - // Create context - let mut ctx = PipelineContext::new(tree, query, options); + // Create context with Pilot + let mut ctx = PipelineContext::with_pilot(tree, query, options, self.pilot.clone()); // Track execution state let mut backtrack_count = 0; @@ -361,6 +376,54 @@ impl RetrievalOrchestrator { additional_beam, go_deeper ); + // Consult Pilot for backtrack guidance + if let Some(ref pilot) = self.pilot { + if pilot.config().guide_at_backtrack { + // Build search state for Pilot + let visited: std::collections::HashSet<_> = + ctx.search_paths + .iter() + .flat_map(|p| p.nodes.iter().copied()) + .collect(); + let candidates: Vec<_> = + ctx.candidates.iter().map(|c| c.node_id).collect(); + + let state = SearchState::new( + &ctx.tree, + &ctx.query, + &[], + &candidates, + &visited, + ); + + match pilot.guide_backtrack(&state).await { + Some(guidance) => { + debug!( + "Pilot backtrack guidance: confidence={}, candidates={}", + guidance.confidence, + guidance.ranked_candidates.len() + ); + // Update candidates with Pilot's suggestions + if guidance.has_candidates() { + ctx.candidates = guidance + .ranked_candidates + .iter() + .map(|rc| CandidateNode { + node_id: rc.node_id, + score: rc.score, + depth: 0, + is_leaf: false, + }) + .collect(); + } + } + None => { + debug!("Pilot provided no backtrack guidance"); + } + } + } + } + // Update search config if let Some(ref mut config) = ctx.search_config { config.beam_width += additional_beam; @@ -392,6 +455,48 @@ impl RetrievalOrchestrator { .iter() .position(|e| e.stage.name() == target_stage) { + // Consult Pilot for backtrack guidance if going to search + if target_stage == "search" { + if let Some(ref pilot) = self.pilot { + if pilot.config().guide_at_backtrack { + let visited: std::collections::HashSet<_> = + ctx.search_paths + .iter() + .flat_map(|p| p.nodes.iter().copied()) + .collect(); + let candidates: Vec<_> = + ctx.candidates.iter().map(|c| c.node_id).collect(); + + let state = SearchState::new( + &ctx.tree, + &ctx.query, + &[], + &candidates, + &visited, + ); + + if let Some(guidance) = pilot.guide_backtrack(&state).await { + debug!( + "Pilot backtrack guidance for explicit backtrack: confidence={}", + guidance.confidence + ); + if guidance.has_candidates() { + ctx.candidates = guidance + .ranked_candidates + .iter() + .map(|rc| CandidateNode { + node_id: rc.node_id, + score: rc.score, + depth: 0, + is_leaf: false, + }) + .collect(); + } + } + } + } + } + ctx.increment_backtrack(); backtrack_count += 1; diff --git a/src/retrieval/search/beam.rs b/src/retrieval/search/beam.rs index d2d05b1e..63cdcec1 100644 --- a/src/retrieval/search/beam.rs +++ b/src/retrieval/search/beam.rs @@ -1,22 +1,32 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Beam search algorithm. +//! Beam search algorithm with Pilot integration. //! //! Explores multiple paths in parallel, keeping only the top-k candidates at each level. +//! When a Pilot is provided, it can intervene at fork points to provide semantic guidance. use async_trait::async_trait; +use std::collections::HashSet; +use tracing::{debug, trace}; use super::super::RetrievalContext; use super::super::types::{NavigationDecision, NavigationStep, SearchPath}; use super::scorer::NodeScorer; use super::{SearchConfig, SearchResult, SearchTree}; -use crate::domain::DocumentTree; +use crate::domain::{DocumentTree, NodeId}; +use crate::retrieval::pilot::{Pilot, SearchState}; /// Beam search - explores multiple paths simultaneously. /// /// Keeps top `beam_width` candidates at each level, providing /// a balance between exploration and computational cost. +/// +/// # Pilot Integration +/// +/// When a Pilot is provided, the algorithm consults it at fork points +/// (when multiple candidates are available) to get semantic guidance +/// on which branches are most relevant to the query. pub struct BeamSearch { scorer: NodeScorer, beam_width: usize, @@ -38,6 +48,60 @@ impl BeamSearch { beam_width: width.max(1), } } + + /// Score candidates using the algorithm's scorer. + fn score_candidates( + &self, + tree: &DocumentTree, + candidates: &[NodeId], + ) -> Vec<(NodeId, f32)> { + self.scorer.score_and_sort(tree, candidates) + } + + /// Merge algorithm scores with Pilot decision. + /// + /// Uses weighted combination: `final = α * algo + β * pilot` + /// where α = 0.4 and β = 0.6 * confidence + fn merge_with_pilot_decision( + &self, + tree: &DocumentTree, + candidates: &[NodeId], + pilot_decision: &crate::retrieval::pilot::PilotDecision, + ) -> Vec<(NodeId, f32)> { + let alpha = 0.4; + let beta = 0.6 * pilot_decision.confidence; + + // Build a map from node_id to pilot score + let mut pilot_scores: std::collections::HashMap = std::collections::HashMap::new(); + for ranked in &pilot_decision.ranked_candidates { + pilot_scores.insert(ranked.node_id, ranked.score); + } + + // Merge scores + let mut merged: Vec<(NodeId, f32)> = candidates + .iter() + .map(|&node_id| { + let algo_score = self.scorer.score(tree, node_id); + let pilot_score = pilot_scores.get(&node_id).copied().unwrap_or(0.0); + + // Weighted combination + let final_score = if beta > 0.0 { + (alpha * algo_score + beta * pilot_score) / (alpha + beta) + } else { + algo_score + }; + + (node_id, final_score) + }) + .collect(); + + // Sort by merged score + merged.sort_by(|a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + merged + } } impl Default for BeamSearch { @@ -53,26 +117,47 @@ impl SearchTree for BeamSearch { tree: &DocumentTree, context: &RetrievalContext, config: &SearchConfig, + pilot: Option<&dyn Pilot>, ) -> SearchResult { let mut result = SearchResult::default(); let beam_width = config.beam_width.min(self.beam_width); + let mut visited: HashSet = HashSet::new(); + + // Track Pilot interventions + let mut pilot_interventions = 0; // Initialize with root's children let root_children = tree.children(tree.root()); - let mut current_beam: Vec = root_children - .iter() - .map(|&child_id| { - let score = self.scorer.score(tree, child_id); - SearchPath::from_node(child_id, score) - }) + + // Check if Pilot wants to guide the start + let initial_candidates = if let Some(p) = pilot { + if p.config().guide_at_start { + if let Some(guidance) = p.guide_start(tree, &context.query).await { + debug!("Pilot provided start guidance with confidence {}", guidance.confidence); + pilot_interventions += 1; + + // Use Pilot's ranked order if available + if guidance.has_candidates() { + self.merge_with_pilot_decision(tree, &root_children, &guidance) + } else { + self.score_candidates(tree, &root_children) + } + } else { + self.score_candidates(tree, &root_children) + } + } else { + self.score_candidates(tree, &root_children) + } + } else { + self.score_candidates(tree, &root_children) + }; + + let mut current_beam: Vec = initial_candidates + .into_iter() + .map(|(node_id, score)| SearchPath::from_node(node_id, score)) .collect(); - // Sort by score and keep top beam_width - current_beam.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + // Keep top beam_width current_beam.truncate(beam_width); for iteration in 0..config.max_iterations { @@ -86,6 +171,8 @@ impl SearchTree for BeamSearch { for path in ¤t_beam { if let Some(leaf_id) = path.leaf { + visited.insert(leaf_id); + // Check if this is a leaf node if tree.is_leaf(leaf_id) { // Add to final results @@ -98,7 +185,44 @@ impl SearchTree for BeamSearch { // Expand this path let children = tree.children(leaf_id); - let scored_children = self.scorer.score_and_sort(tree, &children); + + // ========== Pilot Intervention Point ========== + let scored_children = if let Some(p) = pilot { + // Build search state for Pilot + let state = SearchState::new( + tree, + &context.query, + &path.nodes, + &children, + &visited, + ); + + // Check if Pilot wants to intervene + if p.should_intervene(&state) { + trace!("Pilot intervening at fork with {} candidates", children.len()); + + match p.decide(&state).await { + decision => { + pilot_interventions += 1; + debug!( + "Pilot decision: confidence={}, direction={:?}", + decision.confidence, + std::mem::discriminant(&decision.direction) + ); + + // Merge algorithm scores with Pilot decision + self.merge_with_pilot_decision(tree, &children, &decision) + } + } + } else { + // No intervention, use algorithm scoring + self.score_candidates(tree, &children) + } + } else { + // No Pilot, use algorithm scoring + self.score_candidates(tree, &children) + }; + // ============================================== for (child_id, child_score) in scored_children.into_iter().take(beam_width) { let new_path = path.extend(child_id, child_score); @@ -152,6 +276,9 @@ impl SearchTree for BeamSearch { }); result.paths.truncate(config.top_k); + // Record Pilot interventions + result.pilot_interventions = pilot_interventions; + result } @@ -159,3 +286,23 @@ impl SearchTree for BeamSearch { "beam" } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_beam_search_creation() { + let search = BeamSearch::new(); + assert_eq!(search.beam_width, 3); + + let search_wide = BeamSearch::with_width(5); + assert_eq!(search_wide.beam_width, 5); + } + + #[test] + fn test_beam_search_minimum_width() { + let search = BeamSearch::with_width(0); + assert_eq!(search.beam_width, 1); + } +} diff --git a/src/retrieval/search/greedy.rs b/src/retrieval/search/greedy.rs index 933b20d9..f016a066 100644 --- a/src/retrieval/search/greedy.rs +++ b/src/retrieval/search/greedy.rs @@ -1,9 +1,10 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Greedy search algorithm. +//! Greedy search algorithm with Pilot integration. //! //! Simple depth-first search that always follows the highest-scoring child. +//! When a Pilot is provided, it can provide semantic guidance at decision points. use async_trait::async_trait; @@ -12,6 +13,7 @@ use super::super::types::{NavigationDecision, NavigationStep, SearchPath}; use super::scorer::NodeScorer; use super::{SearchConfig, SearchResult, SearchTree}; use crate::domain::DocumentTree; +use crate::retrieval::pilot::Pilot; /// Greedy search - always follows the best single path. /// @@ -42,7 +44,10 @@ impl SearchTree for GreedySearch { tree: &DocumentTree, context: &RetrievalContext, config: &SearchConfig, + _pilot: Option<&dyn Pilot>, ) -> SearchResult { + // Note: Pilot integration for GreedySearch can be added in Phase 2 + // For now, we keep the original behavior let mut result = SearchResult::default(); let mut current_path = SearchPath::new(); let mut current_node = tree.root(); diff --git a/src/retrieval/search/mcts.rs b/src/retrieval/search/mcts.rs index c556d18f..2cc6fbd0 100644 --- a/src/retrieval/search/mcts.rs +++ b/src/retrieval/search/mcts.rs @@ -1,9 +1,10 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Monte Carlo Tree Search (MCTS) algorithm. +//! Monte Carlo Tree Search (MCTS) algorithm with Pilot integration. //! //! Balances exploration and exploitation using UCT formula. +//! When a Pilot is provided, it can provide semantic guidance at decision points. use async_trait::async_trait; use std::collections::HashMap; @@ -14,6 +15,7 @@ use super::scorer::NodeScorer; use super::{SearchConfig, SearchResult, SearchTree}; use crate::config::StrategyConfig; use crate::domain::{DocumentTree, NodeId}; +use crate::retrieval::pilot::Pilot; /// Statistics for a node in MCTS. #[derive(Debug, Clone, Default)] @@ -150,7 +152,10 @@ impl SearchTree for MctsSearch { tree: &DocumentTree, context: &RetrievalContext, config: &SearchConfig, + _pilot: Option<&dyn Pilot>, ) -> SearchResult { + // Note: Pilot integration for MCTS can be added in Phase 2 + // For now, we keep the original behavior let mut result = SearchResult::default(); let mut stats: HashMap = HashMap::new(); let root = tree.root(); diff --git a/src/retrieval/search/trait.rs b/src/retrieval/search/trait.rs index afbeb970..927753cf 100644 --- a/src/retrieval/search/trait.rs +++ b/src/retrieval/search/trait.rs @@ -8,6 +8,7 @@ use async_trait::async_trait; use super::super::RetrievalContext; use super::super::types::{NavigationStep, SearchPath}; use crate::domain::DocumentTree; +use crate::retrieval::pilot::Pilot; /// Result of a search operation. #[derive(Debug, Clone)] @@ -20,6 +21,8 @@ pub struct SearchResult { pub nodes_visited: usize, /// Number of iterations performed. pub iterations: usize, + /// Number of Pilot interventions. + pub pilot_interventions: usize, } impl Default for SearchResult { @@ -29,6 +32,7 @@ impl Default for SearchResult { trace: Vec::new(), nodes_visited: 0, iterations: 0, + pilot_interventions: 0, } } } @@ -64,6 +68,18 @@ impl Default for SearchConfig { /// /// Implementations provide different strategies for exploring /// the document tree to find relevant content. +/// +/// # Pilot Integration +/// +/// Search algorithms can optionally accept a [`Pilot`] for intelligent +/// navigation guidance at key decision points. When a Pilot is provided, +/// the algorithm consults it at: +/// - Fork points (multiple candidates) +/// - Low confidence situations +/// - Backtracking decisions +/// +/// When no Pilot is provided (None), the algorithm uses its default +/// scoring mechanism. #[async_trait] pub trait SearchTree: Send + Sync { /// Search the tree for relevant nodes. @@ -73,6 +89,7 @@ pub trait SearchTree: Send + Sync { /// * `tree` - The document tree to search /// * `context` - Retrieval context with query information /// * `config` - Search configuration + /// * `pilot` - Optional Pilot for navigation guidance /// /// # Returns /// @@ -82,8 +99,19 @@ pub trait SearchTree: Send + Sync { tree: &DocumentTree, context: &RetrievalContext, config: &SearchConfig, + pilot: Option<&dyn Pilot>, ) -> SearchResult; + /// Search without Pilot (uses default algorithm scoring). + async fn search_without_pilot( + &self, + tree: &DocumentTree, + context: &RetrievalContext, + config: &SearchConfig, + ) -> SearchResult { + self.search(tree, context, config, None).await + } + /// Get the name of this search algorithm. fn name(&self) -> &str; } diff --git a/src/retrieval/stages/search.rs b/src/retrieval/stages/search.rs index 648dd295..0283de23 100644 --- a/src/retrieval/stages/search.rs +++ b/src/retrieval/stages/search.rs @@ -1,10 +1,11 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Search Stage - Execute tree search. +//! Search Stage - Execute tree search with Pilot integration. //! //! This stage executes the selected search algorithm using -//! the selected retrieval strategy. +//! the selected retrieval strategy. When a Pilot is provided, +//! it can provide semantic guidance at key decision points. use async_trait::async_trait; use std::sync::Arc; @@ -12,6 +13,7 @@ use tracing::{info, warn}; use crate::domain::DocumentTree; // LlmClient is used via strategy +use crate::retrieval::pilot::Pilot; use crate::retrieval::RetrievalContext; // Legacy context use crate::retrieval::pipeline::{ CandidateNode, FailurePolicy, PipelineContext, RetrievalStage, SearchAlgorithm, StageOutcome, @@ -22,23 +24,36 @@ use crate::retrieval::search::{ use crate::retrieval::strategy::{KeywordStrategy, LlmStrategy, RetrievalStrategy}; use crate::retrieval::types::StrategyPreference; -/// Search Stage - executes tree search. +/// Search Stage - executes tree search with optional Pilot guidance. /// /// This stage: /// 1. Instantiates the selected search algorithm /// 2. Creates the appropriate strategy -/// 3. Executes search and collects candidates +/// 3. Executes search with optional Pilot intervention +/// 4. Collects candidates +/// +/// # Pilot Integration +/// +/// When a Pilot is provided via [`with_pilot`], the search algorithm +/// can consult it at key decision points for semantic guidance. +/// Without a Pilot, the search uses pure algorithm scoring. /// /// # Example /// /// ```rust,ignore +/// use vectorless::retrieval::pilot::{LlmPilot, PilotConfig}; +/// +/// let pilot = LlmPilot::new(llm_client, PilotConfig::default()); /// let stage = SearchStage::new() +/// .with_pilot(Arc::new(pilot)) /// .with_llm_strategy(llm_strategy); /// ``` pub struct SearchStage { keyword_strategy: KeywordStrategy, llm_strategy: Option>, semantic_strategy: Option>, + /// Pilot for navigation guidance (optional). + pilot: Option>, } impl Default for SearchStage { @@ -48,15 +63,26 @@ impl Default for SearchStage { } impl SearchStage { - /// Create a new search stage. + /// Create a new search stage without Pilot. pub fn new() -> Self { Self { keyword_strategy: KeywordStrategy::new(), llm_strategy: None, semantic_strategy: None, + pilot: None, } } + /// Add Pilot for semantic navigation guidance. + /// + /// When provided, the search algorithm will consult the Pilot + /// at key decision points to get semantic guidance on which + /// branches are most relevant to the query. + pub fn with_pilot(mut self, pilot: Arc) -> Self { + self.pilot = Some(pilot); + self + } + /// Add LLM strategy for complex queries. pub fn with_llm_strategy(mut self, strategy: LlmStrategy) -> Self { self.llm_strategy = Some(Arc::new(strategy)); @@ -69,6 +95,11 @@ impl SearchStage { self } + /// Check if Pilot is available and active. + pub fn has_pilot(&self) -> bool { + self.pilot.as_ref().map(|p| p.is_active()).unwrap_or(false) + } + /// Get the strategy to use based on context. fn get_strategy(&self, ctx: &PipelineContext) -> Arc { let preference = ctx.selected_strategy.unwrap_or(StrategyPreference::Auto); @@ -136,7 +167,7 @@ impl SearchStage { #[async_trait] impl RetrievalStage for SearchStage { - fn name(&self) -> &'static str { + fn name(&self) -> &str { "search" } @@ -160,13 +191,20 @@ impl RetrievalStage for SearchStage { let start = std::time::Instant::now(); // Get strategy and algorithm - let strategy = self.get_strategy(ctx); + let _strategy = self.get_strategy(ctx); let algorithm = ctx.selected_algorithm.unwrap_or(SearchAlgorithm::Beam); let config = ctx.search_config.clone().unwrap_or_default(); + // Reset Pilot state for new query + if let Some(ref pilot) = self.pilot { + pilot.reset(); + } + info!( - "Executing search: algorithm={:?}, beam_width={}", - algorithm, config.beam_width + "Executing search: algorithm={:?}, beam_width={}, pilot={}", + algorithm, + config.beam_width, + if self.has_pilot() { "enabled" } else { "disabled" } ); // Increment search iteration @@ -188,24 +226,31 @@ impl RetrievalStage for SearchStage { ctx.options.sufficiency_check, ); - // Execute search based on algorithm + // Get Pilot reference (or None if not available) + let pilot_ref: Option<&dyn Pilot> = self.pilot.as_deref(); + + // Execute search based on algorithm with Pilot let result = match algorithm { SearchAlgorithm::Greedy => { let search = GreedySearch::new(); - search.search(&ctx.tree, &legacy_ctx, &search_config).await + search.search(&ctx.tree, &legacy_ctx, &search_config, pilot_ref).await } SearchAlgorithm::Beam => { let search = BeamSearch::new(); - search.search(&ctx.tree, &legacy_ctx, &search_config).await + search.search(&ctx.tree, &legacy_ctx, &search_config, pilot_ref).await } SearchAlgorithm::Mcts => { // Use beam search as fallback for now let search = BeamSearch::new(); - search.search(&ctx.tree, &legacy_ctx, &search_config).await + search.search(&ctx.tree, &legacy_ctx, &search_config, pilot_ref).await } }; - info!("Search found {} paths", result.paths.len()); + info!( + "Search found {} paths (pilot interventions: {})", + result.paths.len(), + result.pilot_interventions + ); // Update context with results ctx.search_paths = result.paths.clone(); @@ -228,12 +273,14 @@ impl RetrievalStage for SearchStage { #[cfg(test)] mod tests { use super::*; + use crate::retrieval::pilot::NoopPilot; #[test] fn test_search_stage_creation() { let stage = SearchStage::new(); assert!(stage.llm_strategy.is_none()); assert!(stage.semantic_strategy.is_none()); + assert!(!stage.has_pilot()); } #[test] @@ -241,4 +288,13 @@ mod tests { let stage = SearchStage::new(); assert_eq!(stage.depends_on(), vec!["plan"]); } + + #[test] + fn test_search_stage_with_noop_pilot() { + let pilot = Arc::new(NoopPilot::new()); + let stage = SearchStage::new().with_pilot(pilot); + + // NoopPilot is not active + assert!(!stage.has_pilot()); + } }