diff --git a/rust/src/retrieval/content/scorer.rs b/rust/src/retrieval/content/scorer.rs index 389225df..3472a733 100644 --- a/rust/src/retrieval/content/scorer.rs +++ b/rust/src/retrieval/content/scorer.rs @@ -161,8 +161,12 @@ impl RelevanceScorer { pub fn score_chunk(&self, chunk: &ContentChunk, ctx: &ScoringContext) -> ContentRelevance { let mut components = ScoreComponents::default(); - // 1. Keyword score - components.keyword_score = self.compute_keyword_score(&chunk.content); + // 1. Keyword score (content + title + summary combined) + components.keyword_score = self.compute_keyword_score(&format!( + "{} {}", + chunk.title, + chunk.content + )); // 2. BM25 score (if enabled) if matches!( diff --git a/rust/src/retrieval/pilot/llm_pilot.rs b/rust/src/retrieval/pilot/llm_pilot.rs index a7bb62a9..408397fe 100644 --- a/rust/src/retrieval/pilot/llm_pilot.rs +++ b/rust/src/retrieval/pilot/llm_pilot.rs @@ -77,8 +77,12 @@ pub struct LlmPilot { executor: Option>, /// Pilot configuration. config: PilotConfig, - /// Budget controller. + /// Budget controller for per-level call tracking. budget: BudgetController, + /// Shared pipeline budget — the primary budget source when set. + /// When available, Pilot checks this before making LLM calls and + /// records token consumption here. + pipeline_budget: parking_lot::RwLock>>, /// Context builder. context_builder: ContextBuilder, /// Prompt builder. @@ -111,6 +115,7 @@ impl LlmPilot { executor: None, config, budget, + pipeline_budget: parking_lot::RwLock::new(None), context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), @@ -131,6 +136,7 @@ impl LlmPilot { executor: Some(Arc::new(executor)), config, budget, + pipeline_budget: parking_lot::RwLock::new(None), context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), @@ -150,6 +156,7 @@ impl LlmPilot { executor: Some(executor), config, budget, + pipeline_budget: parking_lot::RwLock::new(None), context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), @@ -172,6 +179,7 @@ impl LlmPilot { executor: None, config, budget, + pipeline_budget: parking_lot::RwLock::new(None), context_builder, prompt_builder, response_parser: ResponseParser::new(), @@ -208,6 +216,17 @@ impl LlmPilot { self } + /// Set the shared pipeline budget controller. + /// + /// When set, this becomes the primary budget gate for LLM calls. + /// The Pilot's own BudgetController still tracks per-level call counts, + /// but token consumption is recorded against the pipeline budget. + /// Call this at query time (not construction time) since the pipeline + /// budget is created per-query. + pub fn set_pipeline_budget(&self, budget: Arc) { + *self.pipeline_budget.write() = Some(budget); + } + /// Check if using LlmExecutor (unified throttle/retry/fallback). pub fn has_executor(&self) -> bool { self.executor.is_some() @@ -259,7 +278,17 @@ impl LlmPilot { } /// Check if budget allows LLM calls. + /// + /// Checks the shared pipeline budget first (if set), then falls back + /// to the Pilot's own per-call budget. fn has_budget(&self) -> bool { + // Primary: check pipeline budget + if let Some(ref pb) = *self.pipeline_budget.read() { + if pb.status().should_stop() { + return false; + } + } + // Secondary: check Pilot's own call-level budget self.budget.can_call() } @@ -335,22 +364,6 @@ impl LlmPilot { } } - println!( - "[DEBUG] LlmPilot::call_llm() - point={:?}, estimated_tokens={}", - point, prompt.estimated_tokens - ); - println!( - "[DEBUG] LlmPilot::call_llm() - SYSTEM PROMPT:\n{}", - prompt.system - ); - println!( - "[DEBUG] LlmPilot::call_llm() - USER PROMPT:\n{}", - prompt.user - ); - println!( - "[DEBUG] LlmPilot::call_llm() - candidates count: {}", - candidates.len() - ); debug!( "Calling LLM for {:?} point (estimated: {} tokens)", point, prompt.estimated_tokens @@ -358,35 +371,28 @@ impl LlmPilot { // Make LLM call -use executor if available, otherwise use client directly let result = if let Some(ref executor) = self.executor { - println!("[DEBUG] LlmPilot::call_llm() - using LlmExecutor"); // Use LlmExecutor for unified throttle/retry/fallback executor.complete(&prompt.system, &prompt.user).await } else { - println!("[DEBUG] LlmPilot::call_llm() - using direct client"); // Fallback to direct client call self.client.complete(&prompt.system, &prompt.user).await }; match result { Ok(response) => { - println!( - "[DEBUG] LlmPilot::call_llm() - RAW LLM RESPONSE:\n{}", - response - ); // Record usage (estimate output tokens) let output_tokens = self.estimate_tokens(&response); + let total_tokens = prompt.estimated_tokens + output_tokens; self.budget .record_usage(prompt.estimated_tokens, output_tokens, 0); + // Also record in pipeline budget if shared + if let Some(ref pb) = *self.pipeline_budget.read() { + pb.record_tokens(total_tokens); + } + // Parse response let mut decision = self.response_parser.parse(&response, candidates, point); - println!( - "[DEBUG] LlmPilot::call_llm() - PARSED DECISION: confidence={:.2}, ranked={}, direction={:?}, reasoning={}", - decision.confidence, - decision.ranked_candidates.len(), - std::mem::discriminant(&decision.direction), - decision.reasoning.chars().take(100).collect::() - ); // Apply learner adjustment if available if let Some(ref adj) = adjustment { @@ -525,13 +531,11 @@ impl Pilot for LlmPilot { fn should_intervene(&self, state: &SearchState<'_>) -> bool { // Check mode if !self.config.mode.uses_llm() { - println!("[DEBUG] LlmPilot::should_intervene() - mode doesn't use LLM"); return false; } // Check budget if !self.has_budget() { - println!("[DEBUG] LlmPilot::should_intervene() - budget exhausted"); debug!("Budget exhausted, skipping intervention"); return false; } @@ -540,11 +544,6 @@ impl Pilot for LlmPilot { // Condition 1: Fork point with enough candidates if state.candidates.len() > intervention.fork_threshold { - println!( - "[DEBUG] LlmPilot::should_intervene() - YES: fork point with {} candidates (threshold={})", - state.candidates.len(), - intervention.fork_threshold - ); debug!( "Intervening: fork point with {} candidates", state.candidates.len() @@ -554,20 +553,12 @@ impl Pilot for LlmPilot { // Condition 2: Scores are too close (algorithm uncertain) if self.scores_are_close(state) { - println!( - "[DEBUG] LlmPilot::should_intervene() - YES: scores are close (best={:.2})", - state.best_score - ); debug!("Intervening: scores are close"); return true; } // Condition 3: Low confidence (best score too low) if intervention.is_low_confidence(state.best_score) { - println!( - "[DEBUG] LlmPilot::should_intervene() - YES: low confidence (best_score={:.2}, threshold={:.2})", - state.best_score, intervention.low_score_threshold - ); debug!( "Intervening: low confidence (best_score={:.2})", state.best_score @@ -577,26 +568,15 @@ impl Pilot for LlmPilot { // Condition 4: Backtracking and guide_at_backtrack is enabled if state.is_backtracking && self.config.guide_at_backtrack { - println!("[DEBUG] LlmPilot::should_intervene() - YES: backtracking"); debug!("Intervening: backtracking"); return true; } - println!( - "[DEBUG] LlmPilot::should_intervene() - NO: candidates={}, best_score={:.2}", - state.candidates.len(), - state.best_score - ); false } async fn decide(&self, state: &SearchState<'_>) -> PilotDecision { let point = self.get_intervention_point(state); - println!( - "[DEBUG] LlmPilot::decide() - intervention_point={:?}, candidates={}", - point, - state.candidates.len() - ); // Build context let context = self.context_builder.build(state); @@ -619,16 +599,7 @@ impl Pilot for LlmPilot { .collect(); // Make LLM call - let decision = self.call_llm(point, &context, &candidate_info).await; - - println!( - "[DEBUG] LlmPilot::decide() - result: confidence={:.2}, direction={:?}, ranked={}", - decision.confidence, - std::mem::discriminant(&decision.direction), - decision.ranked_candidates.len() - ); - - decision + self.call_llm(point, &context, &candidate_info).await } async fn guide_start( @@ -637,20 +608,13 @@ impl Pilot for LlmPilot { query: &str, start_node: NodeId, ) -> Option { - println!( - "[DEBUG] LlmPilot::guide_start() called, query='{}', start_node={:?}", - query, start_node - ); - // Check if guide_at_start is enabled if !self.config.guide_at_start { - println!("[DEBUG] LlmPilot::guide_start() - guide_at_start=false, skipping"); return None; } // Check budget if !self.has_budget() { - println!("[DEBUG] LlmPilot::guide_start() - budget exhausted, skipping"); debug!("Budget exhausted, cannot guide start"); return None; } @@ -664,10 +628,6 @@ impl Pilot for LlmPilot { debug!("Start node has no children, no guidance needed"); return None; } - println!( - "[DEBUG] LlmPilot::guide_start() - {} children candidates from start_node", - node_ids.len() - ); // Build CandidateInfo with titles let candidates: Vec = node_ids @@ -683,30 +643,12 @@ impl Pilot for LlmPilot { .collect(); // Make LLM call - println!("[DEBUG] LlmPilot::guide_start() - calling LLM..."); let decision = self .call_llm(InterventionPoint::Start, &context, &candidates) .await; - println!( - "[DEBUG] LlmPilot::guide_start() - LLM returned: confidence={:.2}, ranked_candidates={}, reasoning='{}'", - decision.confidence, - decision.ranked_candidates.len(), - decision.reasoning.chars().take(100).collect::() - ); - - // Debug: show top ranked candidates - for (i, rc) in decision.ranked_candidates.iter().enumerate().take(3) { - if let Some(node) = tree.get(rc.node_id) { - println!( - "[DEBUG] Ranked {}: node_id={:?}, score={:.3}, title='{}'", - i, rc.node_id, rc.score, node.title - ); - } - } - info!( - "Pilot start guidance: confidence={}, candidates={}", + "Pilot start guidance: confidence={:.2}, candidates={}", decision.confidence, decision.ranked_candidates.len() ); @@ -764,8 +706,13 @@ impl Pilot for LlmPilot { fn reset(&self) { self.budget.reset(); + *self.pipeline_budget.write() = None; debug!("LlmPilot reset for new query"); } + + fn as_any(&self) -> &dyn std::any::Any { + self + } } #[cfg(test)] diff --git a/rust/src/retrieval/pilot/trait.rs b/rust/src/retrieval/pilot/trait.rs index 54936b9a..3f4b868a 100644 --- a/rust/src/retrieval/pilot/trait.rs +++ b/rust/src/retrieval/pilot/trait.rs @@ -202,6 +202,14 @@ pub trait Pilot: Send + Sync { /// Called at the start of each new search to reset /// budget counters, caches, and other per-query state. fn reset(&self); + + /// Downcast support for shared budget injection. + /// + /// Default implementation returns a dummy Any. + fn as_any(&self) -> &dyn std::any::Any { + // Default: no downcast support + &() + } } /// Extension trait for Pilot with utility methods. diff --git a/rust/src/retrieval/pipeline/context.rs b/rust/src/retrieval/pipeline/context.rs index b2a2745a..484f41ca 100644 --- a/rust/src/retrieval/pipeline/context.rs +++ b/rust/src/retrieval/pipeline/context.rs @@ -205,7 +205,8 @@ pub struct PipelineContext { /// Optional Pilot for navigation guidance. pub pilot: Option>, /// Adaptive token budget controller for the entire pipeline. - pub budget_controller: RetrievalBudgetController, + /// Shared via Arc so Pilot can read/check the same budget. + pub budget_controller: Arc, /// Tiered reasoning cache (L1 exact, L2 path pattern, L3 strategy score). pub reasoning_cache: Arc, @@ -259,6 +260,9 @@ pub struct PipelineContext { /// Fingerprint of candidate node IDs from previous evaluate call. /// Used to detect stagnant loops (same candidates → same evaluation). pub prev_candidate_fingerprint: Option, + /// Per-node content cache to avoid duplicate computation. + /// Populated by `aggregate_content()`, read by `build_response()`. + pub node_content_cache: HashMap, // ============ Final Result ============ /// Final retrieval response. @@ -282,7 +286,7 @@ impl PipelineContext { ) -> Self { // Build retrieval index for efficient operations let retrieval_index = Some(tree.build_retrieval_index()); - let budget_controller = RetrievalBudgetController::new(options.max_tokens); + let budget_controller = Arc::new(RetrievalBudgetController::new(options.max_tokens)); Self { query: query.into(), @@ -311,6 +315,7 @@ impl PipelineContext { accumulated_content: String::new(), token_count: 0, prev_candidate_fingerprint: None, + node_content_cache: HashMap::new(), result: None, stage_results: HashMap::new(), metrics: RetrievalMetrics::default(), diff --git a/rust/src/retrieval/pipeline/mod.rs b/rust/src/retrieval/pipeline/mod.rs index 6726a8ce..5c84a509 100644 --- a/rust/src/retrieval/pipeline/mod.rs +++ b/rust/src/retrieval/pipeline/mod.rs @@ -40,7 +40,7 @@ mod orchestrator; mod outcome; mod stage; -pub use budget::BudgetStatus; +pub use budget::{BudgetStatus, RetrievalBudgetController}; pub use context::{CandidateNode, PipelineContext, SearchAlgorithm, SearchConfig}; pub use orchestrator::RetrievalOrchestrator; pub use outcome::StageOutcome; diff --git a/rust/src/retrieval/pipeline/orchestrator.rs b/rust/src/retrieval/pipeline/orchestrator.rs index 704629e3..c42b5cbe 100644 --- a/rust/src/retrieval/pipeline/orchestrator.rs +++ b/rust/src/retrieval/pipeline/orchestrator.rs @@ -328,6 +328,13 @@ impl RetrievalOrchestrator { ctx = ctx.with_document_graph(graph); } + // Share the pipeline budget with the Pilot (unified budget) + if let Some(ref pilot) = self.pilot { + if let Some(llm_pilot) = pilot.as_any().downcast_ref::() { + llm_pilot.set_pipeline_budget(ctx.budget_controller.clone()); + } + } + // Track execution state let mut backtrack_count = 0; let mut total_iterations = 0; @@ -613,6 +620,13 @@ impl RetrievalOrchestrator { ctx = ctx.with_document_graph(graph); } + // Share the pipeline budget with the Pilot (unified budget) + if let Some(ref pilot) = self.pilot { + if let Some(llm_pilot) = pilot.as_any().downcast_ref::() { + llm_pilot.set_pipeline_budget(ctx.budget_controller.clone()); + } + } + let mut backtrack_count = 0; let mut total_iterations = 0; let mut group_idx = 0; @@ -908,6 +922,13 @@ impl RetrievalOrchestrator { ctx = ctx.with_document_graph(graph); } + // Share the pipeline budget with the Pilot (unified budget) + if let Some(ref pilot) = self.pilot { + if let Some(llm_pilot) = pilot.as_any().downcast_ref::() { + llm_pilot.set_pipeline_budget(ctx.budget_controller.clone()); + } + } + let mut backtrack_count = 0; let mut total_iterations = 0; let mut group_idx = 0; diff --git a/rust/src/retrieval/stages/evaluate.rs b/rust/src/retrieval/stages/evaluate.rs index c9f05800..972f9667 100644 --- a/rust/src/retrieval/stages/evaluate.rs +++ b/rust/src/retrieval/stages/evaluate.rs @@ -98,14 +98,9 @@ impl EvaluateStage { /// Aggregate content from candidates. /// - /// When content aggregator is enabled: - /// - Uses relevance scoring for content selection - /// - Respects token budget - /// - Prioritizes high-relevance content - /// - /// Otherwise falls back to simple collection: - /// - Collects node's own content + descendant leaf content - fn aggregate_content(&self, ctx: &PipelineContext) -> (String, usize) { + /// Populates `ctx.node_content_cache` with per-node content so that + /// `build_response()` can reuse it without recomputing leaf traversal. + fn aggregate_content(&self, ctx: &mut PipelineContext) -> (String, usize) { // Use ContentAggregator if configured if let Some(ref aggregator) = self.content_aggregator { use crate::retrieval::content::CandidateNode; @@ -124,47 +119,62 @@ impl EvaluateStage { return (result.content, result.tokens_used); } - // Fallback: simple content collection + // Simple content collection with per-node caching self.aggregate_content_simple(ctx) } - /// Simple content aggregation (legacy behavior). - fn aggregate_content_simple(&self, ctx: &PipelineContext) -> (String, usize) { + /// Simple content aggregation with per-node caching. + /// + /// Computes each candidate's content once and stores it in + /// `ctx.node_content_cache` for reuse by `build_response()`. + fn aggregate_content_simple(&self, ctx: &mut PipelineContext) -> (String, usize) { let mut content_parts = Vec::new(); let mut total_tokens = 0; for candidate in &ctx.candidates { if let Some(node) = ctx.tree.get(candidate.node_id) { - // Add title - content_parts.push(format!("## {}\n", node.title)); - - // Always collect all content: own content + descendant leaf content - let mut has_content = false; - - // Add node's own content if available - if !node.content.is_empty() { - content_parts.push(format!("{}\n\n", node.content)); - has_content = true; + // Build per-node content (own + leaf descendants) + let node_content = self.build_node_content(&ctx.tree, candidate.node_id); + + // Cache for build_response reuse + ctx.node_content_cache.insert(candidate.node_id, node_content.clone()); + + // Add to aggregated content + if !node_content.is_empty() { + content_parts.push(format!("## {}\n", node.title)); + content_parts.push(format!("{}\n\n", node_content)); + total_tokens += estimate_tokens(&node_content); + } else if !node.summary.is_empty() { + content_parts.push(format!("## {}\n", node.title)); + content_parts.push(format!("{}\n\n", node.summary)); + total_tokens += estimate_tokens(&node.summary); } + } + } - // Also collect content from leaf descendants (for intermediate nodes) - let leaf_content = self.collect_leaf_content(&ctx.tree, candidate.node_id); - if !leaf_content.is_empty() { - content_parts.push(format!("{}\n\n", leaf_content)); - has_content = true; - } + (content_parts.join(""), total_tokens) + } - // Fall back to summary only if no content available - if !has_content && !node.summary.is_empty() { - content_parts.push(format!("{}\n\n", node.summary)); - } + /// Build content for a single node (own content + leaf descendants). + fn build_node_content( + &self, + tree: &crate::document::DocumentTree, + node_id: crate::document::NodeId, + ) -> String { + let mut parts = Vec::new(); - // Estimate tokens - total_tokens += estimate_tokens(&content_parts.last().unwrap_or(&String::new())); + if let Some(node) = tree.get(node_id) { + if !node.content.is_empty() { + parts.push(node.content.clone()); } } - (content_parts.join(""), total_tokens) + let leaf_content = self.collect_leaf_content(tree, node_id); + if !leaf_content.is_empty() { + parts.push(leaf_content); + } + + parts.join("\n\n") } /// Collect content from leaf descendants of a node (excluding the node itself). @@ -227,30 +237,27 @@ impl EvaluateStage { } /// Build the final response. + /// + /// Reads per-node content from `ctx.node_content_cache` populated + /// during `aggregate_content()` — no duplicate leaf traversal. fn build_response(&self, ctx: &PipelineContext) -> RetrieveResponse { let mut results = Vec::new(); for candidate in &ctx.candidates { if let Some(node) = ctx.tree.get(candidate.node_id) { - // Build content: node's own content + all descendant leaf content let content = if ctx.options.include_content { - let mut content_parts = Vec::new(); - - // Add node's own content - if !node.content.is_empty() { - content_parts.push(node.content.clone()); - } - - // Add content from leaf descendants - let leaf_content = self.collect_leaf_content(&ctx.tree, candidate.node_id); - if !leaf_content.is_empty() { - content_parts.push(leaf_content); - } - - if content_parts.is_empty() { - None - } else { - Some(content_parts.join("\n\n")) + // Read from cache — computed once in aggregate_content() + match ctx.node_content_cache.get(&candidate.node_id) { + Some(cached) if !cached.is_empty() => Some(cached.clone()), + _ => { + // Cache miss (edge case): compute inline + let built = self.build_node_content(&ctx.tree, candidate.node_id); + if built.is_empty() { + None + } else { + Some(built) + } + } } } else { None diff --git a/rust/src/retrieval/stages/search.rs b/rust/src/retrieval/stages/search.rs index 31e6ae9a..fcef9052 100644 --- a/rust/src/retrieval/stages/search.rs +++ b/rust/src/retrieval/stages/search.rs @@ -48,7 +48,6 @@ use crate::retrieval::types::{ pub struct SearchStage { keyword_strategy: KeywordStrategy, llm_strategy: Option>, - semantic_strategy: Option>, hybrid_strategy: Option>, /// Pilot for navigation guidance (optional). pilot: Option>, @@ -70,7 +69,6 @@ impl SearchStage { Self { keyword_strategy: KeywordStrategy::new(), llm_strategy: None, - semantic_strategy: None, hybrid_strategy: None, pilot: None, llm_client: None, @@ -99,12 +97,6 @@ impl SearchStage { self } - /// Add semantic strategy for embedding-based search. - pub fn with_semantic_strategy(mut self, strategy: Arc) -> Self { - self.semantic_strategy = Some(strategy); - self - } - /// Add hybrid strategy (BM25 + LLM refinement). pub fn with_hybrid_strategy(mut self, strategy: Arc) -> Self { self.hybrid_strategy = Some(strategy); @@ -135,15 +127,6 @@ impl SearchStage { info!("Using Keyword strategy"); Arc::new(self.keyword_strategy.clone()) } - StrategyPreference::ForceSemantic => { - if let Some(ref strategy) = self.semantic_strategy { - info!("Using Semantic strategy"); - strategy.clone() - } else { - warn!("Semantic strategy requested but not available, falling back to Keyword"); - Arc::new(self.keyword_strategy.clone()) - } - } StrategyPreference::ForceLlm => { if let Some(ref strategy) = self.llm_strategy { info!("Using LLM strategy"); @@ -607,6 +590,28 @@ impl RetrievalStage for SearchStage { .locate(&ctx.query, &ctx.tree, &top_level_nodes) .await; + // === L2 Cache boost: boost cues whose paths have historical success === + let doc_key = format!("{:?}", ctx.tree.root()); + let l2_paths = ctx.reasoning_cache.l2_top_paths(&doc_key, 5); + if !l2_paths.is_empty() { + for cue in &mut cues { + if let Some(node) = ctx.tree.get(cue.root) { + let node_path = node.title.as_str(); + if let Some((_, cached_conf)) = l2_paths + .iter() + .find(|(path, _)| node_path.contains(path.as_str()) || path.contains(node_path)) + { + // Blend current confidence with historical: 60% current + 40% cached + cue.confidence = cue.confidence * 0.6 + cached_conf * 0.4; + debug!( + "L2 cache boost for '{}': {:.3} → {:.3}", + node_path, cue.confidence, cue.confidence + ); + } + } + } + } + debug!("ToCNavigator returned {} cues", cues.len()); // Inject structure hints from Analyze stage as high-priority cues @@ -682,6 +687,44 @@ impl RetrievalStage for SearchStage { .collect(); tracker.record_hits(&hits); } + + // === L3 Cache boost: use cached strategy scores to refine candidates === + for candidate in &mut ctx.candidates { + if let Some(node) = ctx.tree.get(candidate.node_id) { + let content_fp = crate::utils::fingerprint::Fingerprint::from_str(&node.content); + if let Some((cached_score, _strategy)) = + ctx.reasoning_cache.l3_get(&content_fp) + { + // Blend: if L3 has a higher score for this node, boost it + if cached_score > candidate.score { + candidate.score = (candidate.score + cached_score) / 2.0; + } + } + } + } + // Re-sort after L3 boost + ctx.candidates.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Store L3 scores for future queries + for candidate in &ctx.candidates { + if let Some(node) = ctx.tree.get(candidate.node_id) { + if !node.content.is_empty() { + let content_fp = + crate::utils::fingerprint::Fingerprint::from_str(&node.content); + ctx.reasoning_cache.l3_store( + content_fp, + candidate.score, + ctx.selected_strategy + .map(|s| format!("{:?}", s)) + .unwrap_or_else(|| "auto".to_string()), + ); + } + } + } // Estimate tokens consumed by this search iteration (content-based heuristic) let search_tokens: usize = ctx .candidates @@ -810,7 +853,6 @@ mod tests { fn test_search_stage_creation() { let stage = SearchStage::new(); assert!(stage.llm_strategy.is_none()); - assert!(stage.semantic_strategy.is_none()); assert!(!stage.has_pilot()); } diff --git a/rust/src/retrieval/strategy/cross_document.rs b/rust/src/retrieval/strategy/cross_document.rs index 97a475d6..40871057 100644 --- a/rust/src/retrieval/strategy/cross_document.rs +++ b/rust/src/retrieval/strategy/cross_document.rs @@ -221,6 +221,9 @@ impl CrossDocumentStrategy { } /// Search a single document and return results. + /// + /// Performs depth-first traversal: evaluates top-level nodes first, + /// then recursively explores children of high-scoring nodes. async fn search_document( &self, doc: &DocumentEntry, @@ -229,19 +232,32 @@ impl CrossDocumentStrategy { let root_id = doc.tree.root(); let children = doc.tree.children(root_id); - // Evaluate top-level nodes to find entry points - let evaluations = self + // Phase 1: Evaluate top-level nodes + let top_evaluations = self .inner .evaluate_nodes(&doc.tree, &children, context) .await; - // Collect results with scores above threshold let mut scored_nodes: Vec<(NodeId, NodeEvaluation)> = children .into_iter() - .zip(evaluations.into_iter()) + .zip(top_evaluations.into_iter()) .filter(|(_, eval)| eval.score >= self.config.min_score) .collect(); + // Phase 2: Depth traversal — explore children of high-scoring nodes + let high_score_nodes: Vec = scored_nodes + .iter() + .filter(|(_, eval)| eval.score >= self.config.min_score * 1.5) + .map(|(id, _)| *id) + .collect(); + + for node_id in high_score_nodes { + let depth_results = self + .search_subtree(&doc.tree, node_id, context, 0, 2) + .await; + scored_nodes.extend(depth_results); + } + // Sort by score descending scored_nodes.sort_by(|a, b| { b.1.score @@ -249,6 +265,9 @@ impl CrossDocumentStrategy { .unwrap_or(std::cmp::Ordering::Equal) }); + // Deduplicate by node_id + scored_nodes.dedup_by(|a, b| a.0 == b.0); + // Limit results per document scored_nodes.truncate(self.config.max_results_per_doc); @@ -262,6 +281,55 @@ impl CrossDocumentStrategy { } } + /// Recursively search a subtree, evaluating children of high-scoring nodes. + fn search_subtree<'a>( + &'a self, + tree: &'a DocumentTree, + parent_id: NodeId, + context: &'a RetrievalContext, + current_depth: usize, + max_depth: usize, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + if current_depth >= max_depth { + return Vec::new(); + } + + let children = tree.children(parent_id); + if children.is_empty() { + return Vec::new(); + } + + let evaluations = self + .inner + .evaluate_nodes(tree, &children, context) + .await; + + let mut results = Vec::new(); + let mut explore_further = Vec::new(); + + for (node_id, eval) in children.into_iter().zip(evaluations.into_iter()) { + if eval.score >= self.config.min_score { + results.push((node_id, eval.clone())); + } + // Only explore deeper if score is promising + if eval.score >= self.config.min_score * 1.5 { + explore_further.push(node_id); + } + } + + // Recurse into promising children + for child_id in explore_further { + let deeper = self + .search_subtree(tree, child_id, context, current_depth + 1, max_depth) + .await; + results.extend(deeper); + } + + results + }) + } + /// Merge results from all documents. fn merge_results( &self, diff --git a/rust/src/retrieval/strategy/llm.rs b/rust/src/retrieval/strategy/llm.rs index e83b6b76..41cd8987 100644 --- a/rust/src/retrieval/strategy/llm.rs +++ b/rust/src/retrieval/strategy/llm.rs @@ -4,6 +4,8 @@ //! LLM-based retrieval strategy. //! //! Uses an LLM for deep reasoning about node relevance with ToC context. +//! Supports batch evaluation — all sibling nodes are scored in a single +//! LLM call instead of one call per node. use async_trait::async_trait; use serde::Deserialize; @@ -14,7 +16,31 @@ use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities}; use crate::document::{DocumentTree, NodeId, TocView}; use crate::llm::LlmClient; -/// LLM response for navigation decision. +/// LLM response for a single node in batch evaluation. +#[derive(Debug, Clone, Deserialize)] +struct NodeScore { + /// 1-based index matching the order in the prompt. + index: usize, + /// Relevance score (0-100, will be normalized to 0-1). + relevance: u8, + /// Decision: "answer", "explore", or "skip". + action: String, + /// Optional reasoning. + #[serde(default)] + reasoning: Option, +} + +/// LLM response for batch node evaluation. +#[derive(Debug, Clone, Deserialize)] +struct BatchResponse { + /// Analysis reasoning. + #[serde(default)] + reasoning: String, + /// Scored nodes. + nodes: Vec, +} + +/// LLM response for single-node evaluation (fallback). #[derive(Debug, Clone, Deserialize)] struct NavigationResponse { /// Relevance score (0-100, will be normalized to 0-1). @@ -31,6 +57,12 @@ struct NavigationResponse { /// Uses an LLM to reason about which nodes are most relevant /// to the query. Includes ToC context for better navigation decisions. /// +/// # Batch Evaluation +/// +/// When multiple nodes need scoring, they are sent in a single LLM call +/// instead of one call per node. This reduces latency from O(N) LLM calls +/// to O(1). +/// /// # Example /// /// ```rust,no_run @@ -45,8 +77,10 @@ struct NavigationResponse { pub struct LlmStrategy { /// The LLM client. client: LlmClient, - /// System prompt for navigation. + /// System prompt for single-node navigation. system_prompt: String, + /// System prompt for batch evaluation. + batch_system_prompt: String, /// ToC view generator. toc_view: TocView, /// Whether to include ToC context in prompts. @@ -59,6 +93,7 @@ impl LlmStrategy { Self { client, system_prompt: Self::default_system_prompt(), + batch_system_prompt: Self::default_batch_system_prompt(), toc_view: TocView::new(), include_toc: true, } @@ -81,7 +116,7 @@ impl LlmStrategy { self } - /// Default system prompt for navigation. + /// Default system prompt for single-node navigation. fn default_system_prompt() -> String { r#"You are a document navigation assistant. Your task is to help find the most relevant sections in a document tree. @@ -95,6 +130,29 @@ Respond in JSON format: Be concise and focused on finding the most relevant information."#.to_string() } + /// Default system prompt for batch node evaluation. + fn default_batch_system_prompt() -> String { + r#"You are a document navigation assistant. Score the relevance of multiple document sections against a user query. + +CRITICAL: Respond with ONLY valid JSON (no markdown code blocks). + +Response format: +{ + "reasoning": "Brief analysis of the query", + "nodes": [ + {"index": 1, "relevance": 85, "action": "answer", "reason": "Why relevant"}, + {"index": 2, "relevance": 30, "action": "skip", "reason": "Why not relevant"} + ] +} + +Rules: +- index: MUST be the number from [N] brackets in the input +- relevance: 0-100 (how relevant this section is to the query) +- action: one of "answer", "explore", "skip" +- Score ALL provided nodes, not just the top ones +- Be concise in reasons"#.to_string() + } + /// Build the navigation prompt for a single node. fn build_prompt( &self, @@ -144,7 +202,62 @@ Be concise and focused on finding the most relevant information."#.to_string() ) } - /// Parse LLM response to evaluation. + /// Build a batch prompt that presents all nodes at once. + fn build_batch_prompt( + &self, + tree: &DocumentTree, + node_ids: &[NodeId], + context: &RetrievalContext, + ) -> String { + // Collect node descriptions + let node_descriptions: Vec = node_ids + .iter() + .enumerate() + .filter_map(|(i, &node_id)| { + let node = tree.get(node_id)?; + let children = tree.children(node_id); + let summary = if node.summary.is_empty() { + let end = 200.min(node.content.len()); + &node.content[..end] + } else { + &node.summary + }; + Some(format!( + "[{}] Title: \"{}\"\n Summary: \"{}\"\n Depth: {}, Children: {}", + i + 1, + node.title, + summary, + node.depth, + children.len() + )) + }) + .collect(); + + let nodes_str = node_descriptions.join("\n\n"); + + // Optional ToC context from the first node's parent scope + let toc_context = if self.include_toc && !node_ids.is_empty() { + let toc = self.toc_view.generate_from(tree, node_ids[0]); + let toc_markdown = self.toc_view.format_markdown(&toc); + let toc_preview: String = toc_markdown.chars().take(800).collect(); + format!( + "\n\nDocument ToC:\n{}\n", + toc_preview + ) + } else { + String::new() + }; + + format!( + "USER QUERY: {}\n{}SECTIONS TO SCORE ({} entries):\n{}\n\nScore ALL sections. Respond with ONLY the JSON object:", + context.query, + toc_context, + node_ids.len(), + nodes_str + ) + } + + /// Parse LLM response to evaluation for a single node. fn parse_response( &self, response: &str, @@ -204,6 +317,73 @@ Be concise and focused on finding the most relevant information."#.to_string() )), } } + + /// Parse a batch LLM response into per-node evaluations. + /// + /// Returns evaluations in the same order as the input `node_ids`. + /// Nodes that the LLM didn't score get a default evaluation. + fn parse_batch_response( + &self, + response: &str, + tree: &DocumentTree, + node_ids: &[NodeId], + ) -> Vec { + // Try JSON parse + if let Ok(batch) = serde_json::from_str::(response) { + let mut evaluations = vec![ + NodeEvaluation { + score: 0.3, + decision: NavigationDecision::ExploreMore, + reasoning: Some("Not scored by LLM (batch fallback)".to_string()), + }; + node_ids.len() + ]; + + for node_score in batch.nodes { + let idx = node_score.index.saturating_sub(1); + if idx < node_ids.len() { + let node_id = node_ids[idx]; + let score = (node_score.relevance as f32 / 100.0).clamp(0.0, 1.0); + let decision = match node_score.action.to_lowercase().as_str() { + "answer" => NavigationDecision::ThisIsTheAnswer, + "explore" => { + if tree.is_leaf(node_id) { + NavigationDecision::ThisIsTheAnswer + } else { + NavigationDecision::ExploreMore + } + } + _ => NavigationDecision::Skip, + }; + evaluations[idx] = NodeEvaluation { + score, + decision, + reasoning: node_score.reasoning, + }; + } + } + + return evaluations; + } + + // Fallback: could not parse batch, return defaults + tracing::warn!( + "Failed to parse batch LLM response, using defaults for {} nodes", + node_ids.len() + ); + node_ids + .iter() + .map(|&node_id| NodeEvaluation { + score: 0.5, + decision: if tree.is_leaf(node_id) { + NavigationDecision::ThisIsTheAnswer + } else { + NavigationDecision::ExploreMore + }, + reasoning: Some("Batch parse fallback".to_string()), + }) + .collect() + } } #[async_trait] @@ -239,13 +419,38 @@ impl RetrievalStrategy for LlmStrategy { node_ids: &[NodeId], context: &RetrievalContext, ) -> Vec { - // Evaluate each node individually - // TODO: Could be optimized with batch prompts - let mut results = Vec::with_capacity(node_ids.len()); - for node_id in node_ids { - results.push(self.evaluate_node(tree, *node_id, context).await); + if node_ids.is_empty() { + return Vec::new(); + } + + // Single node: use the simpler single-node prompt + if node_ids.len() == 1 { + return vec![self.evaluate_node(tree, node_ids[0], context).await]; + } + + // Batch: send all nodes in one LLM call + let prompt = self.build_batch_prompt(tree, node_ids, context); + + match self + .client + .complete(&self.batch_system_prompt, &prompt) + .await + { + Ok(response) => self.parse_batch_response(&response, tree, node_ids), + Err(e) => { + tracing::warn!( + "Batch LLM evaluation failed ({}), falling back to single evaluation: {}", + node_ids.len(), + e + ); + // Fallback: evaluate individually (still works, just slower) + let mut results = Vec::with_capacity(node_ids.len()); + for &node_id in node_ids { + results.push(self.evaluate_node(tree, node_id, context).await); + } + results + } } - results } fn name(&self) -> &'static str { diff --git a/rust/src/retrieval/strategy/mod.rs b/rust/src/retrieval/strategy/mod.rs index 27e3a8c3..44bdf880 100644 --- a/rust/src/retrieval/strategy/mod.rs +++ b/rust/src/retrieval/strategy/mod.rs @@ -6,7 +6,6 @@ //! This module provides several retrieval strategies: //! //! - **KeywordStrategy**: Fast keyword matching using TF-IDF -//! - **SemanticStrategy**: Embedding-based semantic similarity //! - **LlmStrategy**: LLM-powered reasoning with ToC context //! - **HybridStrategy**: BM25 pre-filter + LLM refinement (recommended) //! - **CrossDocumentStrategy**: Multi-document retrieval with result aggregation @@ -17,7 +16,6 @@ mod hybrid; mod keyword; mod llm; mod page_range; -mod semantic; mod r#trait; pub use hybrid::{HybridConfig, HybridStrategy}; diff --git a/rust/src/retrieval/strategy/semantic.rs b/rust/src/retrieval/strategy/semantic.rs deleted file mode 100644 index 1e924538..00000000 --- a/rust/src/retrieval/strategy/semantic.rs +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Semantic (embedding-based) retrieval strategy. -//! -//! Uses vector embeddings for semantic similarity matching. - -use async_trait::async_trait; - -use super::super::RetrievalContext; -use super::super::types::{NavigationDecision, QueryComplexity}; -use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities}; -use crate::config::StrategyConfig; -use crate::document::{DocumentTree, NodeId}; - -/// Embedding model trait for semantic strategies. -#[async_trait] -pub trait EmbeddingModel: Send + Sync { - /// Generate embedding for a text. - async fn embed(&self, text: &str) -> Result, EmbeddingError>; - - /// Generate embeddings for multiple texts (batch). - async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError>; - - /// Get the dimension of embeddings. - fn dimension(&self) -> usize; -} - -/// Embedding generation error. -#[derive(Debug, thiserror::Error)] -pub enum EmbeddingError { - #[error("Failed to generate embedding: {0}")] - GenerationFailed(String), - #[error("Invalid input: {0}")] - InvalidInput(String), -} - -/// Semantic retrieval strategy using embeddings. -/// -/// Compares query embeddings with node content/summary embeddings -/// to find semantically similar content. -pub struct SemanticStrategy { - /// The embedding model to use. - model: Box, - /// Whether to cache embeddings. - cache_embeddings: bool, - /// Similarity threshold for considering a node relevant. - similarity_threshold: f32, - /// High similarity threshold for "answer" decision. - high_similarity_threshold: f32, - /// Low similarity threshold for "explore" decision. - low_similarity_threshold: f32, -} - -impl SemanticStrategy { - /// Create a new semantic strategy with the given embedding model. - pub fn new(model: Box) -> Self { - Self::with_config(model, &StrategyConfig::default()) - } - - /// Create with configuration. - pub fn with_config(model: Box, config: &StrategyConfig) -> Self { - Self { - model, - cache_embeddings: true, - similarity_threshold: config.similarity_threshold, - high_similarity_threshold: config.high_similarity_threshold, - low_similarity_threshold: config.low_similarity_threshold, - } - } - - /// Set whether to cache embeddings. - pub fn with_cache(mut self, cache: bool) -> Self { - self.cache_embeddings = cache; - self - } - - /// Set the similarity threshold. - pub fn with_threshold(mut self, threshold: f32) -> Self { - self.similarity_threshold = threshold; - self - } - - /// Calculate cosine similarity between two vectors. - fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { - if a.len() != b.len() || a.is_empty() { - return 0.0; - } - - let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); - let mag_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); - let mag_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); - - if mag_a == 0.0 || mag_b == 0.0 { - 0.0 - } else { - dot / (mag_a * mag_b) - } - } - - /// Get text to embed for a node. - fn get_embedding_text(tree: &DocumentTree, node_id: NodeId) -> String { - if let Some(node) = tree.get(node_id) { - // Prefer summary if available, otherwise use content - if !node.summary.is_empty() { - format!("{}: {}", node.title, node.summary) - } else if !node.content.is_empty() { - // Truncate long content - let content = if node.content.len() > 500 { - &node.content[..500] - } else { - &node.content - }; - format!("{}: {}", node.title, content) - } else { - node.title.clone() - } - } else { - String::new() - } - } -} - -#[async_trait] -impl RetrievalStrategy for SemanticStrategy { - async fn evaluate_node( - &self, - tree: &DocumentTree, - node_id: NodeId, - context: &RetrievalContext, - ) -> NodeEvaluation { - let node_text = Self::get_embedding_text(tree, node_id); - - if node_text.is_empty() { - return NodeEvaluation { - score: 0.0, - decision: NavigationDecision::Skip, - reasoning: Some("Empty node".to_string()), - }; - } - - // Get embeddings - let query_embedding = match self.model.embed(&context.query).await { - Ok(e) => e, - Err(e) => { - return NodeEvaluation { - score: 0.0, - decision: NavigationDecision::Skip, - reasoning: Some(format!("Embedding error: {}", e)), - }; - } - }; - - let node_embedding = match self.model.embed(&node_text).await { - Ok(e) => e, - Err(e) => { - return NodeEvaluation { - score: 0.0, - decision: NavigationDecision::Skip, - reasoning: Some(format!("Embedding error: {}", e)), - }; - } - }; - - // Calculate similarity - let similarity = Self::cosine_similarity(&query_embedding, &node_embedding); - - // Determine decision based on similarity - let decision = if similarity > self.high_similarity_threshold { - NavigationDecision::ThisIsTheAnswer - } else if similarity > self.similarity_threshold { - if tree.is_leaf(node_id) { - NavigationDecision::ThisIsTheAnswer - } else { - NavigationDecision::ExploreMore - } - } else if similarity > self.low_similarity_threshold { - NavigationDecision::ExploreMore - } else { - NavigationDecision::Skip - }; - - NodeEvaluation { - score: similarity, - decision, - reasoning: Some(format!("Semantic similarity: {:.3}", similarity)), - } - } - - async fn evaluate_nodes( - &self, - tree: &DocumentTree, - node_ids: &[NodeId], - context: &RetrievalContext, - ) -> Vec { - // Get query embedding once - let query_embedding = match self.model.embed(&context.query).await { - Ok(e) => e, - Err(e) => { - return node_ids - .iter() - .map(|_| NodeEvaluation { - score: 0.0, - decision: NavigationDecision::Skip, - reasoning: Some(format!("Embedding error: {}", e)), - }) - .collect(); - } - }; - - // Collect all node texts - let texts: Vec = node_ids - .iter() - .map(|&id| Self::get_embedding_text(tree, id)) - .collect(); - - // Batch embed all nodes - let node_embeddings = match self.model.embed_batch(&texts).await { - Ok(e) => e, - Err(e) => { - return node_ids - .iter() - .map(|_| NodeEvaluation { - score: 0.0, - decision: NavigationDecision::Skip, - reasoning: Some(format!("Embedding error: {}", e)), - }) - .collect(); - } - }; - - // Calculate similarities and determine decisions - node_ids - .iter() - .zip(node_embeddings.iter()) - .map(|(&node_id, node_embedding)| { - let similarity = Self::cosine_similarity(&query_embedding, node_embedding); - - let decision = if similarity > 0.8 { - NavigationDecision::ThisIsTheAnswer - } else if similarity > self.similarity_threshold { - if tree.is_leaf(node_id) { - NavigationDecision::ThisIsTheAnswer - } else { - NavigationDecision::ExploreMore - } - } else if similarity > 0.3 { - NavigationDecision::ExploreMore - } else { - NavigationDecision::Skip - }; - - NodeEvaluation { - score: similarity, - decision, - reasoning: Some(format!("Semantic similarity: {:.3}", similarity)), - } - }) - .collect() - } - - fn name(&self) -> &'static str { - "semantic" - } - - fn capabilities(&self) -> StrategyCapabilities { - StrategyCapabilities { - uses_llm: false, - uses_embeddings: true, - supports_sufficiency: true, - typical_latency_ms: 50, - } - } - - fn suitable_for_complexity(&self, complexity: QueryComplexity) -> bool { - matches!( - complexity, - QueryComplexity::Simple | QueryComplexity::Medium - ) - } -} diff --git a/rust/src/retrieval/types.rs b/rust/src/retrieval/types.rs index ec3e25dd..1c99e79c 100644 --- a/rust/src/retrieval/types.rs +++ b/rust/src/retrieval/types.rs @@ -36,9 +36,6 @@ pub enum StrategyPreference { /// Force keyword-based strategy (fast, no LLM). ForceKeyword, - /// Force semantic strategy (embedding-based). - ForceSemantic, - /// Force LLM strategy (deep reasoning). ForceLlm,