From f73a66f9a9cd325f1e1dc40a35a331fdec741a00 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Tue, 14 Apr 2026 11:31:51 +0800 Subject: [PATCH] feat(pilot): add binary pruning and pre-filtering for wide nodes Add binary pruning functionality to quickly filter relevant candidates before full scoring. Introduce pre-filtering using NodeScorer to reduce LLM token costs when nodes have many children. - Add Prune InterventionPoint for binary relevance filtering - Implement PrefilterConfig and PruneConfig with configurable thresholds - Add pre-filtering logic in score_candidates_detailed to narrow candidate sets using NodeScorer before LLM calls - Implement binary_prune method in LLM pilot for quick yes/no relevance decisions - Update metrics collection to track pruning interventions - Add comprehensive tests for new configuration options --- rust/src/metrics/pilot.rs | 4 +- rust/src/retrieval/pilot/config.rs | 179 +++++++++++++++++++- rust/src/retrieval/pilot/decision.rs | 3 + rust/src/retrieval/pilot/decision_scorer.rs | 82 ++++++++- rust/src/retrieval/pilot/feedback.rs | 2 + rust/src/retrieval/pilot/llm_pilot.rs | 47 +++++ rust/src/retrieval/pilot/metrics.rs | 2 +- rust/src/retrieval/pilot/mod.rs | 2 +- rust/src/retrieval/pilot/noop.rs | 5 + rust/src/retrieval/pilot/parser.rs | 42 +++++ rust/src/retrieval/pilot/prompts/builder.rs | 9 + rust/src/retrieval/pilot/trait.rs | 11 ++ 12 files changed, 380 insertions(+), 8 deletions(-) diff --git a/rust/src/metrics/pilot.rs b/rust/src/metrics/pilot.rs index fee0e011..8b424935 100644 --- a/rust/src/metrics/pilot.rs +++ b/rust/src/metrics/pilot.rs @@ -18,6 +18,8 @@ pub enum InterventionPoint { Backtrack, /// Evaluating content sufficiency. Evaluate, + /// Binary pruning for wide nodes. + Prune, } /// Helper to store f64 as u64 bits for atomic operations. @@ -87,7 +89,7 @@ impl PilotMetrics { InterventionPoint::Start => { self.start_guidance_calls.fetch_add(1, Ordering::Relaxed); } - InterventionPoint::Fork => { + InterventionPoint::Fork | InterventionPoint::Prune => { self.fork_decisions.fetch_add(1, Ordering::Relaxed); } InterventionPoint::Backtrack => { diff --git a/rust/src/retrieval/pilot/config.rs b/rust/src/retrieval/pilot/config.rs index 14eb01c4..5bcee27b 100644 --- a/rust/src/retrieval/pilot/config.rs +++ b/rust/src/retrieval/pilot/config.rs @@ -27,6 +27,10 @@ pub struct PilotConfig { pub guide_at_backtrack: bool, /// Optional path to custom prompt templates. pub prompt_template_path: Option, + /// Pre-filtering configuration for reducing candidates before Pilot. + pub prefilter: PrefilterConfig, + /// Binary pruning configuration for quick relevance filtering. + pub prune: PruneConfig, } impl Default for PilotConfig { @@ -38,6 +42,8 @@ impl Default for PilotConfig { guide_at_start: true, guide_at_backtrack: true, prompt_template_path: None, + prefilter: PrefilterConfig::default(), + prune: PruneConfig::default(), } } } @@ -51,7 +57,7 @@ impl PilotConfig { } } - /// Create a high-quality config (more LLM calls). + /// Create a high-quality config (more LLM calls, generous pre-filter). pub fn high_quality() -> Self { Self { mode: PilotMode::Aggressive, @@ -71,10 +77,20 @@ impl PilotConfig { guide_at_start: true, guide_at_backtrack: true, prompt_template_path: None, + prefilter: PrefilterConfig { + threshold: 20, + max_to_pilot: 20, + enabled: true, + }, + prune: PruneConfig { + enabled: true, + threshold: 25, + min_keep: 5, + }, } } - /// Create a low-cost config (fewer LLM calls). + /// Create a low-cost config (fewer LLM calls, aggressive pre-filter). pub fn low_cost() -> Self { Self { mode: PilotMode::Conservative, @@ -94,6 +110,16 @@ impl PilotConfig { guide_at_start: false, guide_at_backtrack: true, prompt_template_path: None, + prefilter: PrefilterConfig { + threshold: 8, + max_to_pilot: 8, + enabled: true, + }, + prune: PruneConfig { + enabled: true, + threshold: 12, + min_keep: 2, + }, } } @@ -101,6 +127,16 @@ impl PilotConfig { pub fn algorithm_only() -> Self { Self { mode: PilotMode::AlgorithmOnly, + prefilter: PrefilterConfig { + threshold: 15, + max_to_pilot: 15, + enabled: false, + }, + prune: PruneConfig { + enabled: false, + threshold: 20, + min_keep: 3, + }, ..Default::default() } } @@ -228,6 +264,88 @@ impl InterventionConfig { } } +/// Configuration for NodeScorer-based pre-filtering before Pilot scoring. +/// +/// When a node has many children, sending all to the LLM is wasteful. +/// Pre-filtering uses cheap NodeScorer (keyword/BM25) to narrow the +/// candidate set before expensive Pilot (LLM) scoring. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PrefilterConfig { + /// Minimum number of candidates to trigger pre-filtering. + /// + /// When `candidates.len()` exceeds this threshold, NodeScorer + /// pre-filters before sending to Pilot. + /// Default: 15. + pub threshold: usize, + + /// Maximum number of candidates passed to Pilot after pre-filtering. + /// + /// NodeScorer's top-N are kept; the rest get NodeScorer-only scores. + /// Default: 15. + pub max_to_pilot: usize, + + /// Whether pre-filtering is enabled. + /// Default: true. + pub enabled: bool, +} + +impl Default for PrefilterConfig { + fn default() -> Self { + Self { + threshold: 15, + max_to_pilot: 15, + enabled: true, + } + } +} + +impl PrefilterConfig { + /// Check if pre-filtering should be applied given the candidate count. + pub fn should_prefilter(&self, candidate_count: usize) -> bool { + self.enabled && candidate_count > self.threshold + } +} + +/// Configuration for binary pruning before full Pilot scoring. +/// +/// After P2 pre-filtering, if candidates still exceed this threshold, +/// a lightweight LLM call asks "which are relevant?" before the full +/// scoring call. This reduces the number of candidates that receive +/// expensive detailed scoring. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PruneConfig { + /// Whether binary pruning is enabled. + /// Default: true. + pub enabled: bool, + + /// Trigger threshold — binary prune activates when the candidate + /// count (after P2 pre-filtering) exceeds this value. + /// Default: 20. + pub threshold: usize, + + /// Minimum candidates to keep after pruning, even if LLM says + /// fewer are relevant. Prevents over-aggressive pruning. + /// Default: 3. + pub min_keep: usize, +} + +impl Default for PruneConfig { + fn default() -> Self { + Self { + enabled: true, + threshold: 20, + min_keep: 3, + } + } +} + +impl PruneConfig { + /// Check if binary pruning should be applied given the candidate count. + pub fn should_prune(&self, candidate_count: usize) -> bool { + self.enabled && candidate_count > self.threshold + } +} + #[cfg(test)] mod tests { use super::*; @@ -269,11 +387,68 @@ mod tests { fn test_pilot_config_presets() { let high = PilotConfig::high_quality(); assert_eq!(high.mode, PilotMode::Aggressive); + assert!(high.prefilter.enabled); + assert_eq!(high.prefilter.threshold, 20); let low = PilotConfig::low_cost(); assert_eq!(low.mode, PilotMode::Conservative); + assert!(low.prefilter.enabled); + assert_eq!(low.prefilter.threshold, 8); let algo = PilotConfig::algorithm_only(); assert_eq!(algo.mode, PilotMode::AlgorithmOnly); + assert!(!algo.prefilter.enabled); + } + + #[test] + fn test_prefilter_config_default() { + let cfg = PrefilterConfig::default(); + assert!(cfg.enabled); + assert_eq!(cfg.threshold, 15); + assert_eq!(cfg.max_to_pilot, 15); + } + + #[test] + fn test_prefilter_should_prefilter() { + let cfg = PrefilterConfig::default(); + assert!(!cfg.should_prefilter(15)); // at threshold + assert!(!cfg.should_prefilter(10)); // below + assert!(cfg.should_prefilter(16)); // above + + let disabled = PrefilterConfig { enabled: false, ..Default::default() }; + assert!(!disabled.should_prefilter(100)); + } + + #[test] + fn test_prune_config_default() { + let cfg = PruneConfig::default(); + assert!(cfg.enabled); + assert_eq!(cfg.threshold, 20); + assert_eq!(cfg.min_keep, 3); + } + + #[test] + fn test_prune_should_prune() { + let cfg = PruneConfig::default(); + assert!(!cfg.should_prune(20)); // at threshold + assert!(!cfg.should_prune(15)); // below + assert!(cfg.should_prune(21)); // above + + let disabled = PruneConfig { enabled: false, ..Default::default() }; + assert!(!disabled.should_prune(100)); + } + + #[test] + fn test_pilot_config_presets_prune() { + let high = PilotConfig::high_quality(); + assert!(high.prune.enabled); + assert_eq!(high.prune.threshold, 25); + + let low = PilotConfig::low_cost(); + assert!(low.prune.enabled); + assert_eq!(low.prune.threshold, 12); + + let algo = PilotConfig::algorithm_only(); + assert!(!algo.prune.enabled); } } diff --git a/rust/src/retrieval/pilot/decision.rs b/rust/src/retrieval/pilot/decision.rs index 06587f93..31ee5677 100644 --- a/rust/src/retrieval/pilot/decision.rs +++ b/rust/src/retrieval/pilot/decision.rs @@ -220,6 +220,8 @@ pub enum InterventionPoint { Backtrack, /// Evaluating a specific node for relevance. Evaluate, + /// Binary pruning — quick yes/no relevance filter for wide nodes. + Prune, } impl InterventionPoint { @@ -230,6 +232,7 @@ impl InterventionPoint { Self::Fork => "fork", Self::Backtrack => "backtrack", Self::Evaluate => "evaluate", + Self::Prune => "prune", } } } diff --git a/rust/src/retrieval/pilot/decision_scorer.rs b/rust/src/retrieval/pilot/decision_scorer.rs index 4b9fe930..a1158dce 100644 --- a/rust/src/retrieval/pilot/decision_scorer.rs +++ b/rust/src/retrieval/pilot/decision_scorer.rs @@ -112,6 +112,14 @@ pub struct ScoredCandidate { /// from the Pilot. Use this when the search algorithm needs to /// record why each path step was taken (e.g., for beam search /// reasoning history). +/// +/// # Pre-filtering +/// +/// When a node has many children (exceeding `prefilter.threshold`), +/// NodeScorer pre-filters candidates before sending to Pilot. This +/// reduces LLM token cost and latency. Candidates filtered out still +/// receive NodeScorer-only scores in the final merge, so no results +/// are lost. pub async fn score_candidates_detailed( tree: &DocumentTree, candidates: &[NodeId], @@ -139,20 +147,88 @@ pub async fn score_candidates_detailed( // Determine parent node (last in path) for cache key let parent = path.last().copied().unwrap_or(tree.root()); + // === PRE-FILTERING === + // When candidates exceed the threshold, use NodeScorer to narrow + // the set before sending to Pilot (LLM). Filtered-out candidates + // still get NodeScorer-only scores in the final merge below. + let prefilter_cfg = &p.config().prefilter; + let pilot_candidates: Vec = if prefilter_cfg.should_prefilter(candidates.len()) { + let scorer = NodeScorer::new(ScoringContext::new(query)); + let mut sorted = scorer.score_and_sort(tree, candidates); + let pilot_max = prefilter_cfg.max_to_pilot.min(candidates.len()); + sorted.truncate(pilot_max); + let ids: Vec = sorted.into_iter().map(|(id, _)| id).collect(); + tracing::debug!( + "Pre-filtered: {} candidates -> {} to Pilot (threshold={})", + candidates.len(), + ids.len(), + prefilter_cfg.threshold, + ); + ids + } else { + candidates.to_vec() + }; + + // === BINARY PRUNING === + // After P2 pre-filtering, if candidates still exceed the prune + // threshold, ask Pilot for a quick yes/no filter before the + // expensive full-scoring call. + let prune_cfg = &p.config().prune; + let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) { + let mut prune_state = SearchState::new( + tree, query, path, &pilot_candidates, visited, + ); + prune_state.step_reasons = step_reasons; + + if let Some(relevant_ids) = p.binary_prune(&prune_state).await { + let relevant_set: HashSet = relevant_ids.iter().copied().collect(); + let mut pruned: Vec = pilot_candidates + .iter() + .filter(|id| relevant_set.contains(id)) + .copied() + .collect(); + + // Enforce min_keep to prevent over-aggressive pruning + if pruned.len() < prune_cfg.min_keep { + // Fill from the top of pilot_candidates that weren't pruned + for id in &pilot_candidates { + if pruned.len() >= prune_cfg.min_keep { + break; + } + if !relevant_set.contains(id) { + pruned.push(*id); + } + } + } + + tracing::debug!( + "Binary prune: {} candidates -> {} relevant (min_keep={})", + pilot_candidates.len(), + pruned.len(), + prune_cfg.min_keep, + ); + pruned + } else { + pilot_candidates + } + } else { + pilot_candidates + }; + // Check cache first let decision = if let Some(c) = cache { if let Some(cached) = c.get(query, parent).await { tracing::trace!("Pilot cache hit for parent={:?}", parent); cached } else { - let mut state = SearchState::new(tree, query, path, candidates, visited); + let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited); state.step_reasons = step_reasons; let d = p.decide(&state).await; c.put(query, parent, &d).await; d } } else { - let mut state = SearchState::new(tree, query, path, candidates, visited); + let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited); state.step_reasons = step_reasons; p.decide(&state).await }; @@ -163,7 +239,7 @@ pub async fn score_candidates_detailed( pilot_data.insert(ranked.node_id, (ranked.score, ranked.reason.clone())); } - // Compute NodeScorer fallback scores + // Compute NodeScorer fallback scores for ALL original candidates let scorer_weight = 1.0 - pilot_weight; let confidence = decision.confidence; let effective_pilot = pilot_weight * confidence; diff --git a/rust/src/retrieval/pilot/feedback.rs b/rust/src/retrieval/pilot/feedback.rs index 495fae34..051a3f04 100644 --- a/rust/src/retrieval/pilot/feedback.rs +++ b/rust/src/retrieval/pilot/feedback.rs @@ -180,6 +180,7 @@ impl InterventionStats { InterventionPoint::Fork => &self.fork, InterventionPoint::Backtrack => &self.backtrack, InterventionPoint::Evaluate => &self.evaluate, + InterventionPoint::Prune => &self.fork, // Prune reuses fork stats } } @@ -190,6 +191,7 @@ impl InterventionStats { InterventionPoint::Fork => &mut self.fork, InterventionPoint::Backtrack => &mut self.backtrack, InterventionPoint::Evaluate => &mut self.evaluate, + InterventionPoint::Prune => &mut self.fork, // Prune reuses fork stats } } } diff --git a/rust/src/retrieval/pilot/llm_pilot.rs b/rust/src/retrieval/pilot/llm_pilot.rs index df8e3b02..d2b6d955 100644 --- a/rust/src/retrieval/pilot/llm_pilot.rs +++ b/rust/src/retrieval/pilot/llm_pilot.rs @@ -700,6 +700,53 @@ impl Pilot for LlmPilot { ) } + async fn binary_prune(&self, state: &SearchState<'_>) -> Option> { + if !self.has_budget() { + debug!("Budget exhausted, cannot binary prune"); + return None; + } + + let context = self.context_builder.build(state); + + let candidate_info: Vec = state + .candidates + .iter() + .enumerate() + .filter_map(|(i, &node_id)| { + state.tree.get(node_id).map(|node| super::parser::CandidateInfo { + node_id, + title: node.title.clone(), + index: i, + }) + }) + .collect(); + + let decision = self + .call_llm(InterventionPoint::Prune, &context, &candidate_info) + .await; + + // Extract relevant node IDs from ranked candidates (score > 0.5 means relevant) + let relevant: Vec = decision + .ranked_candidates + .iter() + .filter(|c| c.score > 0.5) + .map(|c| c.node_id) + .collect(); + + if relevant.is_empty() { + debug!("Binary prune: LLM marked no candidates as relevant"); + return None; + } + + debug!( + "Binary prune: {} of {} candidates marked relevant", + relevant.len(), + state.candidates.len() + ); + + Some(relevant) + } + fn config(&self) -> &PilotConfig { &self.config } diff --git a/rust/src/retrieval/pilot/metrics.rs b/rust/src/retrieval/pilot/metrics.rs index b97a1977..cf0f12b3 100644 --- a/rust/src/retrieval/pilot/metrics.rs +++ b/rust/src/retrieval/pilot/metrics.rs @@ -232,7 +232,7 @@ impl MetricsCollector { InterventionPoint::Start => { self.start_interventions.fetch_add(1, Ordering::Relaxed); } - InterventionPoint::Fork => { + InterventionPoint::Fork | InterventionPoint::Prune => { self.fork_interventions.fetch_add(1, Ordering::Relaxed); } InterventionPoint::Backtrack => { diff --git a/rust/src/retrieval/pilot/mod.rs b/rust/src/retrieval/pilot/mod.rs index 0ba28975..fd18f92a 100644 --- a/rust/src/retrieval/pilot/mod.rs +++ b/rust/src/retrieval/pilot/mod.rs @@ -47,7 +47,7 @@ mod r#trait; mod scorer; pub use complexity::detect_with_llm; -pub use config::PilotConfig; +pub use config::{PilotConfig, PrefilterConfig, PruneConfig}; pub use decision::{InterventionPoint, PilotDecision}; pub use decision_scorer::{PilotDecisionCache, ScoredCandidate, score_candidates, score_candidates_detailed}; pub use llm_pilot::LlmPilot; diff --git a/rust/src/retrieval/pilot/noop.rs b/rust/src/retrieval/pilot/noop.rs index fa2fba39..e5159276 100644 --- a/rust/src/retrieval/pilot/noop.rs +++ b/rust/src/retrieval/pilot/noop.rs @@ -84,6 +84,11 @@ impl Pilot for NoopPilot { None } + async fn binary_prune(&self, _state: &SearchState<'_>) -> Option> { + // NoopPilot does not support binary pruning + None + } + fn config(&self) -> &PilotConfig { &self.config } diff --git a/rust/src/retrieval/pilot/parser.rs b/rust/src/retrieval/pilot/parser.rs index f7c2fe85..d79a246f 100644 --- a/rust/src/retrieval/pilot/parser.rs +++ b/rust/src/retrieval/pilot/parser.rs @@ -52,6 +52,12 @@ pub struct LlmResponse { /// Reasoning for the decision. #[serde(default)] pub reasoning: String, + /// Relevant candidate indices from PRUNE response (binary yes/no). + #[serde(default)] + pub relevant_indices: Vec, + /// Alternative field name some LLMs use for relevant indices. + #[serde(default)] + pub relevant: Vec, } /// Custom deserializer for confidence that accepts both float and string. @@ -664,6 +670,42 @@ impl ResponseParser { } } + // Handle PRUNE response format: relevant_indices + if ranked_candidates.is_empty() { + let indices: Vec = if !llm_response.relevant_indices.is_empty() { + llm_response.relevant_indices.clone() + } else if !llm_response.relevant.is_empty() { + llm_response.relevant.clone() + } else { + Vec::new() + }; + + for idx in &indices { + if *idx < candidates.len() { + ranked_candidates.push(RankedCandidate { + node_id: candidates[*idx].node_id, + score: 1.0, // Relevant = high score + reason: Some(format!("Marked relevant (index {})", idx)), + }); + } + } + + // Non-relevant candidates get low score (for completeness) + if !ranked_candidates.is_empty() { + let relevant_ids: std::collections::HashSet = + ranked_candidates.iter().map(|rc| rc.node_id).collect(); + for candidate in candidates { + if !relevant_ids.contains(&candidate.node_id) { + ranked_candidates.push(RankedCandidate { + node_id: candidate.node_id, + score: 0.1, // Not relevant + reason: None, + }); + } + } + } + } + // Convert direction let direction = match llm_response.direction { DirectionResponse::GoDeeper => SearchDirection::GoDeeper { diff --git a/rust/src/retrieval/pilot/prompts/builder.rs b/rust/src/retrieval/pilot/prompts/builder.rs index 0d03f09d..c5301ad2 100644 --- a/rust/src/retrieval/pilot/prompts/builder.rs +++ b/rust/src/retrieval/pilot/prompts/builder.rs @@ -81,6 +81,7 @@ impl PromptBuilder { InterventionPoint::Fork => self.build_fork(context), InterventionPoint::Backtrack => self.build_backtrack(context), InterventionPoint::Evaluate => self.build_evaluate(context), + InterventionPoint::Prune => self.build_fork(context), // Prune reuses fork template } } @@ -175,6 +176,7 @@ impl PromptBuilder { InterventionPoint::Fork => &self.fork_template, InterventionPoint::Backtrack => &self.backtrack_template, InterventionPoint::Evaluate => &self.evaluate_template, + InterventionPoint::Prune => &self.fork_template, // Prune reuses fork template } } @@ -215,6 +217,13 @@ impl PromptBuilder { "direction": "go_deeper|found_answer", "confidence": 0.0-1.0, "reasoning": "explanation" +}"# + } + InterventionPoint::Prune => { + r#"{ + "relevant_indices": [0, 2, 5], + "confidence": 0.0-1.0, + "reasoning": "explanation" }"# } } diff --git a/rust/src/retrieval/pilot/trait.rs b/rust/src/retrieval/pilot/trait.rs index bc8d136a..fc99ee53 100644 --- a/rust/src/retrieval/pilot/trait.rs +++ b/rust/src/retrieval/pilot/trait.rs @@ -194,6 +194,17 @@ pub trait Pilot: Send + Sync { /// Returns `None` if no guidance is available. async fn guide_backtrack(&self, state: &SearchState<'_>) -> Option; + /// Binary prune — quick relevance filter for wide nodes. + /// + /// Called after P2 pre-filtering when candidates still exceed the + /// prune threshold. Asks the LLM a simple yes/no question per + /// candidate instead of full scoring. Returns the subset of + /// candidate node IDs deemed relevant. + /// + /// Returns `None` if no pruning guidance is available (e.g. budget + /// exhausted, not supported). + async fn binary_prune(&self, state: &SearchState<'_>) -> Option>; + /// Get the current configuration. fn config(&self) -> &PilotConfig;