diff --git a/rust/src/retrieval/pilot/complexity.rs b/rust/src/retrieval/pilot/complexity.rs index 5d77ca5b..8348b4f8 100644 --- a/rust/src/retrieval/pilot/complexity.rs +++ b/rust/src/retrieval/pilot/complexity.rs @@ -25,10 +25,7 @@ const USER_PROMPT: &str = include_str!("prompts/user_complexity.txt"); /// Detect query complexity using LLM. /// /// Returns `None` if the LLM call fails (caller should fall back to heuristic). -pub async fn detect_with_llm( - client: &LlmClient, - query: &str, -) -> Option { +pub async fn detect_with_llm(client: &LlmClient, query: &str) -> Option { let user = USER_PROMPT.replace("{query}", query); let resp: ComplexityResponse = client diff --git a/rust/src/retrieval/pilot/config.rs b/rust/src/retrieval/pilot/config.rs index 5bcee27b..1d2c3a04 100644 --- a/rust/src/retrieval/pilot/config.rs +++ b/rust/src/retrieval/pilot/config.rs @@ -413,9 +413,12 @@ mod tests { let cfg = PrefilterConfig::default(); assert!(!cfg.should_prefilter(15)); // at threshold assert!(!cfg.should_prefilter(10)); // below - assert!(cfg.should_prefilter(16)); // above + assert!(cfg.should_prefilter(16)); // above - let disabled = PrefilterConfig { enabled: false, ..Default::default() }; + let disabled = PrefilterConfig { + enabled: false, + ..Default::default() + }; assert!(!disabled.should_prefilter(100)); } @@ -432,9 +435,12 @@ mod tests { let cfg = PruneConfig::default(); assert!(!cfg.should_prune(20)); // at threshold assert!(!cfg.should_prune(15)); // below - assert!(cfg.should_prune(21)); // above + assert!(cfg.should_prune(21)); // above - let disabled = PruneConfig { enabled: false, ..Default::default() }; + let disabled = PruneConfig { + enabled: false, + ..Default::default() + }; assert!(!disabled.should_prune(100)); } diff --git a/rust/src/retrieval/pilot/decision_scorer.rs b/rust/src/retrieval/pilot/decision_scorer.rs index a1158dce..0169169a 100644 --- a/rust/src/retrieval/pilot/decision_scorer.rs +++ b/rust/src/retrieval/pilot/decision_scorer.rs @@ -89,7 +89,15 @@ pub async fn score_candidates( step_reasons: Option<&[Option]>, ) -> Vec<(NodeId, f32)> { let scored = score_candidates_detailed( - tree, candidates, query, pilot, path, visited, pilot_weight, cache, step_reasons, + tree, + candidates, + query, + pilot, + path, + visited, + pilot_weight, + cache, + step_reasons, ) .await; scored.into_iter().map(|s| (s.node_id, s.score)).collect() @@ -175,9 +183,7 @@ pub async fn score_candidates_detailed( // expensive full-scoring call. let prune_cfg = &p.config().prune; let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) { - let mut prune_state = SearchState::new( - tree, query, path, &pilot_candidates, visited, - ); + let mut prune_state = SearchState::new(tree, query, path, &pilot_candidates, visited); prune_state.step_reasons = step_reasons; if let Some(relevant_ids) = p.binary_prune(&prune_state).await { @@ -250,7 +256,8 @@ pub async fn score_candidates_detailed( .iter() .map(|&node_id| { let algo_score = scorer.score(tree, node_id); - let (p_score, reason) = pilot_data.get(&node_id) + let (p_score, reason) = pilot_data + .get(&node_id) .map(|(s, r)| (*s, r.clone())) .unwrap_or((0.0, None)); @@ -261,11 +268,19 @@ pub async fn score_candidates_detailed( algo_score }; - ScoredCandidate { node_id, score: final_score, reason } + ScoredCandidate { + node_id, + score: final_score, + reason, + } }) .collect(); - scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); + scored.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); scored } @@ -289,7 +304,11 @@ fn score_with_scorer_detailed( scorer .score_and_sort(tree, candidates) .into_iter() - .map(|(node_id, score)| ScoredCandidate { node_id, score, reason: None }) + .map(|(node_id, score)| ScoredCandidate { + node_id, + score, + reason: None, + }) .collect() } diff --git a/rust/src/retrieval/pilot/llm_pilot.rs b/rust/src/retrieval/pilot/llm_pilot.rs index d2b6d955..88e2ee03 100644 --- a/rust/src/retrieval/pilot/llm_pilot.rs +++ b/rust/src/retrieval/pilot/llm_pilot.rs @@ -713,11 +713,14 @@ impl Pilot for LlmPilot { .iter() .enumerate() .filter_map(|(i, &node_id)| { - state.tree.get(node_id).map(|node| super::parser::CandidateInfo { - node_id, - title: node.title.clone(), - index: i, - }) + state + .tree + .get(node_id) + .map(|node| super::parser::CandidateInfo { + node_id, + title: node.title.clone(), + index: i, + }) }) .collect(); diff --git a/rust/src/retrieval/pilot/mod.rs b/rust/src/retrieval/pilot/mod.rs index fd18f92a..b14aa997 100644 --- a/rust/src/retrieval/pilot/mod.rs +++ b/rust/src/retrieval/pilot/mod.rs @@ -43,13 +43,13 @@ mod metrics; mod noop; mod parser; mod prompts; -mod r#trait; mod scorer; +mod r#trait; pub use complexity::detect_with_llm; -pub use config::{PilotConfig, PrefilterConfig, PruneConfig}; +pub use config::PilotConfig; pub use decision::{InterventionPoint, PilotDecision}; -pub use decision_scorer::{PilotDecisionCache, ScoredCandidate, score_candidates, score_candidates_detailed}; +pub use decision_scorer::{PilotDecisionCache, score_candidates, score_candidates_detailed}; pub use llm_pilot::LlmPilot; -pub use r#trait::{Pilot, SearchState}; pub use scorer::{NodeScorer, ScoringContext}; +pub use r#trait::{Pilot, SearchState}; diff --git a/rust/src/retrieval/search/beam.rs b/rust/src/retrieval/search/beam.rs index 265a484c..649ef1ee 100644 --- a/rust/src/retrieval/search/beam.rs +++ b/rust/src/retrieval/search/beam.rs @@ -20,10 +20,10 @@ use tracing::debug; use super::super::RetrievalContext; use super::super::types::{NavigationDecision, NavigationStep, SearchPath}; -use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed}; use super::{SearchConfig, SearchResult, SearchTree}; use crate::document::{DocumentTree, NodeId}; use crate::retrieval::pilot::{Pilot, SearchState}; +use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed}; /// Maximum entries in the fallback stack relative to beam width. const FALLBACK_STACK_MULTIPLIER: usize = 3; @@ -91,7 +91,11 @@ impl BeamSearch { if let Some(min_idx) = fallback_stack .iter() .enumerate() - .min_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)) + .min_by(|(_, a), (_, b)| { + a.score + .partial_cmp(&b.score) + .unwrap_or(std::cmp::Ordering::Equal) + }) .map(|(i, _)| i) { if entry.score > fallback_stack[min_idx].score { @@ -114,7 +118,11 @@ impl BeamSearch { let max_idx = fallback_stack .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)) + .max_by(|(_, a), (_, b)| { + a.score + .partial_cmp(&b.score) + .unwrap_or(std::cmp::Ordering::Equal) + }) .map(|(i, _)| i)?; Some(fallback_stack.swap_remove(max_idx)) } @@ -287,11 +295,8 @@ impl BeamSearch { .unwrap_or(std::cmp::Ordering::Equal) }); - let mut current_beam: Vec = sorted_initial - .iter() - .take(beam_width) - .cloned() - .collect(); + let mut current_beam: Vec = + sorted_initial.iter().take(beam_width).cloned().collect(); // Remaining candidates go to fallback stack for path in sorted_initial.iter().skip(beam_width) { @@ -421,16 +426,14 @@ impl BeamSearch { // Keep top beam_width in the beam, shelve the rest let mut beam_candidates = next_beam; - let overflow: Vec = beam_candidates.split_off(beam_width.min(beam_candidates.len())); + let overflow: Vec = + beam_candidates.split_off(beam_width.min(beam_candidates.len())); for path in overflow { let score = path.score; Self::push_fallback( &mut fallback_stack, - FallbackEntry { - path, - score, - }, + FallbackEntry { path, score }, config.min_score, config.fallback_score_ratio, max_fallback_size, @@ -566,18 +569,33 @@ mod tests { BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 }, - 0.1, 0.5, 100, + FallbackEntry { + path: SearchPath::from_node(id0, 0.3), + score: 0.3, + }, + 0.1, + 0.5, + 100, ); BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id1, 0.7), score: 0.7 }, - 0.1, 0.5, 100, + FallbackEntry { + path: SearchPath::from_node(id1, 0.7), + score: 0.7, + }, + 0.1, + 0.5, + 100, ); BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id2, 0.5), score: 0.5 }, - 0.1, 0.5, 100, + FallbackEntry { + path: SearchPath::from_node(id2, 0.5), + score: 0.5, + }, + 0.1, + 0.5, + 100, ); assert_eq!(stack.len(), 3); @@ -603,16 +621,26 @@ mod tests { // Score 0.01 with threshold 0.1 * 0.5 = 0.05 → should be rejected BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id0, 0.01), score: 0.01 }, - 0.1, 0.5, 100, + FallbackEntry { + path: SearchPath::from_node(id0, 0.01), + score: 0.01, + }, + 0.1, + 0.5, + 100, ); assert_eq!(stack.len(), 0, "Score below threshold should be rejected"); // Score 0.06 with threshold 0.05 → should be accepted BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id1, 0.06), score: 0.06 }, - 0.1, 0.5, 100, + FallbackEntry { + path: SearchPath::from_node(id1, 0.06), + score: 0.06, + }, + 0.1, + 0.5, + 100, ); assert_eq!(stack.len(), 1, "Score above threshold should be accepted"); } @@ -628,21 +656,36 @@ mod tests { // Fill to capacity (max_size=2) BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 }, - 0.1, 0.5, 2, + FallbackEntry { + path: SearchPath::from_node(id0, 0.3), + score: 0.3, + }, + 0.1, + 0.5, + 2, ); BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id1, 0.5), score: 0.5 }, - 0.1, 0.5, 2, + FallbackEntry { + path: SearchPath::from_node(id1, 0.5), + score: 0.5, + }, + 0.1, + 0.5, + 2, ); assert_eq!(stack.len(), 2); // Push a higher-score entry → should evict the lowest (0.3) BeamSearch::push_fallback( &mut stack, - FallbackEntry { path: SearchPath::from_node(id2, 0.8), score: 0.8 }, - 0.1, 0.5, 2, + FallbackEntry { + path: SearchPath::from_node(id2, 0.8), + score: 0.8, + }, + 0.1, + 0.5, + 2, ); assert_eq!(stack.len(), 2); diff --git a/rust/src/retrieval/search/greedy.rs b/rust/src/retrieval/search/greedy.rs index 2bd72ca6..f644e986 100644 --- a/rust/src/retrieval/search/greedy.rs +++ b/rust/src/retrieval/search/greedy.rs @@ -13,10 +13,10 @@ use tracing::debug; use super::super::RetrievalContext; use super::super::types::{NavigationDecision, NavigationStep, SearchPath}; -use crate::retrieval::pilot::{PilotDecisionCache, score_candidates}; use super::{SearchConfig, SearchResult, SearchTree}; use crate::document::{DocumentTree, NodeId}; use crate::retrieval::pilot::Pilot; +use crate::retrieval::pilot::{PilotDecisionCache, score_candidates}; /// Pure Pilot search — Pilot picks the best child at each layer. /// diff --git a/rust/src/retrieval/search/mcts.rs b/rust/src/retrieval/search/mcts.rs index 1ab48480..3470af73 100644 --- a/rust/src/retrieval/search/mcts.rs +++ b/rust/src/retrieval/search/mcts.rs @@ -20,10 +20,10 @@ use tracing::debug; use super::super::RetrievalContext; use super::super::types::{NavigationDecision, NavigationStep, SearchPath}; -use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, NodeScorer, ScoringContext}; use super::{SearchConfig, SearchResult, SearchTree}; use crate::document::{DocumentTree, NodeId}; use crate::retrieval::pilot::Pilot; +use crate::retrieval::pilot::{NodeScorer, PilotDecisionCache, ScoringContext, score_candidates}; /// Statistics for a node in MCTS. #[derive(Debug, Clone, Default)] diff --git a/rust/src/retrieval/stages/analyze.rs b/rust/src/retrieval/stages/analyze.rs index 34d93352..cc0e6d50 100644 --- a/rust/src/retrieval/stages/analyze.rs +++ b/rust/src/retrieval/stages/analyze.rs @@ -147,8 +147,7 @@ impl AnalyzeStage { /// Enable query decomposition and LLM-based complexity detection. pub fn with_llm_client(mut self, client: crate::llm::LlmClient) -> Self { // Use LLM client for complexity detection - self.complexity_detector = - ComplexityDetector::with_llm_client(client.clone()); + self.complexity_detector = ComplexityDetector::with_llm_client(client.clone()); // Also enable query decomposition if self.query_decomposer.is_none() { self.query_decomposer = diff --git a/rust/src/retrieval/types.rs b/rust/src/retrieval/types.rs index b9a3d17e..7a7baa1e 100644 --- a/rust/src/retrieval/types.rs +++ b/rust/src/retrieval/types.rs @@ -660,7 +660,12 @@ impl SearchPath { /// Extend the path with a new node and a reason for choosing it. #[must_use] - pub fn extend_with_reason(&self, node_id: NodeId, score: f32, reason: impl Into) -> Self { + pub fn extend_with_reason( + &self, + node_id: NodeId, + score: f32, + reason: impl Into, + ) -> Self { let mut nodes = self.nodes.clone(); let mut step_reasons = self.step_reasons.clone(); nodes.push(node_id);