From a27f9b79fd5feba689dbae80728609ccd3ebffd4 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 18:30:18 +0800 Subject: [PATCH 1/6] refactor(llm): extract executor logic into separate module - Move retry, throttle, and fallback coordination logic from LlmClient into new LlmExecutor module - Replace direct OpenAI client calls with unified executor pattern - Introduce LlmExecutor to coordinate throttle, retry, and fallback mechanisms in a single place - Update LlmClient to delegate execution to LlmExecutor - Add debug logging for LLM completion operations - Update test assertions to use new executor methods --- src/llm/client.rs | 232 +++++----------------- src/llm/executor.rs | 462 ++++++++++++++++++++++++++++++++++++++++++++ src/llm/mod.rs | 2 + 3 files changed, 512 insertions(+), 184 deletions(-) create mode 100644 src/llm/executor.rs diff --git a/src/llm/client.rs b/src/llm/client.rs index 27c81d80..40386094 100644 --- a/src/llm/client.rs +++ b/src/llm/client.rs @@ -3,14 +3,6 @@ //! Unified LLM client with retry and concurrency support. -use async_openai::{ - Client, - config::OpenAIConfig, - types::chat::{ - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequestArgs, - }, -}; use serde::de::DeserializeOwned; use std::borrow::Cow; use std::sync::Arc; @@ -18,8 +10,8 @@ use tracing::{debug, instrument}; use super::config::LlmConfig; use super::error::{LlmError, LlmResult}; +use super::executor::LlmExecutor; use super::fallback::FallbackChain; -use super::retry::with_retry; use crate::throttle::ConcurrencyController; /// Unified LLM client. @@ -60,21 +52,19 @@ use crate::throttle::ConcurrencyController; /// ``` #[derive(Clone)] pub struct LlmClient { - config: LlmConfig, - concurrency: Option>, - fallback: Option>, + executor: LlmExecutor, } impl std::fmt::Debug for LlmClient { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LlmClient") - .field("model", &self.config.model) - .field("endpoint", &self.config.endpoint) + .field("model", &self.executor.config().model) + .field("endpoint", &self.executor.config().endpoint) .field( "concurrency", - &self.concurrency.as_ref().map(|c| format!("{:?}", c)), + &self.executor.throttle().map(|c| format!("{:?}", c)), ) - .field("fallback_enabled", &self.fallback.is_some()) + .field("fallback_enabled", &self.executor.fallback().is_some()) .finish() } } @@ -83,9 +73,7 @@ impl LlmClient { /// Create a new LLM client with the given configuration. pub fn new(config: LlmConfig) -> Self { Self { - config, - concurrency: None, - fallback: None, + executor: LlmExecutor::new(config), } } @@ -115,13 +103,13 @@ impl LlmClient { /// .with_concurrency(ConcurrencyController::new(config)); /// ``` pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self { - self.concurrency = Some(Arc::new(controller)); + self.executor = self.executor.with_throttle(controller); self } /// Add concurrency control from an existing Arc. pub fn with_shared_concurrency(mut self, controller: Arc) -> Self { - self.concurrency = Some(controller); + self.executor = self.executor.with_shared_throttle(controller); self } @@ -139,29 +127,34 @@ impl LlmClient { /// assert!(client.fallback().is_some()); /// ``` pub fn with_fallback(mut self, chain: FallbackChain) -> Self { - self.fallback = Some(Arc::new(chain)); + self.executor = self.executor.with_fallback(chain); self } /// Add fallback chain from an existing Arc. pub fn with_shared_fallback(mut self, chain: Arc) -> Self { - self.fallback = Some(chain); + self.executor = self.executor.with_shared_fallback(chain); self } /// Get the configuration. pub fn config(&self) -> &LlmConfig { - &self.config + self.executor.config() } /// Get the concurrency controller (if any). pub fn concurrency(&self) -> Option<&ConcurrencyController> { - self.concurrency.as_deref() + self.executor.throttle() } /// Get the fallback chain (if any). pub fn fallback(&self) -> Option<&FallbackChain> { - self.fallback.as_deref() + self.executor.fallback() + } + + /// Get the underlying executor (for advanced usage). + pub fn executor(&self) -> &LlmExecutor { + &self.executor } /// Complete a prompt with system and user messages. @@ -169,12 +162,15 @@ impl LlmClient { /// This method includes: /// - Automatic rate limiting (if configured) /// - Automatic retry with exponential backoff - #[instrument(skip(self, system, user), fields(model = %self.config.model))] + /// - Automatic fallback on persistent errors (if configured) + #[instrument(skip(self, system, user), fields(model = %self.executor.config().model))] pub async fn complete(&self, system: &str, user: &str) -> LlmResult { - with_retry(&self.config.retry, || async { - self.complete_once(system, user).await - }) - .await + debug!( + system_len = system.len(), + user_len = user.len(), + "Starting LLM completion" + ); + self.executor.complete(system, user).await } /// Complete a prompt with custom max tokens. @@ -184,11 +180,15 @@ impl LlmClient { user: &str, max_tokens: u16, ) -> LlmResult { - with_retry(&self.config.retry, || async { - self.complete_once_with_max_tokens(system, user, max_tokens) - .await - }) - .await + debug!( + system_len = system.len(), + user_len = user.len(), + max_tokens = max_tokens, + "Starting LLM completion with max tokens" + ); + self.executor + .complete_with_max_tokens(system, user, max_tokens) + .await } /// Complete a prompt and parse the response as JSON. @@ -239,152 +239,6 @@ impl LlmClient { self.parse_json(&response) } - /// Single completion attempt (no retry). - async fn complete_once(&self, system: &str, user: &str) -> LlmResult { - // Acquire concurrency permit (rate limiter + semaphore) - let _permit = if let Some(ref cc) = self.concurrency { - Some(cc.acquire().await) - } else { - None - }; - - let api_key = self.config.get_api_key().ok_or_else(|| { - LlmError::Config( - "No API key found. Set OPENAI_API_KEY environment variable.".to_string(), - ) - })?; - - let endpoint = self.config.auto_detect_endpoint(); - let model = self.config.auto_detect_model(); - - println!("Using OpenAI API endpoint: {}", endpoint); - println!("Using OpenAI model: {}", model); - - let openai_config = OpenAIConfig::new() - .with_api_key(api_key) - .with_api_base(&endpoint); - - let client = Client::with_config(openai_config); - - // Truncate user prompt if too long - let truncated = self.truncate_prompt(user); - - let request = CreateChatCompletionRequestArgs::default() - .model(&model) - .messages([ - ChatCompletionRequestSystemMessage::from(system).into(), - ChatCompletionRequestUserMessage::from(truncated).into(), - ]) - // .max_tokens(self.config.max_tokens as u16) - .temperature(self.config.temperature) - .build() - .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?; - - debug!("Sending LLM request to {} with model {}", endpoint, model); - - let response = client.chat().create(request).await.map_err(|e| { - let msg = e.to_string(); - LlmError::from_api_message(&msg) - })?; - - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .ok_or(LlmError::NoContent)?; - - debug!("LLM response length: {} chars", content.len()); - - Ok(content) - } - - /// Single completion with custom max tokens. - async fn complete_once_with_max_tokens( - &self, - system: &str, - user: &str, - max_tokens: u16, - ) -> LlmResult { - // Acquire concurrency permit - let _permit = if let Some(ref cc) = self.concurrency { - Some(cc.acquire().await) - } else { - None - }; - - let api_key = self.config.get_api_key().ok_or_else(|| { - LlmError::Config( - "No API key found. Set OPENAI_API_KEY environment variable.".to_string(), - ) - })?; - - let endpoint = self.config.auto_detect_endpoint(); - let model = self.config.auto_detect_model(); - - let openai_config = OpenAIConfig::new() - .with_api_key(api_key) - .with_api_base(&endpoint); - - let client = Client::with_config(openai_config); - - let truncated = self.truncate_prompt(user); - - let request = CreateChatCompletionRequestArgs::default() - .model(&model) - .messages([ - ChatCompletionRequestSystemMessage::from(system).into(), - ChatCompletionRequestUserMessage::from(truncated).into(), - ]) - // .max_tokens(max_tokens) - .temperature(self.config.temperature) - .build() - .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?; - - let response = client.chat().create(request).await.map_err(|e| { - let msg = e.to_string(); - eprintln!("[LLM ERROR] API error: {}", msg); - LlmError::from_api_message(&msg) - })?; - - // Debug: log response structure - eprintln!("[LLM DEBUG] Response: {} choices", response.choices.len()); - if let Some(choice) = response.choices.first() { - eprintln!( - "[LLM DEBUG] First choice: finish_reason={:?}, has_content={}", - choice.finish_reason, - choice.message.content.is_some() - ); - } - - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .ok_or_else(|| { - eprintln!("[LLM ERROR] Response has no content"); - LlmError::NoContent - })?; - - if content.is_empty() { - eprintln!("[LLM WARN] Returned empty content for model: {}", model); - } else { - eprintln!("[LLM DEBUG] Content length: {} chars", content.len()); - } - - Ok(content) - } - - /// Truncate a prompt to a reasonable length. - fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str { - // Roughly 4 chars per token, limit to ~30k chars - const MAX_CHARS: usize = 30000; - if text.len() > MAX_CHARS { - &text[..MAX_CHARS] - } else { - text - } - } - /// Parse JSON from LLM response. fn parse_json(&self, text: &str) -> LlmResult { let json_text = self.extract_json(text); @@ -481,7 +335,7 @@ mod tests { #[test] fn test_client_creation() { let client = LlmClient::for_model("gpt-4o"); - assert_eq!(client.config.model, "gpt-4o"); + assert_eq!(client.config().model, "gpt-4o"); } #[test] @@ -491,6 +345,16 @@ mod tests { let controller = ConcurrencyController::new(ConcurrencyConfig::conservative()); let client = LlmClient::for_model("gpt-4o-mini").with_concurrency(controller); - assert!(client.concurrency.is_some()); + assert!(client.concurrency().is_some()); + } + + #[test] + fn test_client_with_fallback() { + use crate::llm::FallbackConfig; + + let fallback = FallbackChain::new(FallbackConfig::default()); + let client = LlmClient::for_model("gpt-4o").with_fallback(fallback); + + assert!(client.fallback().is_some()); } } diff --git a/src/llm/executor.rs b/src/llm/executor.rs new file mode 100644 index 00000000..8ac193f1 --- /dev/null +++ b/src/llm/executor.rs @@ -0,0 +1,462 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Unified executor coordinating throttle, retry, and fallback. +//! +//! This module provides the `LlmExecutor` which coordinates: +//! - **Throttle** — Rate limiting and concurrency control +//! - **Retry** — Exponential backoff on transient errors +//! - **Fallback** — Model/endpoint degradation on persistent failures +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ LlmExecutor │ +//! │ │ +//! │ execute() ──▶ [Throttle] ──▶ [API Call] ──▶ [Success/Error] │ +//! │ │ │ │ +//! │ acquire permit do request │ +//! │ │ │ +//! │ ┌──────────┴──────────┐ │ +//! │ ▼ ▼ │ +//! │ [Retry] [Fallback] │ +//! │ │ │ │ +//! │ exponential model/endpoint │ +//! │ backoff degradation │ +//! │ │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Example +//! +//! ```rust,no_run +//! use vectorless::llm::{LlmExecutor, LlmConfig, FallbackChain, FallbackConfig}; +//! use vectorless::throttle::{ConcurrencyController, ConcurrencyConfig}; +//! +//! # #[tokio::main] +//! # async fn main() -> vectorless::llm::LlmResult<()> { +//! let config = LlmConfig::new("gpt-4o"); +//! let throttle = ConcurrencyController::new(ConcurrencyConfig::default()); +//! let fallback = FallbackChain::new(FallbackConfig::default()); +//! +//! let executor = LlmExecutor::new(config) +//! .with_throttle(throttle) +//! .with_fallback(fallback); +//! +//! let result = executor.complete("You are helpful.", "Hello!").await?; +//! # Ok(()) +//! # } +//! ``` + +use std::sync::Arc; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use super::config::LlmConfig; +use super::error::{LlmError, LlmResult}; +use super::fallback::{FallbackChain, FallbackStep}; +use crate::throttle::ConcurrencyController; + +/// Unified executor for LLM operations. +/// +/// Coordinates throttle, retry, and fallback mechanisms. +#[derive(Clone)] +pub struct LlmExecutor { + /// LLM configuration. + config: LlmConfig, + /// Throttle controller (optional). + throttle: Option>, + /// Fallback chain (optional). + fallback: Option>, +} + +impl std::fmt::Debug for LlmExecutor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlmExecutor") + .field("model", &self.config.model) + .field("endpoint", &self.config.endpoint) + .field("has_throttle", &self.throttle.is_some()) + .field("has_fallback", &self.fallback.is_some()) + .finish() + } +} + +impl LlmExecutor { + /// Create a new executor with the given configuration. + pub fn new(config: LlmConfig) -> Self { + Self { + config, + throttle: None, + fallback: None, + } + } + + /// Create an executor with default configuration. + pub fn with_defaults() -> Self { + Self::new(LlmConfig::default()) + } + + /// Create an executor for a specific model. + pub fn for_model(model: impl Into) -> Self { + Self::new(LlmConfig::new(model)) + } + + /// Add throttle control. + pub fn with_throttle(mut self, controller: ConcurrencyController) -> Self { + self.throttle = Some(Arc::new(controller)); + self + } + + /// Add throttle control from an existing Arc. + pub fn with_shared_throttle(mut self, controller: Arc) -> Self { + self.throttle = Some(controller); + self + } + + /// Add fallback chain. + pub fn with_fallback(mut self, chain: FallbackChain) -> Self { + self.fallback = Some(Arc::new(chain)); + self + } + + /// Add fallback chain from an existing Arc. + pub fn with_shared_fallback(mut self, chain: Arc) -> Self { + self.fallback = Some(chain); + self + } + + /// Get the configuration. + pub fn config(&self) -> &LlmConfig { + &self.config + } + + /// Get the throttle controller (if any). + pub fn throttle(&self) -> Option<&ConcurrencyController> { + self.throttle.as_deref() + } + + /// Get the fallback chain (if any). + pub fn fallback(&self) -> Option<&FallbackChain> { + self.fallback.as_deref() + } + + /// Execute a completion with unified coordination. + /// + /// This method coordinates: + /// 1. Throttle: Acquire permit before API call + /// 2. Retry: Exponential backoff on transient errors + /// 3. Fallback: Model/endpoint degradation on persistent failures + pub async fn complete(&self, system: &str, user: &str) -> LlmResult { + self.execute_with_context(system, user, None).await + } + + /// Execute a completion with custom max tokens. + pub async fn complete_with_max_tokens( + &self, + system: &str, + user: &str, + max_tokens: u16, + ) -> LlmResult { + self.execute_with_context(system, user, Some(max_tokens)) + .await + } + + /// Internal execution with full coordination. + async fn execute_with_context( + &self, + system: &str, + user: &str, + max_tokens: Option, + ) -> LlmResult { + let mut attempts = 0; + let mut current_model = self.config.model.clone(); + let current_endpoint = self.config.auto_detect_endpoint(); + let mut fallback_history: Vec = vec![]; + let mut total_attempts_including_fallback = 0; + + loop { + attempts += 1; + total_attempts_including_fallback += 1; + + // Safety check: prevent infinite loops + const MAX_TOTAL_ATTEMPTS: usize = 20; + if total_attempts_including_fallback > MAX_TOTAL_ATTEMPTS { + warn!( + total_attempts = total_attempts_including_fallback, + "Exceeded maximum total attempts, aborting" + ); + return Err(LlmError::RetryExhausted { + attempts: total_attempts_including_fallback, + last_error: "Exceeded maximum total attempts including fallbacks".to_string(), + }); + } + + // Step 1: Acquire throttle permit + let _permit = self.acquire_throttle_permit().await; + + debug!( + attempt = attempts, + model = %current_model, + endpoint = %current_endpoint, + "Executing LLM request" + ); + + // Step 2: Execute the request + let result = self + .do_request(¤t_model, ¤t_endpoint, system, user, max_tokens) + .await; + + match result { + Ok(response) => { + if fallback_history.is_empty() { + debug!( + attempts = attempts, + "LLM request succeeded without fallback" + ); + } else { + info!( + attempts = attempts, + fallback_steps = fallback_history.len(), + "LLM request succeeded after fallback" + ); + } + return Ok(response); + } + Err(error) => { + // Step 3: Check if we should retry + if self.should_retry(&error, attempts) { + let delay = self.retry_delay(attempts); + warn!( + attempt = attempts, + max_attempts = self.config.retry.max_attempts, + delay_ms = delay.as_millis() as u64, + error = %error, + "LLM call failed, retrying..." + ); + tokio::time::sleep(delay).await; + continue; + } + + // Step 4: Check if we should fallback + if let Some(ref fallback) = self.fallback { + if fallback.should_fallback(&error) { + let mut fell_back = false; + + // Try next model + if let Some(next_model) = fallback.next_model(¤t_model) { + info!( + from_model = %current_model, + to_model = %next_model, + "Falling back to next model" + ); + fallback.record_fallback( + &mut fallback_history, + current_model.clone(), + Some(next_model.clone()), + current_endpoint.clone(), + None, + error.to_string(), + ); + current_model = next_model; + attempts = 0; // Reset retry counter for new model + fell_back = true; + } + + if fell_back { + continue; + } + } + } + + // Step 5: No more retries or fallbacks, return error + warn!( + attempts = attempts, + fallback_steps = fallback_history.len(), + error = %error, + "LLM call failed, no more retries or fallbacks available" + ); + return Err(error); + } + } + } + } + + /// Acquire throttle permit (if configured). + async fn acquire_throttle_permit(&self) -> Option> { + if let Some(ref throttle) = self.throttle { + throttle.acquire().await + } else { + None + } + } + + /// Check if we should retry based on error and attempt count. + fn should_retry(&self, error: &LlmError, attempts: usize) -> bool { + if attempts >= self.config.retry.max_attempts { + return false; + } + + match error { + LlmError::RateLimit(_) => self.config.retry.retry_on_rate_limit, + LlmError::Timeout(_) => true, + LlmError::Api(msg) => { + let msg_lower = msg.to_lowercase(); + msg_lower.contains("rate limit") + || msg_lower.contains("429") + || msg_lower.contains("503") + || msg_lower.contains("502") + || msg_lower.contains("timeout") + || msg_lower.contains("overloaded") + } + _ => false, + } + } + + /// Calculate retry delay for a given attempt. + fn retry_delay(&self, attempt: usize) -> Duration { + self.config.retry.delay_for_attempt(attempt - 1) + } + + /// Execute the actual API request. + async fn do_request( + &self, + model: &str, + endpoint: &str, + system: &str, + user: &str, + max_tokens: Option, + ) -> LlmResult { + use async_openai::{ + Client, + config::OpenAIConfig, + types::chat::{ + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + CreateChatCompletionRequestArgs, + }, + }; + + let api_key = self.config.get_api_key().ok_or_else(|| { + LlmError::Config( + "No API key found. Set OPENAI_API_KEY environment variable.".to_string(), + ) + })?; + + let openai_config = OpenAIConfig::new() + .with_api_key(api_key) + .with_api_base(endpoint); + + let client = Client::with_config(openai_config); + + // Truncate user prompt if too long + let truncated = self.truncate_prompt(user); + + // Build request based on whether max_tokens is specified + let request = if let Some(tokens) = max_tokens { + CreateChatCompletionRequestArgs::default() + .model(model) + .messages([ + ChatCompletionRequestSystemMessage::from(system).into(), + ChatCompletionRequestUserMessage::from(truncated).into(), + ]) + .temperature(self.config.temperature) + .max_tokens(tokens) + .build() + } else { + CreateChatCompletionRequestArgs::default() + .model(model) + .messages([ + ChatCompletionRequestSystemMessage::from(system).into(), + ChatCompletionRequestUserMessage::from(truncated).into(), + ]) + .temperature(self.config.temperature) + .build() + }; + + let request = + request.map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?; + + debug!("Sending LLM request to {} with model {}", endpoint, model); + + let response = client.chat().create(request).await.map_err(|e| { + let msg = e.to_string(); + LlmError::from_api_message(&msg) + })?; + + let content = response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .ok_or(LlmError::NoContent)?; + + debug!("LLM response length: {} chars", content.len()); + + Ok(content) + } + + /// Truncate a prompt to a reasonable length. + fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str { + // Roughly 4 chars per token, limit to ~30k chars + const MAX_CHARS: usize = 30000; + if text.len() > MAX_CHARS { + &text[..MAX_CHARS] + } else { + text + } + } +} + +impl Default for LlmExecutor { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_executor_creation() { + let executor = LlmExecutor::for_model("gpt-4o"); + assert_eq!(executor.config().model, "gpt-4o"); + assert!(executor.throttle().is_none()); + assert!(executor.fallback().is_none()); + } + + #[test] + fn test_executor_with_throttle() { + use crate::throttle::ConcurrencyConfig; + + let controller = ConcurrencyController::new(ConcurrencyConfig::conservative()); + let executor = LlmExecutor::for_model("gpt-4o-mini").with_throttle(controller); + + assert!(executor.throttle().is_some()); + } + + #[test] + fn test_should_retry() { + let executor = LlmExecutor::with_defaults(); + + // Should retry on timeout + assert!(executor.should_retry(&LlmError::Timeout("test".to_string()), 1)); + + // Should retry on rate limit (if configured) + assert!(executor.should_retry(&LlmError::RateLimit("test".to_string()), 1)); + + // Should not retry on config error + assert!(!executor.should_retry(&LlmError::Config("test".to_string()), 1)); + + // Should not retry after max attempts + assert!(!executor.should_retry(&LlmError::Timeout("test".to_string()), 100)); + } + + #[test] + fn test_retry_delay() { + let executor = LlmExecutor::with_defaults(); + + // First retry attempt (attempt 1 -> delay_for_attempt(0)) + let delay = executor.retry_delay(1); + assert_eq!(delay, Duration::from_millis(500)); + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 89a6a415..50cd8557 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -66,6 +66,7 @@ mod client; mod config; mod error; +mod executor; mod fallback; mod pool; mod retry; @@ -73,5 +74,6 @@ mod retry; pub use client::LlmClient; pub use config::{LlmConfig, LlmConfigs, RetryConfig}; pub use error::{LlmError, LlmResult}; +pub use executor::LlmExecutor; pub use fallback::{FallbackChain, FallbackConfig, FallbackResult, FallbackStep}; pub use pool::LlmPool; From 092162a7bcb4ca7958fcbd24cdae38b40964e9e5 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 19:38:59 +0800 Subject: [PATCH 2/6] feat(config): add unified LLM pool and metrics configuration - Introduce LlmPoolConfig for managing summary, retrieval, and pilot client configurations - Add comprehensive retry, throttle, and fallback mechanisms - Implement unified metrics configuration with LlmMetricsConfig, PilotMetricsConfig, and RetrievalMetricsConfig - Add cost calculation capabilities for LLM operations - Consolidate existing LLM-related config types into new modules - Update main Config struct to include llm and metrics fields feat(metrics): implement central metrics hub for unified collection - Create MetricsHub for collecting LLM, Pilot, and Retrieval metrics - Add thread-safe atomic metric tracking for multi-threaded usage - Implement comprehensive reporting capabilities - Add support for token tracking, latency measurement, and cost estimation - Include metrics for rate limits, timeouts, and fallbacks --- src/config/mod.rs | 11 + src/config/types/llm_pool.rs | 454 +++++++++++++++++++ src/config/types/metrics.rs | 212 +++++++++ src/config/types/mod.rs | 50 ++- src/lib.rs | 1 + src/metrics/hub.rs | 353 +++++++++++++++ src/metrics/llm.rs | 206 +++++++++ src/metrics/mod.rs | 63 +++ src/metrics/pilot.rs | 248 +++++++++++ src/metrics/retrieval.rs | 253 +++++++++++ src/retrieval/decompose.rs | 741 +++++++++++++++++++++++++++++++ src/retrieval/mod.rs | 7 + src/retrieval/pilot/feedback.rs | 737 ++++++++++++++++++++++++++++++ src/retrieval/pilot/llm_pilot.rs | 81 +++- src/retrieval/pilot/mod.rs | 5 + src/retrieval/search/scorer.rs | 308 ++++++++++++- vectorless.example.toml | 290 ++++++------ 17 files changed, 3837 insertions(+), 183 deletions(-) create mode 100644 src/config/types/llm_pool.rs create mode 100644 src/config/types/metrics.rs create mode 100644 src/metrics/hub.rs create mode 100644 src/metrics/llm.rs create mode 100644 src/metrics/mod.rs create mode 100644 src/metrics/pilot.rs create mode 100644 src/metrics/retrieval.rs create mode 100644 src/retrieval/decompose.rs create mode 100644 src/retrieval/pilot/feedback.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 4de3984b..42567fa8 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -95,10 +95,20 @@ pub use types::{ // Indexer IndexerConfig, // LLM configs + LlmClientConfig, LlmConfig, + LlmFallbackBehavior, + LlmFallbackConfig, + LlmMetricsConfig, + LlmOnAllFailedBehavior, + LlmPoolConfig, + MetricsConfig, OnAllFailedBehavior, + PilotMetricsConfig, // Retrieval configs RetrievalConfig, + RetrievalMetricsConfig, + RetryConfig, SearchConfig, Severity, // Storage and sufficiency @@ -106,6 +116,7 @@ pub use types::{ StrategyConfig, SufficiencyConfig, SummaryConfig, + ThrottleConfig, ValidationError, }; pub use validator::{ConfigValidator, ValidationRule}; diff --git a/src/config/types/llm_pool.rs b/src/config/types/llm_pool.rs new file mode 100644 index 00000000..18793400 --- /dev/null +++ b/src/config/types/llm_pool.rs @@ -0,0 +1,454 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Unified LLM configuration including pool, retry, throttle, and fallback. +//! +//! This module consolidates all LLM-related configuration into a single +//! cohesive structure that maps directly to the TOML configuration file. + +use serde::{Deserialize, Serialize}; + +/// Unified LLM configuration. +/// +/// Contains all settings for LLM operations including: +/// - Pool of clients for different purposes (summary, retrieval, pilot) +/// - Retry behavior +/// - Throttle/rate limiting +/// - Fallback strategy +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmPoolConfig { + /// Summary client configuration. + #[serde(default)] + pub summary: LlmClientConfig, + + /// Retrieval client configuration. + #[serde(default)] + pub retrieval: LlmClientConfig, + + /// Pilot client configuration. + #[serde(default = "default_pilot_config")] + pub pilot: LlmClientConfig, + + /// Default API key (used if not specified per-client). + #[serde(default)] + pub api_key: Option, + + /// Retry configuration. + #[serde(default)] + pub retry: RetryConfig, + + /// Throttle/rate limiting configuration. + #[serde(default)] + pub throttle: ThrottleConfig, + + /// Fallback configuration. + #[serde(default)] + pub fallback: FallbackConfig, +} + +fn default_pilot_config() -> LlmClientConfig { + LlmClientConfig { + model: "gpt-4o-mini".to_string(), + max_tokens: 300, + temperature: 0.0, + ..Default::default() + } +} + +impl Default for LlmPoolConfig { + fn default() -> Self { + Self { + summary: LlmClientConfig::default(), + retrieval: LlmClientConfig { + model: "gpt-4o".to_string(), + max_tokens: 100, + ..Default::default() + }, + pilot: default_pilot_config(), + api_key: None, + retry: RetryConfig::default(), + throttle: ThrottleConfig::default(), + fallback: FallbackConfig::default(), + } + } +} + +impl LlmPoolConfig { + /// Create a new LLM pool config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Set the default API key. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Get API key for a specific client (client-specific or default). + pub fn get_api_key_for(&self, client_key: Option<&str>) -> Option { + // First check client-specific key + if let Some(key) = client_key { + if let Some(ref k) = self.summary.api_key { + if self.summary.model == key { + return Some(k.clone()); + } + } + if let Some(ref k) = self.retrieval.api_key { + if self.retrieval.model == key { + return Some(k.clone()); + } + } + if let Some(ref k) = self.pilot.api_key { + if self.pilot.model == key { + return Some(k.clone()); + } + } + } + // Fall back to default + self.api_key.clone() + } +} + +/// Individual LLM client configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmClientConfig { + /// Model name. + #[serde(default = "default_model")] + pub model: String, + + /// API endpoint. + #[serde(default = "default_endpoint")] + pub endpoint: String, + + /// API key (optional, falls back to default). + #[serde(default)] + pub api_key: Option, + + /// Maximum tokens for responses. + #[serde(default = "default_max_tokens")] + pub max_tokens: usize, + + /// Temperature for generation. + #[serde(default = "default_temperature")] + pub temperature: f32, +} + +fn default_model() -> String { + "gpt-4o-mini".to_string() +} + +fn default_endpoint() -> String { + "https://api.openai.com/v1".to_string() +} + +fn default_max_tokens() -> usize { + 200 +} + +fn default_temperature() -> f32 { + 0.0 +} + +impl Default for LlmClientConfig { + fn default() -> Self { + Self { + model: default_model(), + endpoint: default_endpoint(), + api_key: None, + max_tokens: default_max_tokens(), + temperature: default_temperature(), + } + } +} + +impl LlmClientConfig { + /// Create a new client config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Set the model. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } + + /// Set the endpoint. + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = endpoint.into(); + self + } + + /// Set the API key. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set the max tokens. + pub fn with_max_tokens(mut self, max_tokens: usize) -> Self { + self.max_tokens = max_tokens; + self + } +} + +/// Retry configuration for LLM calls. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryConfig { + /// Maximum number of retry attempts. + #[serde(default = "default_max_attempts")] + pub max_attempts: usize, + + /// Initial delay before first retry (milliseconds). + #[serde(default = "default_initial_delay_ms")] + pub initial_delay_ms: u64, + + /// Maximum delay between retries (milliseconds). + #[serde(default = "default_max_delay_ms")] + pub max_delay_ms: u64, + + /// Multiplier for exponential backoff. + #[serde(default = "default_multiplier")] + pub multiplier: f64, + + /// Whether to retry on rate limit errors. + #[serde(default = "default_true")] + pub retry_on_rate_limit: bool, +} + +fn default_max_attempts() -> usize { + 3 +} + +fn default_initial_delay_ms() -> u64 { + 500 +} + +fn default_max_delay_ms() -> u64 { + 30000 +} + +fn default_multiplier() -> f64 { + 2.0 +} + +fn default_true() -> bool { + true +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_attempts: default_max_attempts(), + initial_delay_ms: default_initial_delay_ms(), + max_delay_ms: default_max_delay_ms(), + multiplier: default_multiplier(), + retry_on_rate_limit: default_true(), + } + } +} + +impl RetryConfig { + /// Create a new retry config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Set the max attempts. + pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { + self.max_attempts = max_attempts; + self + } + + /// Calculate delay for a given attempt (0-indexed). + pub fn delay_for_attempt(&self, attempt: usize) -> std::time::Duration { + let delay_ms = + (self.initial_delay_ms as f64) * self.multiplier.powi(attempt as i32); + let delay_ms = delay_ms.min(self.max_delay_ms as f64); + std::time::Duration::from_millis(delay_ms as u64) + } +} + +/// Throttle/rate limiting configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThrottleConfig { + /// Maximum concurrent LLM API calls. + #[serde(default = "default_max_concurrent")] + pub max_concurrent_requests: usize, + + /// Rate limit: requests per minute. + #[serde(default = "default_rpm")] + pub requests_per_minute: usize, + + /// Enable rate limiting. + #[serde(default = "default_true")] + pub enabled: bool, + + /// Enable semaphore-based concurrency limiting. + #[serde(default = "default_true")] + pub semaphore_enabled: bool, +} + +fn default_max_concurrent() -> usize { + 10 +} + +fn default_rpm() -> usize { + 500 +} + +impl Default for ThrottleConfig { + fn default() -> Self { + Self { + max_concurrent_requests: default_max_concurrent(), + requests_per_minute: default_rpm(), + enabled: default_true(), + semaphore_enabled: default_true(), + } + } +} + +impl ThrottleConfig { + /// Create a new throttle config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Set the max concurrent requests. + pub fn with_max_concurrent(mut self, max: usize) -> Self { + self.max_concurrent_requests = max; + self + } + + /// Set the requests per minute. + pub fn with_rpm(mut self, rpm: usize) -> Self { + self.requests_per_minute = rpm; + self + } +} + +/// Fallback configuration for LLM calls. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FallbackConfig { + /// Enable fallback mechanism. + #[serde(default = "default_true")] + pub enabled: bool, + + /// Fallback models in priority order. + #[serde(default = "default_fallback_models")] + pub models: Vec, + + /// Fallback endpoints (optional). + #[serde(default)] + pub endpoints: Vec, + + /// Behavior on rate limit error. + #[serde(default)] + pub on_rate_limit: FallbackBehavior, + + /// Behavior on timeout error. + #[serde(default)] + pub on_timeout: FallbackBehavior, + + /// Behavior when all attempts fail. + #[serde(default)] + pub on_all_failed: OnAllFailedBehavior, +} + +fn default_fallback_models() -> Vec { + vec!["gpt-4o-mini".to_string(), "glm-4-flash".to_string()] +} + +/// Fallback behavior on errors. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum FallbackBehavior { + /// Retry the same model. + Retry, + /// Immediately fall back to next model. + Fallback, + /// Retry first, then fall back. + #[default] + RetryThenFallback, + /// Fail immediately. + Fail, +} + +/// Behavior when all fallback attempts fail. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum OnAllFailedBehavior { + /// Return an error. + #[default] + ReturnError, + /// Return cached result if available. + ReturnCache, +} + +impl Default for FallbackConfig { + fn default() -> Self { + Self { + enabled: default_true(), + models: default_fallback_models(), + endpoints: Vec::new(), + on_rate_limit: FallbackBehavior::default(), + on_timeout: FallbackBehavior::default(), + on_all_failed: OnAllFailedBehavior::default(), + } + } +} + +impl FallbackConfig { + /// Create a new fallback config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Disable fallback. + pub fn disabled() -> Self { + Self { + enabled: false, + ..Self::default() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_llm_pool_config_defaults() { + let config = LlmPoolConfig::default(); + assert_eq!(config.summary.model, "gpt-4o-mini"); + assert_eq!(config.retrieval.model, "gpt-4o"); + assert_eq!(config.pilot.model, "gpt-4o-mini"); + assert_eq!(config.retry.max_attempts, 3); + assert_eq!(config.throttle.max_concurrent_requests, 10); + } + + #[test] + fn test_retry_delay_calculation() { + let config = RetryConfig::default(); + + // Initial delay + assert_eq!( + config.delay_for_attempt(0), + std::time::Duration::from_millis(500) + ); + + // Second attempt: 500 * 2 = 1000 + assert_eq!( + config.delay_for_attempt(1), + std::time::Duration::from_millis(1000) + ); + } + + #[test] + fn test_fallback_config_defaults() { + let config = FallbackConfig::default(); + assert!(config.enabled); + assert!(!config.models.is_empty()); + } +} diff --git a/src/config/types/metrics.rs b/src/config/types/metrics.rs new file mode 100644 index 00000000..230686ac --- /dev/null +++ b/src/config/types/metrics.rs @@ -0,0 +1,212 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Metrics configuration for unified observability. + +use serde::{Deserialize, Serialize}; + +/// Unified metrics configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MetricsConfig { + /// Enable metrics collection. + #[serde(default = "default_true")] + pub enabled: bool, + + /// Storage path for persisted metrics. + #[serde(default = "default_storage_path")] + pub storage_path: String, + + /// Retention period in days. + #[serde(default = "default_retention_days")] + pub retention_days: usize, + + /// LLM metrics configuration. + #[serde(default)] + pub llm: LlmMetricsConfig, + + /// Pilot metrics configuration. + #[serde(default)] + pub pilot: PilotMetricsConfig, + + /// Retrieval metrics configuration. + #[serde(default)] + pub retrieval: RetrievalMetricsConfig, +} + +fn default_storage_path() -> String { + "./workspace/metrics".to_string() +} + +fn default_retention_days() -> usize { + 30 +} + +fn default_true() -> bool { + true +} + +impl Default for MetricsConfig { + fn default() -> Self { + Self { + enabled: default_true(), + storage_path: default_storage_path(), + retention_days: default_retention_days(), + llm: LlmMetricsConfig::default(), + pilot: PilotMetricsConfig::default(), + retrieval: RetrievalMetricsConfig::default(), + } + } +} + +impl MetricsConfig { + /// Create a new metrics config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Disable metrics collection. + pub fn disabled() -> Self { + Self { + enabled: false, + ..Self::default() + } + } +} + +/// LLM-specific metrics configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmMetricsConfig { + /// Track token usage. + #[serde(default = "default_true")] + pub track_tokens: bool, + + /// Track latency. + #[serde(default = "default_true")] + pub track_latency: bool, + + /// Track estimated cost. + #[serde(default = "default_true")] + pub track_cost: bool, + + /// Cost per 1K input tokens (in USD). + #[serde(default = "default_cost_per_1k_input")] + pub cost_per_1k_input_tokens: f64, + + /// Cost per 1K output tokens (in USD). + #[serde(default = "default_cost_per_1k_output")] + pub cost_per_1k_output_tokens: f64, +} + +fn default_cost_per_1k_input() -> f64 { + 0.00015 // gpt-4o-mini +} + +fn default_cost_per_1k_output() -> f64 { + 0.0006 // gpt-4o-mini +} + +impl Default for LlmMetricsConfig { + fn default() -> Self { + Self { + track_tokens: default_true(), + track_latency: default_true(), + track_cost: default_true(), + cost_per_1k_input_tokens: default_cost_per_1k_input(), + cost_per_1k_output_tokens: default_cost_per_1k_output(), + } + } +} + +impl LlmMetricsConfig { + /// Calculate cost for given tokens. + pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 { + (input_tokens as f64 / 1000.0) * self.cost_per_1k_input_tokens + + (output_tokens as f64 / 1000.0) * self.cost_per_1k_output_tokens + } +} + +/// Pilot-specific metrics configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PilotMetricsConfig { + /// Track Pilot decisions. + #[serde(default = "default_true")] + pub track_decisions: bool, + + /// Track decision accuracy (requires feedback). + #[serde(default = "default_true")] + pub track_accuracy: bool, + + /// Track user feedback. + #[serde(default = "default_true")] + pub track_feedback: bool, +} + +impl Default for PilotMetricsConfig { + fn default() -> Self { + Self { + track_decisions: default_true(), + track_accuracy: default_true(), + track_feedback: default_true(), + } + } +} + +/// Retrieval-specific metrics configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetrievalMetricsConfig { + /// Track search paths. + #[serde(default = "default_true")] + pub track_paths: bool, + + /// Track relevance scores. + #[serde(default = "default_true")] + pub track_scores: bool, + + /// Track iterations. + #[serde(default = "default_true")] + pub track_iterations: bool, + + /// Track cache hits/misses. + #[serde(default = "default_true")] + pub track_cache: bool, +} + +impl Default for RetrievalMetricsConfig { + fn default() -> Self { + Self { + track_paths: default_true(), + track_scores: default_true(), + track_iterations: default_true(), + track_cache: default_true(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metrics_config_defaults() { + let config = MetricsConfig::default(); + assert!(config.enabled); + assert_eq!(config.retention_days, 30); + } + + #[test] + fn test_llm_cost_calculation() { + let config = LlmMetricsConfig::default(); + + // 1000 input + 500 output tokens + let cost = config.calculate_cost(1000, 500); + + // 1 * 0.00015 + 0.5 * 0.0006 = 0.00015 + 0.0003 = 0.00045 + assert!((cost - 0.00045).abs() < 0.000001); + } + + #[test] + fn test_disabled_metrics() { + let config = MetricsConfig::disabled(); + assert!(!config.enabled); + } +} diff --git a/src/config/types/mod.rs b/src/config/types/mod.rs index 78b6c37e..ab397188 100644 --- a/src/config/types/mod.rs +++ b/src/config/types/mod.rs @@ -11,6 +11,8 @@ mod content; mod fallback; mod indexer; mod llm; +mod llm_pool; +mod metrics; mod retrieval; mod storage; @@ -21,6 +23,14 @@ pub use content::ContentAggregatorConfig; pub use fallback::{FallbackBehavior, FallbackConfig, OnAllFailedBehavior}; pub use indexer::IndexerConfig; pub use llm::{LlmConfig, SummaryConfig}; +pub use llm_pool::{ + FallbackBehavior as LlmFallbackBehavior, FallbackConfig as LlmFallbackConfig, + LlmClientConfig, LlmPoolConfig, OnAllFailedBehavior as LlmOnAllFailedBehavior, RetryConfig, + ThrottleConfig, +}; +pub use metrics::{ + LlmMetricsConfig, MetricsConfig, PilotMetricsConfig, RetrievalMetricsConfig, +}; pub use retrieval::{RetrievalConfig, SearchConfig}; pub use storage::{ CacheConfig, CompressionAlgorithm, CompressionConfig, StorageConfig, StrategyConfig, @@ -30,11 +40,19 @@ pub use storage::{ /// Main configuration for vectorless. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { + /// Unified LLM configuration (pool, retry, throttle, fallback). + #[serde(default)] + pub llm: LlmPoolConfig, + + /// Unified metrics configuration. + #[serde(default)] + pub metrics: MetricsConfig, + /// Indexer configuration. #[serde(default)] pub indexer: IndexerConfig, - /// Summary model configuration. + /// Summary model configuration (legacy, prefer llm.summary). #[serde(default)] pub summary: SummaryConfig, @@ -46,11 +64,11 @@ pub struct Config { #[serde(default)] pub storage: StorageConfig, - /// Concurrency control configuration. + /// Concurrency control configuration (legacy, prefer llm.throttle). #[serde(default)] pub concurrency: ConcurrencyConfig, - /// Fallback/error recovery configuration. + /// Fallback/error recovery configuration (legacy, prefer llm.fallback). #[serde(default)] pub fallback: FallbackConfig, } @@ -58,6 +76,8 @@ pub struct Config { impl Default for Config { fn default() -> Self { Self { + llm: LlmPoolConfig::default(), + metrics: MetricsConfig::default(), indexer: IndexerConfig::default(), summary: SummaryConfig::default(), retrieval: RetrievalConfig::default(), @@ -74,6 +94,18 @@ impl Config { Self::default() } + /// Set the LLM pool configuration. + pub fn with_llm(mut self, llm: LlmPoolConfig) -> Self { + self.llm = llm; + self + } + + /// Set the metrics configuration. + pub fn with_metrics(mut self, metrics: MetricsConfig) -> Self { + self.metrics = metrics; + self + } + /// Set the indexer configuration. pub fn with_indexer(mut self, indexer: IndexerConfig) -> Self { self.indexer = indexer; @@ -307,6 +339,18 @@ mod tests { assert_eq!(config.summary.model, "gpt-4o-mini"); assert_eq!(config.retrieval.model, "gpt-4o"); assert_eq!(config.concurrency.max_concurrent_requests, 10); + // New fields + assert!(config.llm.summary.model == "gpt-4o-mini"); + assert!(config.metrics.enabled); + } + + #[test] + fn test_llm_pool_config_defaults() { + let config = LlmPoolConfig::default(); + assert_eq!(config.summary.model, "gpt-4o-mini"); + assert_eq!(config.retrieval.model, "gpt-4o"); + assert_eq!(config.retry.max_attempts, 3); + assert_eq!(config.throttle.max_concurrent_requests, 10); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 3065f4d0..51657c52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,6 +110,7 @@ pub mod document; pub mod error; pub mod index; pub mod llm; +pub mod metrics; pub mod parser; pub mod retrieval; pub mod storage; diff --git a/src/metrics/hub.rs b/src/metrics/hub.rs new file mode 100644 index 00000000..73ab6b75 --- /dev/null +++ b/src/metrics/hub.rs @@ -0,0 +1,353 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Central metrics hub for unified collection. + +use std::sync::Arc; + +use super::llm::{LlmMetrics, LlmMetricsReport}; +use super::pilot::{InterventionPoint, PilotMetrics, PilotMetricsReport}; +use super::retrieval::{RetrievalMetrics, RetrievalMetricsReport}; +use crate::config::MetricsConfig; + +/// Central metrics hub for unified collection. +/// +/// Provides a single point for all metrics collection across: +/// - LLM operations (tokens, latency, cost) +/// - Pilot decisions (accuracy, confidence, feedback) +/// - Retrieval operations (paths, scores, cache) +/// +/// # Thread Safety +/// +/// All metrics use atomic operations and are safe to use from multiple threads. +/// +/// # Example +/// +/// ```rust +/// use vectorless::metrics::{MetricsHub, MetricsConfig}; +/// +/// let config = MetricsConfig::default(); +/// let hub = MetricsHub::new(config); +/// +/// // Record LLM call +/// hub.record_llm_call(100, 50, 150, true); +/// +/// // Record Pilot decision +/// hub.record_pilot_decision(0.85, InterventionPoint::Fork); +/// +/// // Get report +/// let report = hub.generate_report(); +/// ``` +#[derive(Debug)] +pub struct MetricsHub { + config: MetricsConfig, + llm: LlmMetrics, + pilot: PilotMetrics, + retrieval: RetrievalMetrics, +} + +impl MetricsHub { + /// Create a new metrics hub. + pub fn new(config: MetricsConfig) -> Self { + Self { + config, + llm: LlmMetrics::new(), + pilot: PilotMetrics::new(), + retrieval: RetrievalMetrics::new(), + } + } + + /// Create a new metrics hub with defaults. + pub fn with_defaults() -> Self { + Self::new(MetricsConfig::default()) + } + + /// Create an Arc-wrapped metrics hub. + pub fn shared() -> Arc { + Arc::new(Self::with_defaults()) + } + + /// Create an Arc-wrapped metrics hub with config. + pub fn shared_with_config(config: MetricsConfig) -> Arc { + Arc::new(Self::new(config)) + } + + /// Check if metrics are enabled. + pub fn is_enabled(&self) -> bool { + self.config.enabled + } + + /// Get the configuration. + pub fn config(&self) -> &MetricsConfig { + &self.config + } + + // ======================================================================== + // LLM Metrics + // ======================================================================== + + /// Record an LLM call. + pub fn record_llm_call( + &self, + input_tokens: u64, + output_tokens: u64, + latency_ms: u64, + success: bool, + ) { + if !self.config.enabled || !self.config.llm.track_tokens { + return; + } + self.llm + .record_call(input_tokens, output_tokens, latency_ms, success, &self.config.llm); + } + + /// Record an LLM rate limit error. + pub fn record_llm_rate_limit(&self) { + if self.config.enabled { + self.llm.record_rate_limit(); + } + } + + /// Record an LLM timeout error. + pub fn record_llm_timeout(&self) { + if self.config.enabled { + self.llm.record_timeout(); + } + } + + /// Record an LLM fallback trigger. + pub fn record_llm_fallback(&self) { + if self.config.enabled { + self.llm.record_fallback(); + } + } + + /// Get LLM metrics report. + pub fn llm_report(&self) -> LlmMetricsReport { + self.llm.generate_report() + } + + // ======================================================================== + // Pilot Metrics + // ======================================================================== + + /// Record a Pilot decision. + pub fn record_pilot_decision(&self, confidence: f64, point: InterventionPoint) { + if !self.config.enabled || !self.config.pilot.track_decisions { + return; + } + self.pilot + .record_decision(confidence, point, &self.config.pilot); + } + + /// Record feedback on a Pilot decision. + pub fn record_pilot_feedback(&self, was_correct: bool) { + if !self.config.enabled || !self.config.pilot.track_feedback { + return; + } + self.pilot.record_feedback(was_correct, &self.config.pilot); + } + + /// Record a Pilot LLM call. + pub fn record_pilot_llm_call(&self) { + if self.config.enabled { + self.pilot.record_llm_call(); + } + } + + /// Record a Pilot intervention. + pub fn record_pilot_intervention(&self) { + if self.config.enabled { + self.pilot.record_intervention(); + } + } + + /// Record a skipped Pilot intervention. + pub fn record_pilot_intervention_skipped(&self) { + if self.config.enabled { + self.pilot.record_skipped_intervention(); + } + } + + /// Record Pilot budget exhausted. + pub fn record_pilot_budget_exhausted(&self) { + if self.config.enabled { + self.pilot.record_budget_exhausted(); + } + } + + /// Record Pilot fallback to algorithm. + pub fn record_pilot_algorithm_fallback(&self) { + if self.config.enabled { + self.pilot.record_algorithm_fallback(); + } + } + + /// Get Pilot metrics report. + pub fn pilot_report(&self) -> PilotMetricsReport { + self.pilot.generate_report() + } + + // ======================================================================== + // Retrieval Metrics + // ======================================================================== + + /// Record a retrieval query. + pub fn record_retrieval_query( + &self, + iterations: u64, + nodes_visited: u64, + latency_ms: u64, + ) { + if !self.config.enabled { + return; + } + self.retrieval + .record_query(iterations, nodes_visited, latency_ms, &self.config.retrieval); + } + + /// Record a found path. + pub fn record_retrieval_path(&self, length: u64, score: f64) { + if !self.config.enabled { + return; + } + self.retrieval + .record_path(length, score, &self.config.retrieval); + } + + /// Record a cache hit. + pub fn record_cache_hit(&self) { + if !self.config.enabled || !self.config.retrieval.track_cache { + return; + } + self.retrieval.record_cache_hit(&self.config.retrieval); + } + + /// Record a cache miss. + pub fn record_cache_miss(&self) { + if !self.config.enabled || !self.config.retrieval.track_cache { + return; + } + self.retrieval.record_cache_miss(&self.config.retrieval); + } + + /// Record a backtrack. + pub fn record_backtrack(&self) { + if self.config.enabled { + self.retrieval.record_backtrack(); + } + } + + /// Record a sufficiency check. + pub fn record_sufficiency_check(&self, was_sufficient: bool) { + if self.config.enabled { + self.retrieval.record_sufficiency_check(was_sufficient); + } + } + + /// Get retrieval metrics report. + pub fn retrieval_report(&self) -> RetrievalMetricsReport { + self.retrieval.generate_report() + } + + // ======================================================================== + // General Operations + // ======================================================================== + + /// Reset all metrics. + pub fn reset(&self) { + self.llm.reset(); + self.pilot.reset(); + self.retrieval.reset(); + } + + /// Generate a complete report. + pub fn generate_report(&self) -> MetricsReport { + MetricsReport { + llm: self.llm_report(), + pilot: self.pilot_report(), + retrieval: self.retrieval_report(), + } + } +} + +impl Default for MetricsHub { + fn default() -> Self { + Self::with_defaults() + } +} + +/// Complete metrics report. +#[derive(Debug, Clone)] +pub struct MetricsReport { + /// LLM metrics. + pub llm: LlmMetricsReport, + /// Pilot metrics. + pub pilot: PilotMetricsReport, + /// Retrieval metrics. + pub retrieval: RetrievalMetricsReport, +} + +impl MetricsReport { + /// Calculate total estimated cost in USD. + pub fn total_cost_usd(&self) -> f64 { + self.llm.estimated_cost_usd + } + + /// Calculate overall success rate. + pub fn overall_success_rate(&self) -> f64 { + let llm_rate = self.llm.success_rate; + let pilot_rate = if self.pilot.total_decisions > 0 { + self.pilot.accuracy + } else { + 1.0 + }; + (llm_rate + pilot_rate) / 2.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metrics_hub_recording() { + let hub = MetricsHub::with_defaults(); + + // Record various metrics + hub.record_llm_call(100, 50, 150, true); + hub.record_pilot_decision(0.9, InterventionPoint::Fork); + hub.record_retrieval_query(5, 10, 100); + + let report = hub.generate_report(); + + assert_eq!(report.llm.total_calls, 1); + assert_eq!(report.pilot.total_decisions, 1); + assert_eq!(report.retrieval.total_queries, 1); + } + + #[test] + fn test_metrics_hub_disabled() { + let config = MetricsConfig::disabled(); + let hub = MetricsHub::new(config); + + hub.record_llm_call(100, 50, 150, true); + hub.record_pilot_decision(0.9, InterventionPoint::Fork); + + let report = hub.generate_report(); + + assert_eq!(report.llm.total_calls, 0); + assert_eq!(report.pilot.total_decisions, 0); + } + + #[test] + fn test_metrics_hub_reset() { + let hub = MetricsHub::with_defaults(); + + hub.record_llm_call(100, 50, 150, true); + hub.reset(); + + let report = hub.generate_report(); + assert_eq!(report.llm.total_calls, 0); + } +} diff --git a/src/metrics/llm.rs b/src/metrics/llm.rs new file mode 100644 index 00000000..c8dc30f1 --- /dev/null +++ b/src/metrics/llm.rs @@ -0,0 +1,206 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! LLM metrics collection. + +use std::sync::atomic::{AtomicU64, Ordering}; + +use crate::config::LlmMetricsConfig; + +/// LLM metrics tracker. +#[derive(Debug, Default)] +pub struct LlmMetrics { + /// Total number of LLM calls. + pub total_calls: AtomicU64, + /// Number of successful calls. + pub successful_calls: AtomicU64, + /// Number of failed calls. + pub failed_calls: AtomicU64, + /// Total input tokens. + pub total_input_tokens: AtomicU64, + /// Total output tokens. + pub total_output_tokens: AtomicU64, + /// Total latency in milliseconds. + pub total_latency_ms: AtomicU64, + /// Estimated cost in micro-dollars. + pub estimated_cost_micros: AtomicU64, + /// Number of rate limit errors. + pub rate_limit_errors: AtomicU64, + /// Number of timeout errors. + pub timeout_errors: AtomicU64, + /// Number of fallback triggers. + pub fallback_triggers: AtomicU64, +} + +impl LlmMetrics { + /// Create new LLM metrics. + pub fn new() -> Self { + Self::default() + } + + /// Record an LLM call. + pub fn record_call( + &self, + input_tokens: u64, + output_tokens: u64, + latency_ms: u64, + success: bool, + config: &LlmMetricsConfig, + ) { + self.total_calls.fetch_add(1, Ordering::Relaxed); + + if success { + self.successful_calls.fetch_add(1, Ordering::Relaxed); + } else { + self.failed_calls.fetch_add(1, Ordering::Relaxed); + } + + if config.track_tokens { + self.total_input_tokens + .fetch_add(input_tokens, Ordering::Relaxed); + self.total_output_tokens + .fetch_add(output_tokens, Ordering::Relaxed); + } + + if config.track_latency { + self.total_latency_ms + .fetch_add(latency_ms, Ordering::Relaxed); + } + + if config.track_cost { + let cost = config.calculate_cost(input_tokens, output_tokens); + // Store in micro-dollars for precision + let cost_micros = (cost * 1_000_000.0) as u64; + self.estimated_cost_micros + .fetch_add(cost_micros, Ordering::Relaxed); + } + } + + /// Record a rate limit error. + pub fn record_rate_limit(&self) { + self.rate_limit_errors.fetch_add(1, Ordering::Relaxed); + } + + /// Record a timeout error. + pub fn record_timeout(&self) { + self.timeout_errors.fetch_add(1, Ordering::Relaxed); + } + + /// Record a fallback trigger. + pub fn record_fallback(&self) { + self.fallback_triggers.fetch_add(1, Ordering::Relaxed); + } + + /// Reset all metrics. + pub fn reset(&self) { + self.total_calls.store(0, Ordering::Relaxed); + self.successful_calls.store(0, Ordering::Relaxed); + self.failed_calls.store(0, Ordering::Relaxed); + self.total_input_tokens.store(0, Ordering::Relaxed); + self.total_output_tokens.store(0, Ordering::Relaxed); + self.total_latency_ms.store(0, Ordering::Relaxed); + self.estimated_cost_micros.store(0, Ordering::Relaxed); + self.rate_limit_errors.store(0, Ordering::Relaxed); + self.timeout_errors.store(0, Ordering::Relaxed); + self.fallback_triggers.store(0, Ordering::Relaxed); + } + + /// Generate a report snapshot. + pub fn generate_report(&self) -> LlmMetricsReport { + let total_calls = self.total_calls.load(Ordering::Relaxed); + let successful = self.successful_calls.load(Ordering::Relaxed); + let failed = self.failed_calls.load(Ordering::Relaxed); + let total_latency = self.total_latency_ms.load(Ordering::Relaxed); + + LlmMetricsReport { + total_calls, + successful_calls: successful, + failed_calls: failed, + success_rate: if total_calls > 0 { + successful as f64 / total_calls as f64 + } else { + 0.0 + }, + total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed), + total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed), + total_tokens: self.total_input_tokens.load(Ordering::Relaxed) + + self.total_output_tokens.load(Ordering::Relaxed), + avg_latency_ms: if total_calls > 0 { + total_latency as f64 / total_calls as f64 + } else { + 0.0 + }, + total_latency_ms: total_latency, + estimated_cost_usd: self.estimated_cost_micros.load(Ordering::Relaxed) as f64 / 1_000_000.0, + rate_limit_errors: self.rate_limit_errors.load(Ordering::Relaxed), + timeout_errors: self.timeout_errors.load(Ordering::Relaxed), + fallback_triggers: self.fallback_triggers.load(Ordering::Relaxed), + } + } +} + +/// LLM metrics report. +#[derive(Debug, Clone)] +pub struct LlmMetricsReport { + /// Total number of LLM calls. + pub total_calls: u64, + /// Number of successful calls. + pub successful_calls: u64, + /// Number of failed calls. + pub failed_calls: u64, + /// Success rate (0.0 - 1.0). + pub success_rate: f64, + /// Total input tokens. + pub total_input_tokens: u64, + /// Total output tokens. + pub total_output_tokens: u64, + /// Total tokens (input + output). + pub total_tokens: u64, + /// Average latency in milliseconds. + pub avg_latency_ms: f64, + /// Total latency in milliseconds. + pub total_latency_ms: u64, + /// Estimated cost in USD. + pub estimated_cost_usd: f64, + /// Number of rate limit errors. + pub rate_limit_errors: u64, + /// Number of timeout errors. + pub timeout_errors: u64, + /// Number of fallback triggers. + pub fallback_triggers: u64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_llm_metrics_recording() { + let config = LlmMetricsConfig::default(); + let metrics = LlmMetrics::new(); + + metrics.record_call(100, 50, 150, true, &config); + metrics.record_call(200, 100, 300, true, &config); + metrics.record_call(100, 0, 0, false, &config); + + let report = metrics.generate_report(); + assert_eq!(report.total_calls, 3); + assert_eq!(report.successful_calls, 2); + assert_eq!(report.failed_calls, 1); + assert!((report.success_rate - 0.666666).abs() < 0.01); + assert_eq!(report.total_input_tokens, 400); + assert_eq!(report.total_output_tokens, 150); + } + + #[test] + fn test_llm_metrics_reset() { + let config = LlmMetricsConfig::default(); + let metrics = LlmMetrics::new(); + + metrics.record_call(100, 50, 150, true, &config); + metrics.reset(); + + let report = metrics.generate_report(); + assert_eq!(report.total_calls, 0); + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs new file mode 100644 index 00000000..69104970 --- /dev/null +++ b/src/metrics/mod.rs @@ -0,0 +1,63 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Unified metrics collection for Vectorless. +//! +//! This module provides centralized metrics collection across all components: +//! - **LLM Metrics** — Token usage, latency, cost +//! - **Pilot Metrics** — Decisions, accuracy, feedback +//! - **Retrieval Metrics** — Paths, scores, iterations, cache +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ MetricsHub │ +//! │ │ +//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +//! │ │ LlmMetrics │ │PilotMetrics │ │RetrievalMetrics│ │ +//! │ │ │ │ │ │ │ │ +//! │ │ - tokens │ │ - decisions │ │ - paths │ │ +//! │ │ - latency │ │ - accuracy │ │ - scores │ │ +//! │ │ - cost │ │ - feedback │ │ - cache │ │ +//! │ └─────────────┘ └─────────────┘ └─────────────┘ │ +//! │ │ +//! │ ┌─────────────────────────────────────────────────────────┐ │ +//! │ │ MetricsReport │ │ +//! │ │ │ │ +//! │ │ Aggregated report with all metrics and statistics │ │ +//! │ └─────────────────────────────────────────────────────────┘ │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Example +//! +//! ```rust +//! use vectorless::metrics::{MetricsHub, MetricsConfig}; +//! +//! let config = MetricsConfig::default(); +//! let hub = MetricsHub::new(config); +//! +//! // Record LLM call +//! hub.record_llm_call(100, 50, 150, true); +//! +//! // Record Pilot decision +//! hub.record_pilot_decision(0.85, InterventionPoint::Fork); +//! +//! // Generate report +//! let report = hub.generate_report(); +//! println!("Total cost: ${:.4}", report.llm.estimated_cost_usd); +//! ``` + +mod hub; +mod llm; +mod pilot; +mod retrieval; + +pub use hub::MetricsHub; +pub use llm::{LlmMetrics, LlmMetricsReport}; +pub use pilot::{InterventionPoint, PilotMetrics, PilotMetricsReport}; +pub use retrieval::{RetrievalMetrics, RetrievalMetricsReport}; + +// Re-export config from config module +pub use crate::config::MetricsConfig; diff --git a/src/metrics/pilot.rs b/src/metrics/pilot.rs new file mode 100644 index 00000000..ccc2ea5e --- /dev/null +++ b/src/metrics/pilot.rs @@ -0,0 +1,248 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Pilot metrics collection. + +use std::sync::atomic::{AtomicU64, Ordering}; + +use crate::config::PilotMetricsConfig; + +/// Intervention point type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InterventionPoint { + /// At search start. + Start, + /// At a fork (multiple candidates). + Fork, + /// During backtracking. + Backtrack, + /// Evaluating content sufficiency. + Evaluate, +} + +/// Helper to store f64 as u64 bits for atomic operations. +fn f64_to_u64_bits(v: f64) -> u64 { + v.to_bits() +} + +/// Helper to convert u64 bits back to f64. +fn u64_bits_to_f64(v: u64) -> f64 { + f64::from_bits(v) +} + +/// Pilot metrics tracker. +#[derive(Debug, Default)] +pub struct PilotMetrics { + /// Total number of Pilot decisions. + pub total_decisions: AtomicU64, + /// Number of start guidance calls. + pub start_guidance_calls: AtomicU64, + /// Number of fork decisions. + pub fork_decisions: AtomicU64, + /// Number of backtrack guidance calls. + pub backtrack_calls: AtomicU64, + /// Number of evaluate calls. + pub evaluate_calls: AtomicU64, + /// Number of correct decisions (based on feedback). + pub correct_decisions: AtomicU64, + /// Number of incorrect decisions (based on feedback). + pub incorrect_decisions: AtomicU64, + /// Sum of confidence values stored as u64 bits (for atomic ops). + /// We store the sum scaled by 1,000,000 to maintain precision. + pub confidence_sum_scaled: AtomicU64, + /// Number of confidence samples. + pub confidence_count: AtomicU64, + /// Number of LLM calls made by Pilot. + pub llm_calls: AtomicU64, + /// Number of times Pilot intervened. + pub interventions: AtomicU64, + /// Number of times Pilot skipped intervention (algorithm was confident). + pub skipped_interventions: AtomicU64, + /// Number of budget exhausted events. + pub budget_exhausted: AtomicU64, + /// Number of fallback to algorithm. + pub algorithm_fallbacks: AtomicU64, +} + +impl PilotMetrics { + /// Create new Pilot metrics. + pub fn new() -> Self { + Self::default() + } + + /// Record a Pilot decision. + pub fn record_decision(&self, confidence: f64, point: InterventionPoint, config: &PilotMetricsConfig) { + if !config.track_decisions { + return; + } + + self.total_decisions.fetch_add(1, Ordering::Relaxed); + + match point { + InterventionPoint::Start => { + self.start_guidance_calls.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Fork => { + self.fork_decisions.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Backtrack => { + self.backtrack_calls.fetch_add(1, Ordering::Relaxed); + } + InterventionPoint::Evaluate => { + self.evaluate_calls.fetch_add(1, Ordering::Relaxed); + } + } + + // Update average confidence (store as scaled integer for atomic operations) + let scaled_confidence = (confidence * 1_000_000.0) as u64; + self.confidence_sum_scaled.fetch_add(scaled_confidence, Ordering::Relaxed); + self.confidence_count.fetch_add(1, Ordering::Relaxed); + } + + /// Record feedback on a decision. + pub fn record_feedback(&self, was_correct: bool, config: &PilotMetricsConfig) { + if !config.track_feedback { + return; + } + + if was_correct { + self.correct_decisions.fetch_add(1, Ordering::Relaxed); + } else { + self.incorrect_decisions.fetch_add(1, Ordering::Relaxed); + } + } + + /// Record an LLM call made by Pilot. + pub fn record_llm_call(&self) { + self.llm_calls.fetch_add(1, Ordering::Relaxed); + } + + /// Record an intervention. + pub fn record_intervention(&self) { + self.interventions.fetch_add(1, Ordering::Relaxed); + } + + /// Record a skipped intervention. + pub fn record_skipped_intervention(&self) { + self.skipped_interventions.fetch_add(1, Ordering::Relaxed); + } + + /// Record budget exhausted. + pub fn record_budget_exhausted(&self) { + self.budget_exhausted.fetch_add(1, Ordering::Relaxed); + } + + /// Record algorithm fallback. + pub fn record_algorithm_fallback(&self) { + self.algorithm_fallbacks.fetch_add(1, Ordering::Relaxed); + } + + /// Reset all metrics. + pub fn reset(&self) { + self.total_decisions.store(0, Ordering::Relaxed); + self.start_guidance_calls.store(0, Ordering::Relaxed); + self.fork_decisions.store(0, Ordering::Relaxed); + self.backtrack_calls.store(0, Ordering::Relaxed); + self.evaluate_calls.store(0, Ordering::Relaxed); + self.correct_decisions.store(0, Ordering::Relaxed); + self.incorrect_decisions.store(0, Ordering::Relaxed); + self.confidence_sum_scaled.store(0, Ordering::Relaxed); + self.confidence_count.store(0, Ordering::Relaxed); + self.llm_calls.store(0, Ordering::Relaxed); + self.interventions.store(0, Ordering::Relaxed); + self.skipped_interventions.store(0, Ordering::Relaxed); + self.budget_exhausted.store(0, Ordering::Relaxed); + self.algorithm_fallbacks.store(0, Ordering::Relaxed); + } + + /// Generate a report snapshot. + pub fn generate_report(&self) -> PilotMetricsReport { + let total_decisions = self.total_decisions.load(Ordering::Relaxed); + let correct = self.correct_decisions.load(Ordering::Relaxed); + let total_feedback = correct + self.incorrect_decisions.load(Ordering::Relaxed); + let confidence_count = self.confidence_count.load(Ordering::Relaxed); + let confidence_sum_scaled = self.confidence_sum_scaled.load(Ordering::Relaxed); + + PilotMetricsReport { + total_decisions, + start_guidance_calls: self.start_guidance_calls.load(Ordering::Relaxed), + fork_decisions: self.fork_decisions.load(Ordering::Relaxed), + backtrack_calls: self.backtrack_calls.load(Ordering::Relaxed), + evaluate_calls: self.evaluate_calls.load(Ordering::Relaxed), + accuracy: if total_feedback > 0 { + correct as f64 / total_feedback as f64 + } else { + 0.0 + }, + correct_decisions: correct, + incorrect_decisions: self.incorrect_decisions.load(Ordering::Relaxed), + avg_confidence: if confidence_count > 0 { + (confidence_sum_scaled as f64 / 1_000_000.0) / confidence_count as f64 + } else { + 0.0 + }, + llm_calls: self.llm_calls.load(Ordering::Relaxed), + interventions: self.interventions.load(Ordering::Relaxed), + skipped_interventions: self.skipped_interventions.load(Ordering::Relaxed), + budget_exhausted: self.budget_exhausted.load(Ordering::Relaxed), + algorithm_fallbacks: self.algorithm_fallbacks.load(Ordering::Relaxed), + } + } +} + +/// Pilot metrics report. +#[derive(Debug, Clone)] +pub struct PilotMetricsReport { + /// Total number of decisions. + pub total_decisions: u64, + /// Number of start guidance calls. + pub start_guidance_calls: u64, + /// Number of fork decisions. + pub fork_decisions: u64, + /// Number of backtrack calls. + pub backtrack_calls: u64, + /// Number of evaluate calls. + pub evaluate_calls: u64, + /// Decision accuracy based on feedback. + pub accuracy: f64, + /// Number of correct decisions. + pub correct_decisions: u64, + /// Number of incorrect decisions. + pub incorrect_decisions: u64, + /// Average confidence across all decisions. + pub avg_confidence: f64, + /// Number of LLM calls made by Pilot. + pub llm_calls: u64, + /// Number of interventions. + pub interventions: u64, + /// Number of skipped interventions. + pub skipped_interventions: u64, + /// Number of budget exhausted events. + pub budget_exhausted: u64, + /// Number of algorithm fallbacks. + pub algorithm_fallbacks: u64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pilot_metrics_recording() { + let config = PilotMetricsConfig::default(); + let metrics = PilotMetrics::new(); + + metrics.record_decision(0.9, InterventionPoint::Start, &config); + metrics.record_decision(0.8, InterventionPoint::Fork, &config); + metrics.record_decision(0.7, InterventionPoint::Fork, &config); + + metrics.record_feedback(true, &config); + metrics.record_feedback(false, &config); + + let report = metrics.generate_report(); + assert_eq!(report.total_decisions, 3); + assert_eq!(report.fork_decisions, 2); + assert!((report.accuracy - 0.5).abs() < 0.01); + assert!((report.avg_confidence - 0.8).abs() < 0.01); + } +} diff --git a/src/metrics/retrieval.rs b/src/metrics/retrieval.rs new file mode 100644 index 00000000..56ecb140 --- /dev/null +++ b/src/metrics/retrieval.rs @@ -0,0 +1,253 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Retrieval metrics collection. + +use std::sync::atomic::{AtomicU64, Ordering}; + +use crate::config::RetrievalMetricsConfig; + +/// Retrieval metrics tracker. +#[derive(Debug, Default)] +pub struct RetrievalMetrics { + /// Total number of queries. + pub total_queries: AtomicU64, + /// Total number of search iterations. + pub total_iterations: AtomicU64, + /// Sum of iterations (for average). + pub iterations_sum: AtomicU64, + /// Total number of nodes visited. + pub nodes_visited: AtomicU64, + /// Total number of paths found. + pub paths_found: AtomicU64, + /// Sum of path lengths (for average). + pub path_length_sum: AtomicU64, + /// Sum of path scores stored as scaled integer (multiply by 1_000_000 for actual value). + pub path_score_sum_scaled: AtomicU64, + /// Number of paths with score >= 0.5. + pub high_score_paths: AtomicU64, + /// Number of paths with score < 0.3. + pub low_score_paths: AtomicU64, + /// Number of cache hits. + pub cache_hits: AtomicU64, + /// Number of cache misses. + pub cache_misses: AtomicU64, + /// Total latency in milliseconds. + pub total_latency_ms: AtomicU64, + /// Number of backtracks. + pub backtracks: AtomicU64, + /// Number of sufficiency checks. + pub sufficiency_checks: AtomicU64, + /// Number of times content was sufficient. + pub sufficient_results: AtomicU64, +} + +impl RetrievalMetrics { + /// Create new retrieval metrics. + pub fn new() -> Self { + Self::default() + } + + /// Record a query. + pub fn record_query(&self, iterations: u64, nodes: u64, latency_ms: u64, config: &RetrievalMetricsConfig) { + self.total_queries.fetch_add(1, Ordering::Relaxed); + + if config.track_iterations { + self.total_iterations.fetch_add(iterations, Ordering::Relaxed); + self.iterations_sum.fetch_add(iterations, Ordering::Relaxed); + } + + if config.track_paths { + self.nodes_visited.fetch_add(nodes, Ordering::Relaxed); + } + + self.total_latency_ms.fetch_add(latency_ms, Ordering::Relaxed); + } + + /// Record a found path. + pub fn record_path(&self, length: u64, score: f64, config: &RetrievalMetricsConfig) { + if !config.track_paths { + return; + } + + self.paths_found.fetch_add(1, Ordering::Relaxed); + self.path_length_sum.fetch_add(length, Ordering::Relaxed); + + if config.track_scores { + let scaled_score = (score * 1_000_000.0) as u64; + self.path_score_sum_scaled.fetch_add(scaled_score, Ordering::Relaxed); + + if score >= 0.5 { + self.high_score_paths.fetch_add(1, Ordering::Relaxed); + } else if score < 0.3 { + self.low_score_paths.fetch_add(1, Ordering::Relaxed); + } + } + } + + /// Record a cache hit. + pub fn record_cache_hit(&self, config: &RetrievalMetricsConfig) { + if config.track_cache { + self.cache_hits.fetch_add(1, Ordering::Relaxed); + } + } + + /// Record a cache miss. + pub fn record_cache_miss(&self, config: &RetrievalMetricsConfig) { + if config.track_cache { + self.cache_misses.fetch_add(1, Ordering::Relaxed); + } + } + + /// Record a backtrack. + pub fn record_backtrack(&self) { + self.backtracks.fetch_add(1, Ordering::Relaxed); + } + + /// Record a sufficiency check. + pub fn record_sufficiency_check(&self, was_sufficient: bool) { + self.sufficiency_checks.fetch_add(1, Ordering::Relaxed); + if was_sufficient { + self.sufficient_results.fetch_add(1, Ordering::Relaxed); + } + } + + /// Reset all metrics. + pub fn reset(&self) { + self.total_queries.store(0, Ordering::Relaxed); + self.total_iterations.store(0, Ordering::Relaxed); + self.iterations_sum.store(0, Ordering::Relaxed); + self.nodes_visited.store(0, Ordering::Relaxed); + self.paths_found.store(0, Ordering::Relaxed); + self.path_length_sum.store(0, Ordering::Relaxed); + self.path_score_sum_scaled.store(0, Ordering::Relaxed); + self.high_score_paths.store(0, Ordering::Relaxed); + self.low_score_paths.store(0, Ordering::Relaxed); + self.cache_hits.store(0, Ordering::Relaxed); + self.cache_misses.store(0, Ordering::Relaxed); + self.total_latency_ms.store(0, Ordering::Relaxed); + self.backtracks.store(0, Ordering::Relaxed); + self.sufficiency_checks.store(0, Ordering::Relaxed); + self.sufficient_results.store(0, Ordering::Relaxed); + } + + /// Generate a report snapshot. + pub fn generate_report(&self) -> RetrievalMetricsReport { + let total_queries = self.total_queries.load(Ordering::Relaxed); + let paths_found = self.paths_found.load(Ordering::Relaxed); + let cache_hits = self.cache_hits.load(Ordering::Relaxed); + let cache_misses = self.cache_misses.load(Ordering::Relaxed); + let total_cache = cache_hits + cache_misses; + let sufficiency_checks = self.sufficiency_checks.load(Ordering::Relaxed); + + RetrievalMetricsReport { + total_queries, + total_iterations: self.total_iterations.load(Ordering::Relaxed), + avg_iterations: if total_queries > 0 { + self.iterations_sum.load(Ordering::Relaxed) as f64 / total_queries as f64 + } else { + 0.0 + }, + nodes_visited: self.nodes_visited.load(Ordering::Relaxed), + paths_found, + avg_path_length: if paths_found > 0 { + self.path_length_sum.load(Ordering::Relaxed) as f64 / paths_found as f64 + } else { + 0.0 + }, + avg_path_score: if paths_found > 0 { + (self.path_score_sum_scaled.load(Ordering::Relaxed) as f64 / 1_000_000.0) / paths_found as f64 + } else { + 0.0 + }, + high_score_paths: self.high_score_paths.load(Ordering::Relaxed), + low_score_paths: self.low_score_paths.load(Ordering::Relaxed), + cache_hits, + cache_misses, + cache_hit_rate: if total_cache > 0 { + cache_hits as f64 / total_cache as f64 + } else { + 0.0 + }, + total_latency_ms: self.total_latency_ms.load(Ordering::Relaxed), + avg_latency_ms: if total_queries > 0 { + self.total_latency_ms.load(Ordering::Relaxed) as f64 / total_queries as f64 + } else { + 0.0 + }, + backtracks: self.backtracks.load(Ordering::Relaxed), + sufficiency_checks, + sufficiency_rate: if sufficiency_checks > 0 { + self.sufficient_results.load(Ordering::Relaxed) as f64 / sufficiency_checks as f64 + } else { + 0.0 + }, + } + } +} + +/// Retrieval metrics report. +#[derive(Debug, Clone)] +pub struct RetrievalMetricsReport { + /// Total number of queries. + pub total_queries: u64, + /// Total number of iterations. + pub total_iterations: u64, + /// Average iterations per query. + pub avg_iterations: f64, + /// Total nodes visited. + pub nodes_visited: u64, + /// Total paths found. + pub paths_found: u64, + /// Average path length. + pub avg_path_length: f64, + /// Average path score. + pub avg_path_score: f64, + /// Number of high-score paths (>= 0.5). + pub high_score_paths: u64, + /// Number of low-score paths (< 0.3). + pub low_score_paths: u64, + /// Number of cache hits. + pub cache_hits: u64, + /// Number of cache misses. + pub cache_misses: u64, + /// Cache hit rate. + pub cache_hit_rate: f64, + /// Total latency in milliseconds. + pub total_latency_ms: u64, + /// Average latency per query in milliseconds. + pub avg_latency_ms: f64, + /// Number of backtracks. + pub backtracks: u64, + /// Number of sufficiency checks. + pub sufficiency_checks: u64, + /// Sufficiency rate. + pub sufficiency_rate: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_retrieval_metrics_recording() { + let config = RetrievalMetricsConfig::default(); + let metrics = RetrievalMetrics::new(); + + metrics.record_query(5, 10, 100, &config); + metrics.record_query(3, 8, 80, &config); + + metrics.record_path(3, 0.8, &config); + metrics.record_path(2, 0.2, &config); + + metrics.record_cache_hit(&config); + metrics.record_cache_hit(&config); + metrics.record_cache_miss(&config); + + let report = metrics.generate_report(); + assert_eq!(report.total_queries, 2); + assert_eq!(report.total_iterations, 8); + assert_eq!(report.paths_found, 2); + assert!((report.cache_hit_rate - 0.666).abs() < 0.01); + } +} diff --git a/src/retrieval/decompose.rs b/src/retrieval/decompose.rs new file mode 100644 index 00000000..603c388d --- /dev/null +++ b/src/retrieval/decompose.rs @@ -0,0 +1,741 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Query decomposition for multi-turn retrieval. +//! +//! Complex queries are broken down into simpler sub-queries +//! that can be processed independently and then combined. +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ Query Decomposition │ +//! ├─────────────────────────────────────────────────────────────────┤ +//! │ │ +//! │ Complex Query ──▶ [Decomposer] ──▶ [Sub-queries] │ +//! │ │ │ │ +//! │ │ ▼ │ +//! │ │ ┌───────────────┐ │ +//! │ │ │ Sub-query 1 │ │ +//! │ │ │ Sub-query 2 │ │ +//! │ │ │ Sub-query 3 │ │ +//! │ │ └───────┬───────┘ │ +//! │ │ │ │ +//! │ └──────────────────────────────────┼─────────────────────┘ +//! │ ▼ │ +//! │ [Result Aggregator] │ +//! │ │ │ +//! │ ▼ │ +//! │ [Final Result] │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Example +//! +//! ```rust,ignore +//! use vectorless::retrieval::decompose::{QueryDecomposer, DecompositionConfig}; +//! +//! let decomposer = QueryDecomposer::new(config); +//! let result = decomposer.decompose("What is the architecture and how does caching work?").await?; +//! +//! for sub_query in &result.sub_queries { +//! println!("Sub-query: {}", sub_query.text); +//! } +//! ``` + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use tracing::{debug, info}; + +use crate::llm::{LlmClient, LlmExecutor}; + +/// Sub-query resulting from decomposition. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubQuery { + /// The sub-query text. + pub text: String, + /// Estimated complexity of this sub-query. + pub complexity: SubQueryComplexity, + /// Order of execution (lower = higher priority). + pub priority: u8, + /// Dependencies on other sub-queries (indices). + pub depends_on: Vec, + /// Type of sub-query. + pub query_type: SubQueryType, +} + +/// Complexity level for a sub-query. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum SubQueryComplexity { + /// Simple keyword lookup. + Simple, + /// Requires understanding context. + Medium, + /// Requires synthesis or reasoning. + Complex, +} + +impl Default for SubQueryComplexity { + fn default() -> Self { + Self::Simple + } +} + +/// Type of sub-query. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum SubQueryType { + /// Fact lookup (who, what, when). + Fact, + /// Explanation (why, how). + Explanation, + /// Comparison (difference between). + Comparison, + /// Synthesis (summarize, combine). + Synthesis, + /// Navigation (where to find). + Navigation, +} + +impl Default for SubQueryType { + fn default() -> Self { + Self::Fact + } +} + +/// Result of query decomposition. +#[derive(Debug, Clone)] +pub struct DecompositionResult { + /// Original query. + pub original: String, + /// Decomposed sub-queries. + pub sub_queries: Vec, + /// Whether decomposition was needed. + pub was_decomposed: bool, + /// Reason for decomposition decision. + pub reason: String, + /// Estimated total complexity. + pub total_complexity: f32, +} + +impl DecompositionResult { + /// Create a result without decomposition (query is simple enough). + pub fn no_decomposition(query: &str, reason: &str) -> Self { + Self { + original: query.to_string(), + sub_queries: vec![SubQuery { + text: query.to_string(), + complexity: SubQueryComplexity::Simple, + priority: 0, + depends_on: vec![], + query_type: SubQueryType::Fact, + }], + was_decomposed: false, + reason: reason.to_string(), + total_complexity: 0.5, + } + } + + /// Check if decomposition produced multiple queries. + pub fn is_multi_turn(&self) -> bool { + self.sub_queries.len() > 1 + } + + /// Get execution order (topologically sorted). + pub fn execution_order(&self) -> Vec { + if self.sub_queries.len() <= 1 { + return vec![0]; + } + + // Simple topological sort based on dependencies and priority + let mut order: Vec = (0..self.sub_queries.len()).collect(); + order.sort_by(|&a, &b| { + // First sort by dependencies (fewer dependencies first) + let a_deps = self.sub_queries[a].depends_on.len(); + let b_deps = self.sub_queries[b].depends_on.len(); + if a_deps != b_deps { + return a_deps.cmp(&b_deps); + } + // Then by priority (lower priority value first) + self.sub_queries[a] + .priority + .cmp(&self.sub_queries[b].priority) + }); + order + } +} + +/// Configuration for query decomposition. +#[derive(Debug, Clone)] +pub struct DecompositionConfig { + /// Maximum sub-queries to generate. + pub max_sub_queries: usize, + /// Minimum query length to consider for decomposition. + pub min_query_length: usize, + /// Enable LLM-based decomposition. + pub use_llm: bool, + /// Threshold for decomposing (complexity score). + pub complexity_threshold: f32, + /// Enable dependency detection. + pub detect_dependencies: bool, +} + +impl Default for DecompositionConfig { + fn default() -> Self { + Self { + max_sub_queries: 5, + min_query_length: 20, + use_llm: true, + complexity_threshold: 0.7, + detect_dependencies: true, + } + } +} + +/// Query decomposer for multi-turn retrieval. +pub struct QueryDecomposer { + /// Configuration. + config: DecompositionConfig, + /// LLM client for decomposition (optional). + llm_client: Option, + /// LLM executor for unified execution (optional). + llm_executor: Option, +} + +impl Default for QueryDecomposer { + fn default() -> Self { + Self::new(DecompositionConfig::default()) + } +} + +impl QueryDecomposer { + /// Create a new query decomposer. + pub fn new(config: DecompositionConfig) -> Self { + Self { + config, + llm_client: None, + llm_executor: None, + } + } + + /// Add LLM client for enhanced decomposition. + pub fn with_llm_client(mut self, client: LlmClient) -> Self { + self.llm_client = Some(client); + self + } + + /// Add LLM executor for unified throttle/retry/fallback. + pub fn with_llm_executor(mut self, executor: LlmExecutor) -> Self { + self.llm_executor = Some(executor); + self + } + + /// Decompose a query into sub-queries. + pub async fn decompose(&self, query: &str) -> crate::error::Result { + // Check if decomposition is needed + if !self.should_decompose(query) { + return Ok(DecompositionResult::no_decomposition( + query, + "Query is simple enough, no decomposition needed", + )); + } + + info!("Decomposing complex query: '{}'", query); + + // Try LLM-based decomposition if available + if self.config.use_llm && (self.llm_client.is_some() || self.llm_executor.is_some()) { + match self.llm_decompose(query).await { + Ok(result) => return Ok(result), + Err(e) => { + debug!("LLM decomposition failed, falling back to rule-based: {}", e); + } + } + } + + // Fall back to rule-based decomposition + self.rule_based_decompose(query) + } + + /// Check if a query should be decomposed. + fn should_decompose(&self, query: &str) -> bool { + // Skip short queries + if query.len() < self.config.min_query_length { + return false; + } + + // Calculate complexity score + let complexity = self.calculate_complexity(query); + complexity >= self.config.complexity_threshold + } + + /// Calculate complexity score for a query. + fn calculate_complexity(&self, query: &str) -> f32 { + let mut score = 0.0; + let query_lower = query.to_lowercase(); + + // 1. Multiple questions (question marks or "and" between questions) + let question_count = query.matches('?').count(); + score += (question_count as f32 * 0.3).min(1.0); + + // 2. Multiple clauses (indicated by conjunctions) + let conjunctions = [" and ", " or ", " but ", " also ", " plus "]; + let conjunction_count = conjunctions + .iter() + .filter(|c| query_lower.contains(*c)) + .count(); + score += (conjunction_count as f32 * 0.2).min(0.6); + + // 3. Complex question words + let complex_indicators = [ + "compare", + "contrast", + "difference between", + "relationship between", + "how does", + "why does", + "explain how", + "analyze", + "evaluate", + "synthesize", + ]; + for indicator in &complex_indicators { + if query_lower.contains(indicator) { + score += 0.2; + } + } + + // 4. Length factor + let word_count = query.split_whitespace().count(); + if word_count > 15 { + score += 0.1 * ((word_count - 15) as f32 / 10.0).min(1.0); + } + + score.min(1.0) + } + + /// Rule-based decomposition (no LLM). + fn rule_based_decompose(&self, query: &str) -> crate::error::Result { + let mut sub_queries = Vec::new(); + let query_lower = query.to_lowercase(); + + // Split on common patterns + let patterns = [ + (" and ", " and "), + ("? ", "? "), + (" also ", " also "), + (" as well as ", " as well as "), + ]; + + // Check for question splits + if query.contains('?') { + let parts: Vec<&str> = query.split('?').filter(|s| !s.trim().is_empty()).collect(); + for (i, part) in parts.iter().enumerate() { + let text = format!("{}?", part.trim()); + sub_queries.push(SubQuery { + text, + complexity: self.estimate_sub_query_complexity(part), + priority: i as u8, + depends_on: vec![], + query_type: self.detect_query_type(part), + }); + } + } + + // If no questions found, try conjunction split + if sub_queries.is_empty() { + for (pattern, _) in &patterns { + if query_lower.contains(pattern) { + let parts: Vec<&str> = query.split(pattern).filter(|s| !s.trim().is_empty()).collect(); + if parts.len() > 1 { + for (i, part) in parts.iter().enumerate() { + sub_queries.push(SubQuery { + text: part.trim().to_string(), + complexity: self.estimate_sub_query_complexity(part), + priority: i as u8, + depends_on: if i > 0 && self.config.detect_dependencies { + vec![i - 1] + } else { + vec![] + }, + query_type: self.detect_query_type(part), + }); + } + break; + } + } + } + } + + // If still no decomposition, return original + if sub_queries.is_empty() || sub_queries.len() > self.config.max_sub_queries { + return Ok(DecompositionResult::no_decomposition( + query, + "No clear decomposition patterns found", + )); + } + + Ok(DecompositionResult { + original: query.to_string(), + sub_queries, + was_decomposed: true, + reason: "Rule-based decomposition".to_string(), + total_complexity: self.calculate_complexity(query), + }) + } + + /// LLM-based decomposition. + async fn llm_decompose(&self, query: &str) -> crate::error::Result { + let system = r#"You are a query decomposition expert. Break down complex queries into simpler sub-queries. + +Rules: +1. Each sub-query should be answerable independently when possible +2. Preserve the original intent +3. Maximum 5 sub-queries +4. Return JSON format: {"sub_queries": [{"text": "...", "complexity": "simple|medium|complex", "priority": 0-4, "depends_on": [], "query_type": "fact|explanation|comparison|synthesis|navigation"}], "reason": "..."} + +If the query is simple enough, return just one sub-query."#; + + let user = format!("Decompose this query: {}", query); + + let response = if let Some(ref executor) = self.llm_executor { + executor.complete(system, &user).await.map_err(|e| { + crate::error::Error::Llm(format!("LLM executor error: {}", e)) + })? + } else if let Some(ref client) = self.llm_client { + client.complete(system, &user).await.map_err(|e| { + crate::error::Error::Llm(format!("LLM client error: {}", e)) + })? + } else { + return Err(crate::error::Error::Config( + "No LLM client or executor configured".to_string(), + )); + }; + + // Parse the JSON response + #[derive(Deserialize)] + struct DecompositionResponse { + sub_queries: Vec, + reason: String, + } + + let parsed: DecompositionResponse = + serde_json::from_str(&extract_json(&response)).map_err(|e| { + crate::error::Error::Llm(format!("Failed to parse decomposition: {}", e)) + })?; + + if parsed.sub_queries.is_empty() { + return Ok(DecompositionResult::no_decomposition( + query, + "LLM returned empty decomposition", + )); + } + + let sub_queries: Vec = parsed + .sub_queries + .into_iter() + .take(self.config.max_sub_queries) + .collect(); + + Ok(DecompositionResult { + original: query.to_string(), + sub_queries, + was_decomposed: true, + reason: parsed.reason, + total_complexity: self.calculate_complexity(query), + }) + } + + /// Estimate complexity for a sub-query. + fn estimate_sub_query_complexity(&self, text: &str) -> SubQueryComplexity { + let text_lower = text.to_lowercase(); + + // Check for complex indicators + if text_lower.contains("compare") + || text_lower.contains("contrast") + || text_lower.contains("analyze") + || text_lower.contains("evaluate") + || text_lower.contains("synthesize") + { + return SubQueryComplexity::Complex; + } + + // Check for medium complexity + if text_lower.contains("how") + || text_lower.contains("why") + || text_lower.contains("explain") + || text_lower.contains("describe") + { + return SubQueryComplexity::Medium; + } + + SubQueryComplexity::Simple + } + + /// Detect the type of a sub-query. + fn detect_query_type(&self, text: &str) -> SubQueryType { + let text_lower = text.to_lowercase(); + + if text_lower.contains("compare") + || text_lower.contains("difference") + || text_lower.contains("versus") + || text_lower.contains(" vs ") + { + return SubQueryType::Comparison; + } + + if text_lower.contains("why") + || text_lower.contains("how") + || text_lower.contains("explain") + { + return SubQueryType::Explanation; + } + + if text_lower.contains("summarize") + || text_lower.contains("combine") + || text_lower.contains("synthesize") + || text_lower.contains("overall") + { + return SubQueryType::Synthesis; + } + + if text_lower.contains("where") + || text_lower.contains("which section") + || text_lower.contains("find") + { + return SubQueryType::Navigation; + } + + SubQueryType::Fact + } +} + +/// Extract JSON from a potentially verbose LLM response. +fn extract_json(text: &str) -> String { + // Try to find JSON object + if let Some(start) = text.find('{') { + if let Some(end) = text.rfind('}') { + if end > start { + return text[start..=end].to_string(); + } + } + } + text.to_string() +} + +/// Result aggregator for multi-turn retrieval. +#[derive(Debug, Clone)] +pub struct SubQueryResult { + /// The sub-query. + pub query: SubQuery, + /// Retrieved content. + pub content: String, + /// Relevance score. + pub score: f32, + /// Nodes that contributed to the result. + pub source_nodes: Vec, +} + +/// Aggregator for combining sub-query results. +pub struct ResultAggregator { + /// Maximum tokens in final result. + pub max_tokens: usize, + /// Weight by query priority. + pub priority_weight: f32, +} + +impl Default for ResultAggregator { + fn default() -> Self { + Self { + max_tokens: 4000, + priority_weight: 0.3, + } + } +} + +impl ResultAggregator { + /// Create a new result aggregator. + pub fn new() -> Self { + Self::default() + } + + /// Aggregate results from multiple sub-queries. + pub fn aggregate( + &self, + results: &[SubQueryResult], + decomposition: &DecompositionResult, + ) -> String { + if results.is_empty() { + return String::new(); + } + + if results.len() == 1 { + return results[0].content.clone(); + } + + // Sort by execution order and priority + let order = decomposition.execution_order(); + let sorted_results: Vec<_> = order + .iter() + .filter_map(|&i| results.iter().find(|r| r.query.text == decomposition.sub_queries[i].text)) + .collect(); + + // Combine results with section headers + let mut combined = String::new(); + let mut total_tokens = 0; + + for result in sorted_results { + let section = format!( + "\n### {}\n\n{}\n", + result.query.text, + result.content + ); + + let section_tokens = section.len() / 4; // Rough estimate + if total_tokens + section_tokens > self.max_tokens { + // Truncate if needed + let remaining = self.max_tokens - total_tokens; + if remaining > 100 { + let end_pos = (remaining * 4).min(result.content.len()); + combined.push_str(&format!( + "\n### {}\n\n{}\n", + result.query.text, + &result.content[..end_pos] + )); + } + break; + } + + combined.push_str(§ion); + total_tokens += section_tokens; + } + + combined + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_complexity_calculation() { + let decomposer = QueryDecomposer::default(); + + // Simple query + let simple = "What is the architecture?"; + let simple_score = decomposer.calculate_complexity(simple); + assert!(simple_score < 0.5); + + // Complex query + let complex = "What is the architecture and how does it compare to other systems?"; + let complex_score = decomposer.calculate_complexity(complex); + assert!(complex_score > simple_score); + } + + #[test] + fn test_rule_based_decomposition() { + let decomposer = QueryDecomposer::default(); + + let result = decomposer.rule_based_decompose( + "What is the architecture? How does caching work?", + ).unwrap(); + + assert!(result.was_decomposed); + assert_eq!(result.sub_queries.len(), 2); + } + + #[test] + fn test_no_decomposition() { + let result = DecompositionResult::no_decomposition( + "What is this?", + "Query is simple", + ); + + assert!(!result.was_decomposed); + assert!(!result.is_multi_turn()); + } + + #[test] + fn test_execution_order() { + let mut result = DecompositionResult::no_decomposition("test", "test"); + result.sub_queries = vec![ + SubQuery { + text: "First".to_string(), + priority: 2, + depends_on: vec![], + query_type: SubQueryType::Fact, + complexity: SubQueryComplexity::Simple, + }, + SubQuery { + text: "Second".to_string(), + priority: 1, + depends_on: vec![0], + query_type: SubQueryType::Fact, + complexity: SubQueryComplexity::Simple, + }, + ]; + result.was_decomposed = true; + + let order = result.execution_order(); + assert_eq!(order, vec![0, 1]); // First should come before Second + } + + #[test] + fn test_query_type_detection() { + let decomposer = QueryDecomposer::default(); + + assert_eq!( + decomposer.detect_query_type("Compare A and B"), + SubQueryType::Comparison + ); + assert_eq!( + decomposer.detect_query_type("Why does this happen?"), + SubQueryType::Explanation + ); + assert_eq!( + decomposer.detect_query_type("Where is the config?"), + SubQueryType::Navigation + ); + } + + #[test] + fn test_result_aggregator() { + let aggregator = ResultAggregator::new(); + + let results = vec![ + SubQueryResult { + query: SubQuery { + text: "First question?".to_string(), + priority: 0, + depends_on: vec![], + query_type: SubQueryType::Fact, + complexity: SubQueryComplexity::Simple, + }, + content: "Answer 1".to_string(), + score: 0.9, + source_nodes: vec![], + }, + SubQueryResult { + query: SubQuery { + text: "Second question?".to_string(), + priority: 1, + depends_on: vec![0], + query_type: SubQueryType::Fact, + complexity: SubQueryComplexity::Simple, + }, + content: "Answer 2".to_string(), + score: 0.8, + source_nodes: vec![], + }, + ]; + + let mut decomposition = DecompositionResult::no_decomposition("test", "test"); + decomposition.sub_queries = results.iter().map(|r| r.query.clone()).collect(); + decomposition.was_decomposed = true; + + let combined = aggregator.aggregate(&results, &decomposition); + assert!(combined.contains("First question")); + assert!(combined.contains("Answer 1")); + } +} diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs index 8e7d3a2a..1a87fda7 100644 --- a/src/retrieval/mod.rs +++ b/src/retrieval/mod.rs @@ -48,6 +48,7 @@ //! ``` mod context; +mod decompose; mod pipeline_retriever; mod retriever; mod types; @@ -115,3 +116,9 @@ pub use pilot::{ BudgetConfig, InterventionConfig, InterventionPoint, Pilot, PilotConfig, PilotDecision, PilotMode, RankedCandidate, SearchDirection, SearchState, }; + +// Decompose exports (multi-turn retrieval) +pub use decompose::{ + DecompositionConfig, DecompositionResult, QueryDecomposer, ResultAggregator, SubQuery, + SubQueryComplexity, SubQueryResult, SubQueryType, +}; diff --git a/src/retrieval/pilot/feedback.rs b/src/retrieval/pilot/feedback.rs new file mode 100644 index 00000000..0d2efbdd --- /dev/null +++ b/src/retrieval/pilot/feedback.rs @@ -0,0 +1,737 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Pilot feedback learning system. +//! +//! This module provides feedback collection and learning capabilities +//! for the Pilot to improve its decision-making over time. +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ Feedback Learning System │ +//! ├─────────────────────────────────────────────────────────────────┤ +//! │ │ +//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +//! │ │ Feedback │ │ Feedback │ │ Pilot │ │ +//! │ │ Record │──▶│ Store │──▶│ Learner │ │ +//! │ └─────────────┘ └─────────────┘ └─────────────┘ │ +//! │ │ │ │ +//! │ ▼ ▼ │ +//! │ [Persistence] [Decision Adjustment] │ +//! │ │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Example +//! +//! ```rust,ignore +//! use vectorless::retrieval::pilot::feedback::{FeedbackStore, FeedbackRecord}; +//! +//! let store = FeedbackStore::new("./feedback_store"); +//! +//! // Record feedback +//! let record = FeedbackRecord::new(decision_id, was_correct, confidence); +//! store.record(record).await?; +//! +//! // Learn from feedback +//! let learner = PilotLearner::new(store); +//! let adjustment = learner.get_adjustment(&context); +//! ``` + +use std::collections::HashMap; +use std::path::Path; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use serde::{Deserialize, Serialize}; +use tracing::{debug, info, warn}; + +use super::decision::InterventionPoint; + +/// Unique identifier for a feedback record. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct FeedbackId(pub u64); + +/// Unique identifier for a decision. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct DecisionId(pub u64); + +/// Feedback record for a Pilot decision. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeedbackRecord { + /// Unique feedback ID. + pub id: FeedbackId, + /// Associated decision ID. + pub decision_id: DecisionId, + /// Whether the decision was correct. + pub was_correct: bool, + /// Pilot's confidence at decision time. + pub pilot_confidence: f64, + /// Intervention point type. + pub intervention_point: InterventionPoint, + /// Query hash for grouping similar queries. + pub query_hash: u64, + /// Node path hash for context. + pub path_hash: u64, + /// Timestamp of feedback. + pub timestamp_ms: u64, + /// Optional user comment. + pub comment: Option, +} + +impl FeedbackRecord { + /// Create a new feedback record. + pub fn new( + decision_id: DecisionId, + was_correct: bool, + pilot_confidence: f64, + intervention_point: InterventionPoint, + query_hash: u64, + path_hash: u64, + ) -> Self { + static COUNTER: AtomicU64 = AtomicU64::new(1); + let id = FeedbackId(COUNTER.fetch_add(1, Ordering::Relaxed)); + let timestamp_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + Self { + id, + decision_id, + was_correct, + pilot_confidence, + intervention_point, + query_hash, + path_hash, + timestamp_ms, + comment: None, + } + } + + /// Add a comment to the feedback. + pub fn with_comment(mut self, comment: impl Into) -> Self { + self.comment = Some(comment.into()); + self + } +} + +/// Statistics for a specific context (query/path combination). +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ContextStats { + /// Total decisions in this context. + pub total: u64, + /// Correct decisions in this context. + pub correct: u64, + /// Average confidence when correct. + pub avg_confidence_correct: f64, + /// Average confidence when incorrect. + pub avg_confidence_incorrect: f64, +} + +impl ContextStats { + /// Get accuracy for this context. + pub fn accuracy(&self) -> f64 { + if self.total == 0 { + 0.0 + } else { + self.correct as f64 / self.total as f64 + } + } + + /// Record a new feedback. + fn record(&mut self, was_correct: bool, confidence: f64) { + self.total += 1; + if was_correct { + self.correct += 1; + // Running average + self.avg_confidence_correct = (self.avg_confidence_correct + * (self.correct - 1) as f64 + + confidence) + / self.correct as f64; + } else { + let incorrect = self.total - self.correct; + self.avg_confidence_incorrect = (self.avg_confidence_incorrect + * (incorrect - 1) as f64 + + confidence) + / incorrect as f64; + } + } +} + +/// Statistics for an intervention point type. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct InterventionStats { + /// Start intervention stats. + pub start: ContextStats, + /// Fork intervention stats. + pub fork: ContextStats, + /// Backtrack intervention stats. + pub backtrack: ContextStats, + /// Evaluate intervention stats. + pub evaluate: ContextStats, +} + +impl InterventionStats { + /// Get stats for a specific intervention point. + pub fn get(&self, point: InterventionPoint) -> &ContextStats { + match point { + InterventionPoint::Start => &self.start, + InterventionPoint::Fork => &self.fork, + InterventionPoint::Backtrack => &self.backtrack, + InterventionPoint::Evaluate => &self.evaluate, + } + } + + /// Get mutable stats for a specific intervention point. + fn get_mut(&mut self, point: InterventionPoint) -> &mut ContextStats { + match point { + InterventionPoint::Start => &mut self.start, + InterventionPoint::Fork => &mut self.fork, + InterventionPoint::Backtrack => &mut self.backtrack, + InterventionPoint::Evaluate => &mut self.evaluate, + } + } +} + +/// In-memory feedback store. +/// +/// Stores feedback records and provides statistics for learning. +/// Thread-safe for concurrent access. +#[derive(Debug)] +pub struct FeedbackStore { + /// All feedback records. + records: std::sync::RwLock>, + /// Statistics by intervention point. + intervention_stats: std::sync::RwLock, + /// Statistics by query hash. + query_stats: std::sync::RwLock>, + /// Statistics by path hash. + path_stats: std::sync::RwLock>, + /// Configuration. + config: FeedbackStoreConfig, +} + +/// Configuration for feedback store. +#[derive(Debug, Clone)] +pub struct FeedbackStoreConfig { + /// Maximum records to keep in memory. + pub max_records: usize, + /// Enable persistence to disk. + pub persist: bool, + /// Path for persistence. + pub storage_path: Option, +} + +impl Default for FeedbackStoreConfig { + fn default() -> Self { + Self { + max_records: 10_000, + persist: false, + storage_path: None, + } + } +} + +impl FeedbackStoreConfig { + /// Create config with persistence enabled. + pub fn with_persistence(path: impl Into) -> Self { + Self { + max_records: 10_000, + persist: true, + storage_path: Some(path.into()), + } + } +} + +impl FeedbackStore { + /// Create a new feedback store. + pub fn new(config: FeedbackStoreConfig) -> Self { + Self { + records: std::sync::RwLock::new(Vec::new()), + intervention_stats: std::sync::RwLock::new(InterventionStats::default()), + query_stats: std::sync::RwLock::new(HashMap::new()), + path_stats: std::sync::RwLock::new(HashMap::new()), + config, + } + } + + /// Create an in-memory store without persistence. + pub fn in_memory() -> Self { + Self::new(FeedbackStoreConfig::default()) + } + + /// Record a feedback. + pub fn record(&self, feedback: FeedbackRecord) { + // Update intervention stats + { + let mut stats = self.intervention_stats.write().unwrap(); + stats + .get_mut(feedback.intervention_point) + .record(feedback.was_correct, feedback.pilot_confidence); + } + + // Update query stats + { + let mut stats = self.query_stats.write().unwrap(); + stats + .entry(feedback.query_hash) + .or_default() + .record(feedback.was_correct, feedback.pilot_confidence); + } + + // Update path stats + { + let mut stats = self.path_stats.write().unwrap(); + stats + .entry(feedback.path_hash) + .or_default() + .record(feedback.was_correct, feedback.pilot_confidence); + } + + // Store record + { + let mut records = self.records.write().unwrap(); + records.push(feedback); + + // Enforce max records limit + if records.len() > self.config.max_records { + let remove_count = records.len() - self.config.max_records; + records.drain(0..remove_count); + } + } + + debug!( + total_records = self.records.read().unwrap().len(), + "Recorded feedback" + ); + } + + /// Get overall intervention statistics. + pub fn intervention_stats(&self) -> InterventionStats { + self.intervention_stats.read().unwrap().clone() + } + + /// Get statistics for a specific query hash. + pub fn query_stats(&self, query_hash: u64) -> Option { + self.query_stats.read().unwrap().get(&query_hash).cloned() + } + + /// Get statistics for a specific path hash. + pub fn path_stats(&self, path_hash: u64) -> Option { + self.path_stats.read().unwrap().get(&path_hash).cloned() + } + + /// Get total number of feedback records. + pub fn total_records(&self) -> usize { + self.records.read().unwrap().len() + } + + /// Get overall accuracy across all feedback. + pub fn overall_accuracy(&self) -> f64 { + let stats = self.intervention_stats.read().unwrap(); + let total = stats.start.total + + stats.fork.total + + stats.backtrack.total + + stats.evaluate.total; + let correct = stats.start.correct + + stats.fork.correct + + stats.backtrack.correct + + stats.evaluate.correct; + + if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + } + } + + /// Clear all feedback records. + pub fn clear(&self) { + self.records.write().unwrap().clear(); + *self.intervention_stats.write().unwrap() = InterventionStats::default(); + self.query_stats.write().unwrap().clear(); + self.path_stats.write().unwrap().clear(); + } + + /// Persist feedback to disk (if configured). + pub fn persist(&self) -> std::io::Result<()> { + if !self.config.persist { + return Ok(()); + } + + let path = self.config.storage_path.as_ref().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::NotFound, "No storage path configured") + })?; + + let records = self.records.read().unwrap(); + let json = serde_json::to_string_pretty(&*records)?; + std::fs::write(path, json)?; + + info!(path = %path, records = records.len(), "Persisted feedback store"); + Ok(()) + } + + /// Load feedback from disk (if configured). + pub fn load(&self) -> std::io::Result<()> { + if !self.config.persist { + return Ok(()); + } + + let path = self.config.storage_path.as_ref().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::NotFound, "No storage path configured") + })?; + + if !Path::new(path).exists() { + return Ok(()); + } + + let json = std::fs::read_to_string(path)?; + let records: Vec = serde_json::from_str(&json)?; + + // Rebuild stats from records + for record in &records { + // Update intervention stats + self.intervention_stats + .write() + .unwrap() + .get_mut(record.intervention_point) + .record(record.was_correct, record.pilot_confidence); + + // Update query stats + self.query_stats + .write() + .unwrap() + .entry(record.query_hash) + .or_default() + .record(record.was_correct, record.pilot_confidence); + + // Update path stats + self.path_stats + .write() + .unwrap() + .entry(record.path_hash) + .or_default() + .record(record.was_correct, record.pilot_confidence); + } + + *self.records.write().unwrap() = records; + + info!(path = %path, "Loaded feedback store"); + Ok(()) + } +} + +/// Decision adjustment based on learned feedback. +#[derive(Debug, Clone, Copy)] +pub struct DecisionAdjustment { + /// Confidence adjustment (add to pilot confidence). + pub confidence_delta: f64, + /// Whether to skip intervention (algorithm is confident). + pub skip_intervention: bool, + /// Weight to apply to algorithm score vs LLM score. + pub algorithm_weight: f64, +} + +impl Default for DecisionAdjustment { + fn default() -> Self { + Self { + confidence_delta: 0.0, + skip_intervention: false, + algorithm_weight: 0.5, + } + } +} + +/// Pilot learner that adjusts decisions based on feedback. +/// +/// Uses collected feedback to: +/// 1. Adjust confidence thresholds for different intervention points +/// 2. Decide when to skip intervention (trust algorithm) +/// 3. Adjust the weight between algorithm and LLM scores +#[derive(Debug)] +pub struct PilotLearner { + /// Feedback store reference. + store: Arc, + /// Learning configuration. + config: LearnerConfig, +} + +/// Configuration for the pilot learner. +#[derive(Debug, Clone)] +pub struct LearnerConfig { + /// Minimum samples required before adjusting. + pub min_samples: u64, + /// Threshold for high accuracy (trust LLM more). + pub high_accuracy_threshold: f64, + /// Threshold for low accuracy (trust algorithm more). + pub low_accuracy_threshold: f64, + /// Maximum confidence adjustment. + pub max_confidence_delta: f64, +} + +impl Default for LearnerConfig { + fn default() -> Self { + Self { + min_samples: 10, + high_accuracy_threshold: 0.8, + low_accuracy_threshold: 0.5, + max_confidence_delta: 0.2, + } + } +} + +impl PilotLearner { + /// Create a new learner with the given feedback store. + pub fn new(store: Arc) -> Self { + Self { + store, + config: LearnerConfig::default(), + } + } + + /// Create a learner with custom configuration. + pub fn with_config(store: Arc, config: LearnerConfig) -> Self { + Self { store, config } + } + + /// Get decision adjustment for a given context. + pub fn get_adjustment( + &self, + intervention_point: InterventionPoint, + query_hash: u64, + path_hash: u64, + ) -> DecisionAdjustment { + let mut adjustment = DecisionAdjustment::default(); + + // Get intervention-level stats + let intervention_stats = self.store.intervention_stats(); + let point_stats = intervention_stats.get(intervention_point); + + // Not enough samples, use defaults + if point_stats.total < self.config.min_samples { + return adjustment; + } + + let accuracy = point_stats.accuracy(); + + // Adjust based on accuracy + if accuracy >= self.config.high_accuracy_threshold { + // High accuracy: trust LLM more + adjustment.confidence_delta = self.config.max_confidence_delta; + adjustment.algorithm_weight = 0.3; // Favor LLM + } else if accuracy <= self.config.low_accuracy_threshold { + // Low accuracy: trust algorithm more + adjustment.confidence_delta = -self.config.max_confidence_delta; + adjustment.algorithm_weight = 0.7; // Favor algorithm + adjustment.skip_intervention = accuracy < 0.3; // Very low accuracy, skip LLM + } + + // Further refine based on query-specific stats + if let Some(query_stats) = self.store.query_stats(query_hash) { + if query_stats.total >= self.config.min_samples { + let query_accuracy = query_stats.accuracy(); + // Adjust confidence based on query-specific performance + if query_accuracy > accuracy { + adjustment.confidence_delta += 0.05; + } else if query_accuracy < accuracy { + adjustment.confidence_delta -= 0.05; + } + } + } + + // Further refine based on path-specific stats + if let Some(path_stats) = self.store.path_stats(path_hash) { + if path_stats.total >= self.config.min_samples { + let path_accuracy = path_stats.accuracy(); + // If this path has very high accuracy, increase confidence + if path_accuracy > 0.9 { + adjustment.confidence_delta += 0.05; + } + } + } + + // Clamp confidence delta + adjustment.confidence_delta = adjustment + .confidence_delta + .clamp(-self.config.max_confidence_delta, self.config.max_confidence_delta); + + adjustment + } + + /// Get the feedback store. + pub fn store(&self) -> &FeedbackStore { + &self.store + } + + /// Get overall accuracy. + pub fn overall_accuracy(&self) -> f64 { + self.store.overall_accuracy() + } + + /// Check if enough feedback has been collected. + pub fn has_sufficient_data(&self) -> bool { + let stats = self.store.intervention_stats(); + let total = stats.start.total + + stats.fork.total + + stats.backtrack.total + + stats.evaluate.total; + total >= self.config.min_samples + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_hash(s: &str) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + s.hash(&mut hasher); + hasher.finish() + } + + #[test] + fn test_feedback_record_creation() { + let record = FeedbackRecord::new( + DecisionId(1), + true, + 0.85, + InterventionPoint::Fork, + make_hash("test query"), + make_hash("/root/child"), + ); + + assert!(record.was_correct); + assert!((record.pilot_confidence - 0.85).abs() < 0.01); + assert!(record.comment.is_none()); + } + + #[test] + fn test_feedback_record_with_comment() { + let record = FeedbackRecord::new( + DecisionId(1), + false, + 0.5, + InterventionPoint::Start, + make_hash("test"), + make_hash("/"), + ) + .with_comment("Wrong direction"); + + assert!(!record.was_correct); + assert_eq!(record.comment, Some("Wrong direction".to_string())); + } + + #[test] + fn test_feedback_store_recording() { + let store = FeedbackStore::in_memory(); + + // Record some feedback + store.record(FeedbackRecord::new( + DecisionId(1), + true, + 0.9, + InterventionPoint::Fork, + make_hash("query1"), + make_hash("/path1"), + )); + + store.record(FeedbackRecord::new( + DecisionId(2), + false, + 0.6, + InterventionPoint::Fork, + make_hash("query1"), + make_hash("/path1"), + )); + + store.record(FeedbackRecord::new( + DecisionId(3), + true, + 0.8, + InterventionPoint::Start, + make_hash("query2"), + make_hash("/"), + )); + + assert_eq!(store.total_records(), 3); + + let stats = store.intervention_stats(); + assert_eq!(stats.fork.total, 2); + assert_eq!(stats.fork.correct, 1); + assert!((stats.fork.accuracy() - 0.5).abs() < 0.01); + + assert_eq!(stats.start.total, 1); + assert_eq!(stats.start.correct, 1); + } + + #[test] + fn test_pilot_learner_adjustment() { + let store = Arc::new(FeedbackStore::in_memory()); + let learner = PilotLearner::new(store.clone()); + + // Not enough data, should return default + let adj = learner.get_adjustment(InterventionPoint::Fork, 0, 0); + assert!((adj.confidence_delta - 0.0).abs() < 0.01); + assert!(!adj.skip_intervention); + + // Add enough feedback with high accuracy + for i in 0..15 { + store.record(FeedbackRecord::new( + DecisionId(i), + true, // All correct + 0.9, + InterventionPoint::Fork, + make_hash("query"), + make_hash("/path"), + )); + } + + // Now should adjust + let adj = learner.get_adjustment(InterventionPoint::Fork, make_hash("query"), 0); + assert!(adj.confidence_delta > 0.0); // Should boost confidence + assert!((adj.algorithm_weight - 0.3).abs() < 0.01); // Should favor LLM + } + + #[test] + fn test_pilot_learner_low_accuracy() { + let store = Arc::new(FeedbackStore::in_memory()); + let learner = PilotLearner::new(store.clone()); + + // Add enough feedback with low accuracy + for i in 0..15 { + store.record(FeedbackRecord::new( + DecisionId(i), + i % 3 == 0, // Only ~33% correct + 0.5, + InterventionPoint::Fork, + 0, + 0, + )); + } + + let adj = learner.get_adjustment(InterventionPoint::Fork, 0, 0); + assert!(adj.confidence_delta < 0.0); // Should reduce confidence + assert!(adj.algorithm_weight > 0.5); // Should favor algorithm + } + + #[test] + fn test_context_stats() { + let mut stats = ContextStats::default(); + + stats.record(true, 0.9); + stats.record(true, 0.8); + stats.record(false, 0.6); + + assert_eq!(stats.total, 3); + assert_eq!(stats.correct, 2); + assert!((stats.accuracy() - 0.666).abs() < 0.01); + assert!((stats.avg_confidence_correct - 0.85).abs() < 0.01); + assert!((stats.avg_confidence_incorrect - 0.6).abs() < 0.01); + } +} diff --git a/src/retrieval/pilot/llm_pilot.rs b/src/retrieval/pilot/llm_pilot.rs index 414bbb3e..6489510e 100644 --- a/src/retrieval/pilot/llm_pilot.rs +++ b/src/retrieval/pilot/llm_pilot.rs @@ -11,7 +11,8 @@ use std::sync::Arc; use tracing::{debug, info, warn}; use crate::document::DocumentTree; -use crate::llm::LlmClient; +use crate::llm::{LlmClient, LlmExecutor}; +use crate::throttle::ConcurrencyController; use super::budget::BudgetController; use super::builder::ContextBuilder; @@ -37,10 +38,10 @@ use super::r#trait::{Pilot, SearchState}; /// │ │ Builder │─▶│ Builder │─▶│ Parser │ │ /// │ └─────────────┘ └─────────────┘ └─────────────┘ │ /// │ │ -/// │ ┌─────────────┐ ┌─────────────┐ │ -/// │ │ Budget │ │ LLM │ │ -/// │ │ Controller │ │ Client │ │ -/// │ └─────────────┘ └─────────────┘ │ +/// │ ┌─────────────┐ ┌───────────────────────┐ │ +/// │ │ Budget │ │ LlmExecutor │ │ +/// │ │ Controller │ │ (throttle+retry+fall) │ │ +/// │ └─────────────┘ └───────────────────────┘ │ /// └─────────────────────────────────────────────────────────────┘ /// ``` /// @@ -48,19 +49,25 @@ use super::r#trait::{Pilot, SearchState}; /// /// ```rust,ignore /// use vectorless::retrieval::pilot::{LlmPilot, PilotConfig}; -/// use vectorless::llm::LlmClient; +/// use vectorless::llm::{LlmClient, LlmExecutor}; /// /// let client = LlmClient::for_model("gpt-4o-mini"); /// let pilot = LlmPilot::new(client, PilotConfig::default()); /// +/// // Or with executor for unified throttle/retry/fallback +/// let executor = LlmExecutor::for_model("gpt-4o-mini"); +/// let pilot = LlmPilot::with_executor(executor, PilotConfig::default()); +/// /// // Use in search /// if pilot.should_intervene(&state) { /// let decision = pilot.decide(&state).await; /// } /// ``` pub struct LlmPilot { - /// LLM client for making requests. + /// LLM client for making requests (fallback when no executor). client: LlmClient, + /// LLM executor with unified throttle/retry/fallback (optional). + executor: Option>, /// Pilot configuration. config: PilotConfig, /// Budget controller. @@ -90,6 +97,42 @@ impl LlmPilot { Self { client, + executor: None, + config, + budget, + context_builder: ContextBuilder::new(token_budget), + prompt_builder: PromptBuilder::new(), + response_parser: ResponseParser::new(), + } + } + + /// Create a Pilot with LlmExecutor for unified throttle/retry/fallback. + pub fn with_executor(executor: LlmExecutor, config: PilotConfig) -> Self { + let budget = BudgetController::new(config.budget.clone()); + let token_budget = config.budget.max_tokens_per_call; + // Create a fallback client for backwards compatibility + let client = LlmClient::for_model(&executor.config().model); + + Self { + client, + executor: Some(Arc::new(executor)), + config, + budget, + context_builder: ContextBuilder::new(token_budget), + prompt_builder: PromptBuilder::new(), + response_parser: ResponseParser::new(), + } + } + + /// Create a Pilot with shared executor (for sharing throttle/fallback across pilots). + pub fn with_shared_executor(executor: Arc, config: PilotConfig) -> Self { + let budget = BudgetController::new(config.budget.clone()); + let token_budget = config.budget.max_tokens_per_call; + let client = LlmClient::for_model(&executor.config().model); + + Self { + client, + executor: Some(executor), config, budget, context_builder: ContextBuilder::new(token_budget), @@ -109,6 +152,7 @@ impl LlmPilot { Self { client, + executor: None, config, budget, context_builder, @@ -117,6 +161,17 @@ impl LlmPilot { } } + /// Add an executor to an existing pilot. + pub fn with_executor_mut(mut self, executor: LlmExecutor) -> Self { + self.executor = Some(Arc::new(executor)); + self + } + + /// Check if using LlmExecutor (unified throttle/retry/fallback). + pub fn has_executor(&self) -> bool { + self.executor.is_some() + } + /// Check if budget allows LLM calls. fn has_budget(&self) -> bool { self.budget.can_call() @@ -167,8 +222,16 @@ impl LlmPilot { point, prompt.estimated_tokens ); - // Make LLM call - match self.client.complete(&prompt.system, &prompt.user).await { + // Make LLM call - use executor if available, otherwise use client directly + let result = if let Some(ref executor) = self.executor { + // Use LlmExecutor for unified throttle/retry/fallback + executor.complete(&prompt.system, &prompt.user).await + } else { + // Fallback to direct client call + self.client.complete(&prompt.system, &prompt.user).await + }; + + match result { Ok(response) => { // Record usage (estimate output tokens) let output_tokens = self.estimate_tokens(&response); diff --git a/src/retrieval/pilot/mod.rs b/src/retrieval/pilot/mod.rs index 0a3a9a61..87488e41 100644 --- a/src/retrieval/pilot/mod.rs +++ b/src/retrieval/pilot/mod.rs @@ -56,6 +56,7 @@ mod builder; mod config; mod decision; mod fallback; +mod feedback; mod llm_pilot; mod metrics; mod noop; @@ -68,6 +69,10 @@ pub use builder::{ContextBuilder, PilotContext, TokenBudget}; pub use config::{BudgetConfig, InterventionConfig, PilotConfig, PilotMode}; pub use decision::{InterventionPoint, PilotDecision, RankedCandidate, SearchDirection}; pub use fallback::{FallbackAction, FallbackConfig, FallbackError, FallbackLevel, FallbackManager}; +pub use feedback::{ + ContextStats, DecisionAdjustment, DecisionId, FeedbackId, FeedbackRecord, FeedbackStore, + FeedbackStoreConfig, InterventionStats, LearnerConfig, PilotLearner, +}; pub use llm_pilot::LlmPilot; pub use metrics::{CallRecord, MetricsCollector, PilotMetrics}; pub use noop::NoopPilot; diff --git a/src/retrieval/search/scorer.rs b/src/retrieval/search/scorer.rs index 0d051938..72080a6b 100644 --- a/src/retrieval/search/scorer.rs +++ b/src/retrieval/search/scorer.rs @@ -1,12 +1,77 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Node scoring utilities. +//! Node scoring utilities with BM25 support. //! //! Implements the NodeScore formula: `Σ ChunkScore(n) / √(N+1)` +//! with optional BM25 scoring for better relevance ranking. + +use std::collections::HashMap; use crate::document::{DocumentTree, NodeId}; +/// Common English stop words for keyword filtering. +const STOPWORDS: &[&str] = &[ + "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "do", "does", "did", "will", "would", "could", + "should", "may", "might", "must", "shall", "can", "need", "dare", + "ought", "used", "to", "of", "in", "for", "on", "with", "at", "by", + "from", "as", "into", "through", "during", "before", "after", "above", + "below", "between", "under", "again", "further", "then", "once", + "here", "there", "when", "where", "why", "how", "all", "each", "few", + "more", "most", "other", "some", "such", "no", "nor", "not", "only", + "own", "same", "so", "than", "too", "very", "just", "and", "but", + "if", "or", "because", "until", "while", "about", "what", "which", + "who", "whom", "this", "that", "these", "those", "i", "me", "my", + "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", + "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", + "hers", "herself", "it", "its", "itself", "they", "them", "their", + "theirs", "themselves", +]; + +/// Extract keywords from a query string, filtering stop words. +fn extract_keywords(query: &str) -> Vec { + query + .to_lowercase() + .split(|c: char| !c.is_alphanumeric()) + .filter(|s| { + let s = *s; + !s.is_empty() && s.len() > 1 && !STOPWORDS.contains(&s) + }) + .map(String::from) + .collect() +} + +/// Scoring strategy to use. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ScoringStrategy { + /// Keyword overlap only (fastest). + KeywordOnly, + /// BM25 only (better relevance). + #[default] + BM25, + /// Hybrid: weighted combination of keyword + BM25. + Hybrid, +} + +/// BM25 parameters. +#[derive(Debug, Clone, Copy)] +pub struct Bm25Params { + /// Term frequency saturation parameter (k1). + pub k1: f32, + /// Length normalization parameter (b). + pub b: f32, +} + +impl Default for Bm25Params { + fn default() -> Self { + Self { + k1: 1.2, + b: 0.75, + } + } +} + /// Context for scoring calculations. #[derive(Debug, Clone)] pub struct ScoringContext { @@ -20,6 +85,16 @@ pub struct ScoringContext { pub content_weight: f32, /// Depth penalty factor. pub depth_penalty: f32, + /// Scoring strategy. + pub strategy: ScoringStrategy, + /// BM25 parameters. + pub bm25_params: Bm25Params, + /// Average document length for BM25. + pub avg_doc_len: f32, + /// Document frequency for terms (for IDF). + pub doc_freq: HashMap, + /// Total document count for IDF. + pub doc_count: usize, } impl Default for ScoringContext { @@ -30,6 +105,11 @@ impl Default for ScoringContext { summary_weight: 1.5, content_weight: 1.0, depth_penalty: 0.1, + strategy: ScoringStrategy::default(), + bm25_params: Bm25Params::default(), + avg_doc_len: 100.0, + doc_freq: HashMap::new(), + doc_count: 1, } } } @@ -38,21 +118,100 @@ impl ScoringContext { /// Create a new scoring context with query terms. pub fn new(query: &str) -> Self { Self { - query_terms: query - .to_lowercase() - .split_whitespace() - .map(|s| s.to_string()) - .collect(), + query_terms: extract_keywords(query), ..Default::default() } } + /// Create a context with a specific scoring strategy. + pub fn with_strategy(query: &str, strategy: ScoringStrategy) -> Self { + Self { + query_terms: extract_keywords(query), + strategy, + ..Default::default() + } + } + + /// Set BM25 parameters. + pub fn with_bm25_params(mut self, params: Bm25Params) -> Self { + self.bm25_params = params; + self + } + + /// Set document statistics for BM25. + pub fn with_doc_stats( + mut self, + doc_count: usize, + avg_doc_len: f32, + doc_freq: HashMap, + ) -> Self { + self.doc_count = doc_count.max(1); + self.avg_doc_len = avg_doc_len.max(1.0); + self.doc_freq = doc_freq; + self + } + + /// Calculate term frequency in text. + fn term_frequency(&self, text: &str, term: &str) -> f32 { + text.to_lowercase().matches(term).count() as f32 + } + + /// Calculate IDF (Inverse Document Frequency) for a term. + fn idf(&self, term: &str) -> f32 { + let df = self.doc_freq.get(term).copied().unwrap_or(1) as f32; + let n = self.doc_count as f32; + ((n - df + 0.5) / (df + 0.5) + 1.0).ln() + } + + /// Calculate BM25 score for a single field. + fn bm25_field_score(&self, text: &str) -> f32 { + if self.query_terms.is_empty() { + return 0.0; + } + + let doc_len = text.split_whitespace().count() as f32; + let k1 = self.bm25_params.k1; + let b = self.bm25_params.b; + + let mut score = 0.0; + for term in &self.query_terms { + let tf = self.term_frequency(text, term); + if tf == 0.0 { + continue; + } + + let idf = self.idf(term); + let numerator = tf * (k1 + 1.0); + let denominator = tf + k1 * (1.0 - b + b * doc_len / self.avg_doc_len); + + score += idf * numerator / denominator; + } + + score + } + + /// Calculate keyword overlap score for a text. + fn keyword_overlap(&self, text: &str) -> f32 { + if self.query_terms.is_empty() { + return 0.0; + } + + let text_lower = text.to_lowercase(); + let matches = self + .query_terms + .iter() + .filter(|term| text_lower.contains(term.as_str())) + .count(); + + matches as f32 / self.query_terms.len() as f32 + } + /// Calculate a quick keyword-based score for a node. pub fn quick_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 { if let Some(node) = tree.get(node_id) { - let title_score = self.term_overlap(&node.title); - let summary_score = self.term_overlap(&node.summary); - let content_score = self.term_overlap(&node.content); + let title_score = self.keyword_overlap(&node.title); + let summary_score = self.keyword_overlap(&node.summary); + let content_score = self.keyword_overlap(&node.content); let base_score = (title_score * self.title_weight + summary_score * self.summary_weight @@ -68,20 +227,44 @@ impl ScoringContext { } } - /// Calculate term overlap between query and text. - fn term_overlap(&self, text: &str) -> f32 { - if self.query_terms.is_empty() { - return 0.0; + /// Calculate BM25 score for a node. + pub fn bm25_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 { + if let Some(node) = tree.get(node_id) { + let title_score = self.bm25_field_score(&node.title) * self.title_weight; + let summary_score = self.bm25_field_score(&node.summary) * self.summary_weight; + let content_score = self.bm25_field_score(&node.content) * self.content_weight; + + let total_score = title_score + summary_score + content_score; + + // Normalize to [0, 1] range + let max_possible = self.query_terms.len() as f32 * 10.0; // Rough upper bound + let normalized = (total_score / max_possible).clamp(0.0, 1.0); + + // Apply depth penalty + let depth_factor = 1.0 - (node.depth as f32 * self.depth_penalty).min(0.5); + + normalized * depth_factor + } else { + 0.0 } + } - let text_lower = text.to_lowercase(); - let matches = self - .query_terms - .iter() - .filter(|term| text_lower.contains(term.as_str())) - .count(); + /// Calculate hybrid score (keyword + BM25). + pub fn hybrid_score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 { + let keyword = self.quick_score(tree, node_id); + let bm25 = self.bm25_score(tree, node_id); - matches as f32 / self.query_terms.len() as f32 + // Weighted combination: 40% keyword, 60% BM25 + keyword * 0.4 + bm25 * 0.6 + } + + /// Calculate score based on configured strategy. + pub fn score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 { + match self.strategy { + ScoringStrategy::KeywordOnly => self.quick_score(tree, node_id), + ScoringStrategy::BM25 => self.bm25_score(tree, node_id), + ScoringStrategy::Hybrid => self.hybrid_score(tree, node_id), + } } } @@ -97,9 +280,29 @@ impl NodeScorer { Self { context } } + /// Create a scorer with default context for a query. + pub fn for_query(query: &str) -> Self { + Self::new(ScoringContext::new(query)) + } + + /// Create a scorer with a specific strategy. + pub fn with_strategy(query: &str, strategy: ScoringStrategy) -> Self { + Self::new(ScoringContext::with_strategy(query, strategy)) + } + + /// Get the scoring context. + pub fn context(&self) -> &ScoringContext { + &self.context + } + + /// Get mutable scoring context. + pub fn context_mut(&mut self) -> &mut ScoringContext { + &mut self.context + } + /// Score a single node. pub fn score(&self, tree: &DocumentTree, node_id: NodeId) -> f32 { - self.context.quick_score(tree, node_id) + self.context.score(tree, node_id) } /// Score multiple nodes and return sorted by score (descending). @@ -117,7 +320,7 @@ impl NodeScorer { /// /// Used in the NodeScore formula. pub fn chunk_score(&self, chunk: &str) -> f32 { - self.context.term_overlap(chunk) + self.context.keyword_overlap(chunk) } /// Calculate the full NodeScore using the formula: @@ -150,3 +353,64 @@ impl NodeScorer { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_keywords() { + let keywords = extract_keywords("What is the architecture of vectorless?"); + assert!(keywords.contains(&"architecture".to_string())); + assert!(keywords.contains(&"vectorless".to_string())); + assert!(!keywords.contains(&"what".to_string())); // stopword + assert!(!keywords.contains(&"the".to_string())); // stopword + } + + #[test] + fn test_keyword_overlap() { + let ctx = ScoringContext::new("vectorless architecture"); + + let text = "Vectorless has a unique architecture for document retrieval."; + let score = ctx.keyword_overlap(text); + + assert!(score > 0.5); // Should match both keywords + } + + #[test] + fn test_bm25_scoring() { + let ctx = ScoringContext::with_strategy("rust cargo", ScoringStrategy::BM25); + + let text = "Rust is a programming language. Cargo is its package manager. Rust Rust Rust."; + let score = ctx.bm25_field_score(text); + + // Should have higher score due to term frequency + assert!(score > 0.0); + } + + #[test] + fn test_hybrid_scoring() { + let ctx = ScoringContext::with_strategy("test query", ScoringStrategy::Hybrid); + + let keyword_score = ctx.keyword_overlap("test query content"); + let bm25_score = ctx.bm25_field_score("test query content"); + let hybrid = ctx.keyword_overlap("test query content") * 0.4 + + ctx.bm25_field_score("test query content") * 0.6; + + // Hybrid should be between keyword and bm25 scores (roughly) + assert!(hybrid > 0.0); + } + + #[test] + fn test_scorer_creation() { + let scorer = NodeScorer::for_query("test query"); + assert!(!scorer.context().query_terms.is_empty()); + } + + #[test] + fn test_scorer_with_strategy() { + let scorer = NodeScorer::with_strategy("test", ScoringStrategy::BM25); + assert_eq!(scorer.context().strategy, ScoringStrategy::BM25); + } +} + diff --git a/vectorless.example.toml b/vectorless.example.toml index aa097ae6..505f0fb5 100644 --- a/vectorless.example.toml +++ b/vectorless.example.toml @@ -1,228 +1,220 @@ # Vectorless Configuration Example -# Copy this file to config.toml and fill in your API keys +# Copy this file to vectorless.toml and fill in your API keys # # All configuration is loaded from this file only. # No environment variables are used - this ensures explicit, traceable configuration. -[indexer] -# Word count threshold for splitting sections into subsections -subsection_threshold = 300 - -# Maximum tokens to send in a single segmentation request -max_segment_tokens = 3000 +# ============================================================================ +# LLM Configuration (Unified) +# ============================================================================ -# Maximum tokens for each summary -max_summary_tokens = 200 - -# Minimum content tokens required to generate a summary -min_summary_tokens = 20 - -[summary] -# API key - get from your provider +[llm] +# Default API key (can be overridden per client) # api_key = "sk-..." -# API endpoint -# OpenAI: https://api.openai.com/v1 -# ZAI General: https://api.z.ai/api/paas/v4 -# ZAI Coding: https://api.z.ai/api/coding/paas/v4 +# Summary client - used for generating document summaries +[llm.pool.summary] +model = "gpt-4o-mini" +endpoint = "https://api.openai.com/v1" +max_tokens = 200 +temperature = 0.0 + +# Retrieval client - used for navigation decisions +[llm.pool.retrieval] +model = "gpt-4o" endpoint = "https://api.openai.com/v1" +max_tokens = 100 +temperature = 0.0 -# Model for summarization (use cheaper models for indexing) +# Pilot client - used for intelligent navigation +[llm.pool.pilot] model = "gpt-4o-mini" +endpoint = "https://api.openai.com/v1" +max_tokens = 300 +temperature = 0.0 -# Maximum tokens for summary generation -max_tokens = 200 +# Retry configuration +[llm.retry] +max_attempts = 3 +initial_delay_ms = 500 +max_delay_ms = 30000 +multiplier = 2.0 +retry_on_rate_limit = true -# Temperature for summary generation -temperature = 0.0 +# Throttle/rate limiting configuration +[llm.throttle] +max_concurrent_requests = 10 +requests_per_minute = 500 +enabled = true +semaphore_enabled = true -[retrieval] -# API key (optional, defaults to summary.api_key) -# api_key = "sk-..." +# Fallback configuration +[llm.fallback] +enabled = true +models = ["gpt-4o-mini", "glm-4-flash"] +# endpoints = [ +# "https://api.openai.com/v1", +# "https://api.z.ai/api/paas/v4" +# ] +on_rate_limit = "retry_then_fallback" +on_timeout = "retry_then_fallback" +on_all_failed = "return_error" -# API endpoint for retrieval -endpoint = "https://api.openai.com/v1" +# ============================================================================ +# Metrics Configuration (Unified) +# ============================================================================ -# Model for retrieval navigation (use smarter models for better results) -model = "gpt-4o" +[metrics] +enabled = true +storage_path = "./workspace/metrics" +retention_days = 30 + +[metrics.llm] +track_tokens = true +track_latency = true +track_cost = true +cost_per_1k_input_tokens = 0.00015 # gpt-4o-mini +cost_per_1k_output_tokens = 0.0006 + +[metrics.pilot] +track_decisions = true +track_accuracy = true +track_feedback = true + +[metrics.retrieval] +track_paths = true +track_scores = true +track_iterations = true +track_cache = true + +# ============================================================================ +# Pilot Configuration +# ============================================================================ + +[pilot] +mode = "Balanced" # Aggressive | Balanced | Conservative | AlgorithmOnly +guide_at_start = true +guide_at_backtrack = true + +[pilot.budget] +max_tokens_per_query = 2000 +max_tokens_per_call = 500 +max_calls_per_query = 5 +max_calls_per_level = 2 +hard_limit = true + +[pilot.intervention] +fork_threshold = 3 +score_gap_threshold = 0.15 +low_score_threshold = 0.3 +max_interventions_per_level = 2 + +[pilot.feedback] +enabled = true +storage_path = "./workspace/feedback" +learning_rate = 0.1 +min_samples_for_learning = 10 -# Number of top results to return -top_k = 3 +# ============================================================================ +# Retrieval Configuration +# ============================================================================ -# Maximum tokens for retrieval context +[retrieval] +model = "gpt-4o" +endpoint = "https://api.openai.com/v1" +top_k = 3 max_tokens = 1000 - -# Temperature for retrieval temperature = 0.0 -# Search algorithm configuration [retrieval.search] -# Number of top-k results to return top_k = 5 - -# Beam width for multi-path search beam_width = 3 - -# Maximum iterations for search algorithms max_iterations = 10 - -# Minimum score to include a path min_score = 0.1 -# Sufficiency checker configuration [retrieval.sufficiency] -# Minimum tokens for sufficiency min_tokens = 500 - -# Target tokens for full sufficiency target_tokens = 2000 - -# Maximum tokens before stopping max_tokens = 4000 - -# Minimum content length (characters) min_content_length = 200 - -# Confidence threshold for LLM judge confidence_threshold = 0.7 -# Cache configuration [retrieval.cache] -# Maximum number of cache entries max_entries = 1000 - -# Time-to-live for cache entries (seconds) ttl_secs = 3600 -# Strategy-specific configuration [retrieval.strategy] -# MCTS exploration weight (sqrt(2) ≈ 1.414) exploration_weight = 1.414 - -# Semantic similarity threshold similarity_threshold = 0.5 - -# High similarity threshold for "answer" decision high_similarity_threshold = 0.8 - -# Low similarity threshold for "explore" decision low_similarity_threshold = 0.3 -# Content aggregator configuration -# Controls how retrieved content is aggregated and returned [retrieval.content] -# Enable/disable content aggregator -# When disabled, uses simple content collection (legacy behavior) enabled = true - -# Maximum tokens for aggregated content token_budget = 4000 - -# Minimum relevance score threshold (0.0 - 1.0) -# Content below this threshold will be filtered out min_relevance_score = 0.2 - -# Scoring strategy: "keyword_only" | "keyword_bm25" | "hybrid" -# - keyword_only: Fast keyword matching (no BM25) -# - keyword_bm25: Keyword + BM25 scoring (recommended) -# - hybrid: Keyword + LLM reranking (most accurate, slower) scoring_strategy = "keyword_bm25" - -# Output format: "markdown" | "json" | "tree" | "flat" -# - markdown: Structured markdown with headers (default) -# - json: JSON format for programmatic use -# - tree: Tree structure preserving hierarchy -# - flat: Flat text format output_format = "markdown" - -# Include relevance scores in output (useful for debugging) include_scores = false - -# Minimum budget allocation per depth level (0.0 - 1.0) -# Ensures each tree level gets representation hierarchical_min_per_level = 0.1 - -# Enable content deduplication deduplicate = true - -# Similarity threshold for deduplication (0.0 - 1.0) -# Higher = more aggressive deduplication dedup_threshold = 0.9 +# ============================================================================ +# Multi-turn Retrieval Configuration +# ============================================================================ + +[retrieval.multiturn] +enabled = true +max_sub_queries = 3 +decomposition_model = "gpt-4o-mini" +aggregation_strategy = "merge" # merge | rank | synthesize + +# ============================================================================ +# Storage Configuration +# ============================================================================ + [storage] -# Workspace directory for persisted documents -# -# Structure: -# workspace/ -# ├── _meta.json # Lightweight index -# ├── {doc_id_1}.json # Document 1 -# └── {doc_id_2}.json # Document 2 workspace_dir = "./workspace" - -# LRU cache size (number of documents to keep in memory) cache_size = 100 - -# Enable atomic writes (temp file + rename) -# This prevents data corruption on crash atomic_writes = true - -# Enable file locking for multi-process safety -# Prevents concurrent access from multiple processes file_lock = true - -# Enable checksum verification for data integrity -# Uses SHA-256 to verify file integrity on load checksum_enabled = true -# Compression settings [storage.compression] -# Enable compression for stored documents enabled = false - -# Compression algorithm: "gzip" or "zstd" algorithm = "gzip" - -# Compression level (1-9, higher = better compression but slower) level = 6 +# ============================================================================ +# Indexer Configuration +# ============================================================================ + +[indexer] +subsection_threshold = 300 +max_segment_tokens = 3000 +max_summary_tokens = 200 +min_summary_tokens = 20 + +# ============================================================================ +# Legacy Configuration (deprecated, use llm.* instead) +# ============================================================================ + +[summary] +model = "gpt-4o-mini" +endpoint = "https://api.openai.com/v1" +max_tokens = 200 +temperature = 0.0 + [concurrency] -# Maximum concurrent LLM API calls -# This limits how many requests can be in-flight at the same time max_concurrent_requests = 10 - -# Rate limit: requests per minute -# This is a soft limit using token bucket algorithm requests_per_minute = 500 - -# Enable rate limiting (token bucket) enabled = true - -# Enable semaphore-based concurrency limiting semaphore_enabled = true [fallback] -# Enable graceful degradation when LLM calls fail enabled = true - -# Fallback models in priority order -# When primary model fails, system tries these in order models = ["gpt-4o-mini", "glm-4-flash"] - -# Fallback endpoints (optional) -# When primary endpoint fails, system tries these in order -# endpoints = [ -# "https://api.openai.com/v1", -# "https://api.z.ai/api/paas/v4" -# ] - -# Behavior on rate limit error (429) -# Options: retry, fallback, retry_then_fallback, fail on_rate_limit = "retry_then_fallback" - -# Behavior on timeout error -# Options: retry, fallback, retry_then_fallback, fail on_timeout = "retry_then_fallback" - -# Behavior when all attempts fail -# Options: return_error, return_cache on_all_failed = "return_error" From 81d4fda6a853cd83942ac0b16390c73da6b7c911 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 19:48:17 +0800 Subject: [PATCH 3/6] feat(pilot): add feedback learning system with hash-based querying Add query_hash and path_hash methods to PilotContext for generating hashes used in feedback learning system. Introduce feedback learning capabilities to LlmPilot including: - Add optional PilotLearner field to LlmPilot struct - Implement with_learner and with_feedback_store builder methods - Add helper methods has_learner, learner, and record_feedback - Integrate learner adjustment logic in decision making process - Apply confidence adjustments based on historical feedback data - Skip interventions when learner suggests low historical accuracy The feedback system allows the pilot to learn from past decisions and improve future recommendations through confidence adjustments and strategic intervention skipping. --- src/retrieval/pilot/builder.rs | 18 +++++++++ src/retrieval/pilot/llm_pilot.rs | 67 +++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/retrieval/pilot/builder.rs b/src/retrieval/pilot/builder.rs index fd9d4581..7b5e2a49 100644 --- a/src/retrieval/pilot/builder.rs +++ b/src/retrieval/pilot/builder.rs @@ -101,6 +101,24 @@ impl PilotContext { && self.path_section.is_empty() && self.candidates_section.is_empty() } + + /// Get a hash of the query for feedback learning. + pub fn query_hash(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.query_section.hash(&mut hasher); + hasher.finish() + } + + /// Get a hash of the path for feedback learning. + pub fn path_hash(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.path_section.hash(&mut hasher); + hasher.finish() + } } /// Context builder for Pilot LLM calls. diff --git a/src/retrieval/pilot/llm_pilot.rs b/src/retrieval/pilot/llm_pilot.rs index 6489510e..9f64e1c0 100644 --- a/src/retrieval/pilot/llm_pilot.rs +++ b/src/retrieval/pilot/llm_pilot.rs @@ -18,6 +18,7 @@ use super::budget::BudgetController; use super::builder::ContextBuilder; use super::config::PilotConfig; use super::decision::{InterventionPoint, PilotDecision}; +use super::feedback::{DecisionAdjustment, FeedbackRecord, FeedbackStore, PilotLearner}; use super::parser::ResponseParser; use super::prompts::PromptBuilder; use super::r#trait::{Pilot, SearchState}; @@ -78,6 +79,8 @@ pub struct LlmPilot { prompt_builder: PromptBuilder, /// Response parser. response_parser: ResponseParser, + /// Feedback learner for improving decisions (optional). + learner: Option>, } impl std::fmt::Debug for LlmPilot { @@ -103,6 +106,7 @@ impl LlmPilot { context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), + learner: None, } } @@ -121,6 +125,7 @@ impl LlmPilot { context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), + learner: None, } } @@ -138,6 +143,7 @@ impl LlmPilot { context_builder: ContextBuilder::new(token_budget), prompt_builder: PromptBuilder::new(), response_parser: ResponseParser::new(), + learner: None, } } @@ -158,6 +164,7 @@ impl LlmPilot { context_builder, prompt_builder, response_parser: ResponseParser::new(), + learner: None, } } @@ -167,11 +174,42 @@ impl LlmPilot { self } + /// Add a feedback learner to the pilot. + pub fn with_learner(mut self, learner: Arc) -> Self { + self.learner = Some(learner); + self + } + + /// Add a feedback learner from a feedback store. + pub fn with_feedback_store(mut self, store: Arc) -> Self { + self.learner = Some(Arc::new(PilotLearner::new(store))); + self + } + /// Check if using LlmExecutor (unified throttle/retry/fallback). pub fn has_executor(&self) -> bool { self.executor.is_some() } + /// Check if using feedback learner. + pub fn has_learner(&self) -> bool { + self.learner.is_some() + } + + /// Get the feedback learner (if any). + pub fn learner(&self) -> Option<&PilotLearner> { + self.learner.as_deref() + } + + /// Record feedback for a decision. + pub fn record_feedback(&self, record: FeedbackRecord) { + if let Some(ref learner) = self.learner { + let decision_id = record.decision_id; + learner.store().record(record); + debug!("Recorded feedback for decision {:?}", decision_id); + } + } + /// Check if budget allows LLM calls. fn has_budget(&self) -> bool { self.budget.can_call() @@ -217,6 +255,23 @@ impl LlmPilot { return self.default_decision(candidates, point); } + // Get learner adjustment if available + let adjustment = if let Some(ref learner) = self.learner { + let query_hash = context.query_hash(); + let path_hash = context.path_hash(); + Some(learner.get_adjustment(point, query_hash, path_hash)) + } else { + None + }; + + // Check if learner suggests skipping intervention + if let Some(ref adj) = adjustment { + if adj.skip_intervention { + debug!("Learner suggests skipping intervention (low historical accuracy)"); + return self.default_decision(candidates, point); + } + } + debug!( "Calling LLM for {:?} point (estimated: {} tokens)", point, prompt.estimated_tokens @@ -239,7 +294,17 @@ impl LlmPilot { .record_usage(prompt.estimated_tokens, output_tokens, 0); // Parse response - let decision = self.response_parser.parse(&response, candidates, point); + let mut decision = self.response_parser.parse(&response, candidates, point); + + // Apply learner adjustment if available + if let Some(ref adj) = adjustment { + decision.confidence = + (decision.confidence + adj.confidence_delta as f32).clamp(0.0, 1.0); + debug!( + "Applied learner adjustment: confidence_delta={:.2}, algorithm_weight={:.2}", + adj.confidence_delta, adj.algorithm_weight + ); + } info!( "LLM decision: direction={:?}, confidence={:.2}, candidates={}", From 39b31ecfe8d6eb65df1fbbd165922371c516a4c4 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 19:49:30 +0800 Subject: [PATCH 4/6] docs: add feedback learning design document and example - Add comprehensive design document for Feedback Learning system including architecture, core components, and integration details - Implement detailed documentation covering FeedbackRecord, FeedbackStore, and PilotLearner components - Add usage examples demonstrating feedback collection and learning - Include system diagrams showing data flow and component interactions - Document configuration options and learning strategies - Provide example code showcasing feedback learning implementation --- docs/design/feedback-learning.md | 581 +++++++++++++++++++++++++++++++ examples/feedback_learning.rs | 145 ++++++++ 2 files changed, 726 insertions(+) create mode 100644 docs/design/feedback-learning.md create mode 100644 examples/feedback_learning.rs diff --git a/docs/design/feedback-learning.md b/docs/design/feedback-learning.md new file mode 100644 index 00000000..9b79bc3d --- /dev/null +++ b/docs/design/feedback-learning.md @@ -0,0 +1,581 @@ +# Feedback Learning 设计文档 + +> Pilot 反馈学习系统 - 从用户反馈中持续改进决策 + +## 概述 + +Feedback Learning 是 Pilot 的学习子系统,通过收集用户对检索结果的反馈,持续优化 Pilot 的决策能力。系统会追踪不同场景下的决策准确性,并据此调整后续决策的置信度和策略。 + +### 设计目标 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 设计目标 │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. 收集反馈 - 记录用户对检索结果的评价 │ +│ 2. 学习模式 - 识别在哪些场景下 Pilot 表现好/差 │ +│ 3. 调整决策 - 根据历史表现调整置信度和策略 │ +│ 4. 持续改进 - 随着数据积累,决策质量逐步提升 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 1. 整体架构 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Feedback Learning 系统架构 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌───────────────────────────────────────────────────────────────────────┐ │ +│ │ 数据流 │ │ +│ │ │ │ +│ │ 检索完成 │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Feedback │────▶│ Feedback │────▶│ Pilot │ │ │ +│ │ │ Record │ │ Store │ │ Learner │ │ │ +│ │ └─────────────┘ └─────────────┘ └──────┬──────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────┐ │ │ +│ │ │ Decision │ │ │ +│ │ │ Adjustment │ │ │ +│ │ └─────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ 下次检索决策 │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 2. 核心组件 + +### 2.1 FeedbackRecord - 反馈记录 + +```rust +/// 反馈记录 +pub struct FeedbackRecord { + /// 唯一反馈 ID + pub id: FeedbackId, + /// 关联的决策 ID + pub decision_id: DecisionId, + /// 决策是否正确 + pub was_correct: bool, + /// Pilot 当时的置信度 + pub pilot_confidence: f64, + /// 介入点类型 + pub intervention_point: InterventionPoint, + /// 查询哈希(用于聚合相似查询) + pub query_hash: u64, + /// 路径哈希(用于聚合相似路径) + pub path_hash: u64, + /// 时间戳 + pub timestamp_ms: u64, + /// 可选的用户评论 + pub comment: Option, +} +``` + +### 2.2 FeedbackStore - 反馈存储 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ FeedbackStore 架构 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ FeedbackStore │ │ +│ │ │ │ +│ │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ │ +│ │ │ records │ │ intervention_ │ │ query_stats │ │ │ +│ │ │ Vec │ │ stats │ │ HashMap │ │ │ +│ │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ │ +│ │ │ │ +│ │ ┌─────────────────┐ │ │ +│ │ │ path_stats │ │ │ +│ │ │ HashMap │ │ │ +│ │ └─────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 统计维度: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 1. 按 InterventionPoint 聚合 │ │ +│ │ - START / FORK / BACKTRACK / EVALUATE 各自的准确率 │ │ +│ │ │ │ +│ │ 2. 按 Query 聚合 │ │ +│ │ - 相似查询的历史表现 │ │ +│ │ │ │ +│ │ 3. 按 Path 聚合 │ │ +│ │ - 相似路径的历史表现 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 2.3 PilotLearner - 学习器 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ PilotLearner 工作原理 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 输入: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ - intervention_point: 当前介入点类型 │ │ +│ │ - query_hash: 查询的哈希值 │ │ +│ │ - path_hash: 路径的哈希值 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 查询历史统计: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 1. 获取 intervention_point 的整体准确率 │ │ +│ │ 2. 获取 query_hash 的特定准确率(如有) │ │ +│ │ 3. 获取 path_hash 的特定准确率(如有) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 输出 DecisionAdjustment: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ pub struct DecisionAdjustment { │ │ +│ │ /// 置信度调整(加到 Pilot 置信度上) │ │ +│ │ pub confidence_delta: f64, │ │ +│ │ /// 是否跳过介入(信任算法) │ │ +│ │ pub skip_intervention: bool, │ │ +│ │ /// 算法权重 vs LLM 权重 │ │ +│ │ pub algorithm_weight: f64, │ │ +│ │ } │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 3. 学习策略 + +### 3.1 准确率阈值 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 准确率阈值策略 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 配置参数: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ min_samples: 10 // 最小样本数才开始调整 │ │ +│ │ high_accuracy_threshold: 0.8 // 高准确率阈值 │ │ +│ │ low_accuracy_threshold: 0.5 // 低准确率阈值 │ │ +│ │ max_confidence_delta: 0.2 // 最大置信度调整幅度 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 决策逻辑: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ if accuracy >= high_accuracy_threshold (0.8): │ │ +│ │ // 高准确率:信任 LLM,提升置信度 │ │ +│ │ confidence_delta = +0.2 │ │ +│ │ algorithm_weight = 0.3 // 更依赖 LLM │ │ +│ │ │ │ +│ │ elif accuracy <= low_accuracy_threshold (0.5): │ │ +│ │ // 低准确率:信任算法,降低置信度 │ │ +│ │ confidence_delta = -0.2 │ │ +│ │ algorithm_weight = 0.7 // 更依赖算法 │ │ +│ │ │ │ +│ │ if accuracy < 0.3: │ │ +│ │ // 非常低:跳过 LLM 调用,完全用算法 │ │ +│ │ skip_intervention = true │ │ +│ │ │ │ +│ │ else: │ │ +│ │ // 中等准确率:保持默认 │ │ +│ │ confidence_delta = 0.0 │ │ +│ │ algorithm_weight = 0.5 │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 3.2 多层统计融合 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 多层统计融合策略 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 三层统计: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ Layer 1: InterventionPoint 级别(粗粒度) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 例如: FORK 点整体准确率 = 0.75 │ │ │ +│ │ │ 影响: 基础调整 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ Layer 2: Query 级别(中粒度) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 例如: 相似查询的准确率 = 0.85 │ │ │ +│ │ │ 影响: 如果高于整体,额外 +0.05 置信度 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ Layer 3: Path 级别(细粒度) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 例如: 相似路径的准确率 = 0.92 │ │ │ +│ │ │ 影响: 如果非常高,额外 +0.05 置信度 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 融合示例: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ 场景: FORK 点,相似查询,相似路径 │ │ +│ │ │ │ +│ │ 1. FORK 整体准确率 0.75 → confidence_delta = +0.1 │ │ +│ │ 2. 查询特定准确率 0.85 > 0.75 → confidence_delta += 0.05 │ │ +│ │ 3. 路径特定准确率 0.92 > 0.9 → confidence_delta += 0.05 │ │ +│ │ │ │ +│ │ 最终: confidence_delta = +0.2 (达到上限) │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 4. 与 LlmPilot 的集成 + +### 4.1 集成点 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LlmPilot 与 Learner 集成 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ LlmPilot 结构: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ pub struct LlmPilot { │ │ +│ │ client: LlmClient, │ │ +│ │ executor: Option>, │ │ +│ │ config: PilotConfig, │ │ +│ │ budget: BudgetController, │ │ +│ │ context_builder: ContextBuilder, │ │ +│ │ prompt_builder: PromptBuilder, │ │ +│ │ response_parser: ResponseParser, │ │ +│ │ learner: Option>, // ← 反馈学习器 │ │ +│ │ } │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 关键方法: │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ // 添加学习器 │ │ +│ │ pub fn with_learner(self, learner: Arc) -> Self │ │ +│ │ │ │ +│ │ // 从反馈存储创建学习器 │ │ +│ │ pub fn with_feedback_store(self, store: Arc) -> Self│ │ +│ │ │ │ +│ │ // 记录反馈 │ │ +│ │ pub fn record_feedback(&self, record: FeedbackRecord) │ │ +│ │ │ │ +│ │ // 获取学习器(只读) │ │ +│ │ pub fn learner(&self) -> Option<&PilotLearner> │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 4.2 决策流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 带学习的决策流程 │ +└─────────────────────────────────────────────────────────────────────────────┘ + + ┌─────────────────┐ + │ call_llm() │ + └────────┬────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ 1. 构建上下文 (ContextBuilder) │ + │ - query_section │ + │ - path_section │ + │ - candidates_section │ + └──────────────┬───────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ 2. 获取学习器调整 │ + │ if learner.is_some() { │ + │ query_hash = ctx.hash() │ + │ path_hash = ctx.hash() │ + │ adjustment = learner │ + │ .get_adjustment( │ + │ point, │ + │ query_hash, │ + │ path_hash │ + │ ) │ + │ } │ + └──────────────┬───────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ 3. 检查是否跳过介入 │ + │ if adjustment.skip { │ + │ return default_decision │ + │ } │ + └──────────────┬───────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ 4. 调用 LLM 获取决策 │ + │ decision = llm.complete() │ + └──────────────┬───────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ 5. 应用学习器调整 │ + │ decision.confidence += │ + │ adjustment.confidence │ + │ .delta │ + └──────────────┬───────────────┘ + │ + ▼ + ┌─────────────────┐ + │ 返回调整后决策 │ + └─────────────────┘ +``` + +--- + +## 5. 使用示例 + +### 5.1 基本使用 + +```rust +use std::sync::Arc; +use vectorless::retrieval::pilot::{ + LlmPilot, PilotConfig, + FeedbackStore, FeedbackRecord, PilotLearner, +}; +use vectorless::llm::LlmClient; + +// 1. 创建反馈存储 +let store = Arc::new(FeedbackStore::in_memory()); + +// 2. 创建带学习器的 Pilot +let client = LlmClient::for_model("gpt-4o-mini"); +let pilot = LlmPilot::new(client, PilotConfig::default()) + .with_feedback_store(store.clone()); + +// 3. 执行检索(Pilot 会自动应用学习调整) +let decision = pilot.decide(&state).await; + +// 4. 记录用户反馈 +let record = FeedbackRecord::new( + decision_id, + was_correct, // 用户评价 + decision.confidence as f64, + InterventionPoint::Fork, + query_hash, + path_hash, +); +pilot.record_feedback(record); + +// 5. 后续检索会自动利用历史反馈改进决策 +``` + +### 5.2 持久化反馈 + +```rust +use vectorless::retrieval::pilot::feedback::FeedbackStoreConfig; + +// 创建带持久化的反馈存储 +let config = FeedbackStoreConfig::with_persistence("./data/feedback.json"); +let store = Arc::new(FeedbackStore::new(config)); + +// 启动时加载历史反馈 +store.load()?; + +// 定期保存 +store.persist()?; +``` + +### 5.3 查看学习效果 + +```rust +// 获取整体准确率 +let accuracy = learner.overall_accuracy(); +println!("Overall accuracy: {:.2}%", accuracy * 100.0); + +// 获取各介入点的统计 +let stats = store.intervention_stats(); +println!("Fork accuracy: {:.2}%", stats.fork.accuracy() * 100.0); +println!("Start accuracy: {:.2}%", stats.start.accuracy() * 100.0); + +// 检查是否有足够数据 +if learner.has_sufficient_data() { + println!("Learner has sufficient data for adjustments"); +} +``` + +--- + +## 6. 配置选项 + +```rust +/// 反馈存储配置 +pub struct FeedbackStoreConfig { + /// 最大记录数(内存限制) + pub max_records: usize, + /// 是否持久化 + pub persist: bool, + /// 持久化路径 + pub storage_path: Option, +} + +/// 学习器配置 +pub struct LearnerConfig { + /// 最小样本数(少于此数不调整) + pub min_samples: u64, + /// 高准确率阈值 + pub high_accuracy_threshold: f64, + /// 低准确率阈值 + pub low_accuracy_threshold: f64, + /// 最大置信度调整幅度 + pub max_confidence_delta: f64, +} + +impl Default for LearnerConfig { + fn default() -> Self { + Self { + min_samples: 10, + high_accuracy_threshold: 0.8, + low_accuracy_threshold: 0.5, + max_confidence_delta: 0.2, + } + } +} +``` + +--- + +## 7. 实现细节 + +### 7.1 哈希计算 + +```rust +impl PilotContext { + /// 计算查询哈希(用于聚合相似查询) + pub fn query_hash(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.query_section.hash(&mut hasher); + hasher.finish() + } + + /// 计算路径哈希(用于聚合相似路径) + pub fn path_hash(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.path_section.hash(&mut hasher); + hasher.finish() + } +} +``` + +### 7.2 统计计算 + +```rust +impl ContextStats { + /// 计算准确率 + pub fn accuracy(&self) -> f64 { + if self.total == 0 { + 0.0 + } else { + self.correct as f64 / self.total as f64 + } + } + + /// 记录新反馈(增量更新) + fn record(&mut self, was_correct: bool, confidence: f64) { + self.total += 1; + if was_correct { + self.correct += 1; + // 增量更新平均置信度 + self.avg_confidence_correct = + (self.avg_confidence_correct * (self.correct - 1) as f64 + confidence) + / self.correct as f64; + } else { + let incorrect = self.total - self.correct; + self.avg_confidence_incorrect = + (self.avg_confidence_incorrect * (incorrect - 1) as f64 + confidence) + / incorrect as f64; + } + } +} +``` + +--- + +## 8. 未来扩展 + +### 8.1 可能的改进方向 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 未来扩展方向 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. 语义相似度聚合 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 当前: 使用精确哈希聚合 │ │ +│ │ 未来: 使用 embedding 计算语义相似度,聚合语义相近的查询 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 2. 时间衰减 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 当前: 所有历史反馈等权重 │ │ +│ │ 未来: 近期反馈权重更高,旧反馈逐渐衰减 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 3. 在线学习 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 当前: 离线分析,在线应用 │ │ +│ │ 未来: 实时更新模型参数 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 4. 个性化学习 │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ 当前: 全局学习 │ │ +│ │ 未来: 按用户/场景分别学习 │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 9. 代码结构 + +``` +src/retrieval/pilot/ +├── mod.rs # 模块入口 +├── feedback.rs # FeedbackStore, PilotLearner 实现 +├── llm_pilot.rs # LlmPilot(集成 learner) +├── builder.rs # ContextBuilder(添加 hash 方法) +└── ... +``` diff --git a/examples/feedback_learning.rs b/examples/feedback_learning.rs new file mode 100644 index 00000000..88a15f4e --- /dev/null +++ b/examples/feedback_learning.rs @@ -0,0 +1,145 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Feedback Learning example. +//! +//! This example demonstrates how to use the feedback learning system +//! to improve Pilot decision quality over time. +//! +//! # What you'll learn: +//! - How to create a FeedbackStore for collecting feedback +//! - How to integrate PilotLearner with LlmPilot +//! - How to record user feedback for decisions +//! - How the learner automatically adjusts decisions +//! +//! # Key concepts: +//! +//! ## Feedback Flow +//! ```text +//! Retrieval → Decision → User Feedback → FeedbackStore +//! ↑ ↓ +//! └──────── PilotLearner ────────┘ +//! (adjusts confidence) +//! ``` +//! +//! ## Learning Effect +//! - High accuracy scenarios → Pilot confidence boosted +//! - Low accuracy scenarios → Algorithm trusted more +//! - Very low accuracy → Intervention skipped entirely + +use std::sync::Arc; +use vectorless::document::DocumentTree; +use vectorless::llm::LlmClient; +use vectorless::retrieval::pilot::feedback::{ + FeedbackRecord, FeedbackStore, FeedbackStoreConfig, LearnerConfig, PilotLearner, + DecisionId, SubQueryComplexity, SubQueryType, +}; +use vectorless::retrieval::pilot::{InterventionPoint, LlmPilot, PilotConfig}; + +fn main() -> Result<(), Box> { + println!("=== Feedback Learning Example ===\n"); + + // 1. Create FeedbackStore with in-memory storage + let store = Arc::new(FeedbackStore::in_memory()); + println!("✓ Created FeedbackStore (in-memory)"); + + // 2. Create Learner with custom configuration + let learner_config = LearnerConfig { + min_samples: 5, // Need 5 samples before adjusting + high_accuracy_threshold: 0.8, // 80%+ accuracy = boost confidence + low_accuracy_threshold: 0.5, // 50%- accuracy = reduce confidence + max_confidence_delta: 0.2, // Max adjustment ±0.2 + }; + let learner = Arc::new(PilotLearner::with_config(store.clone(), learner_config)); + println!("✓ Created PilotLearner with custom config"); + + // 3. Create LlmPilot with feedback learning + let client = LlmClient::for_model("gpt-4o-mini"); + let pilot = LlmPilot::new(client, PilotConfig::default()).with_learner(learner.clone()); + println!("✓ Created LlmPilot with feedback learner"); + + // 4. Simulate some retrieval operations with feedback + println!("\n=== Simulating Retrieval with Feedback ===\n"); + + // Simulate 10 retrieval operations + for i in 0..10 { + let decision_id = DecisionId(i); + let was_correct = i % 3 != 0; // 66% accuracy + let confidence = 0.7 + (i as f64 * 0.02); + + // Create feedback record + let record = FeedbackRecord::new( + decision_id, + was_correct, + confidence, + InterventionPoint::Fork, + 12345, // query_hash + 67890, // path_hash + ); + + // Record feedback + pilot.record_feedback(record); + + println!( + "Decision {}: {} (confidence: {:.2})", + i, + if was_correct { "✓ Correct" } else { "✗ Incorrect" }, + confidence + ); + } + + // 5. View learning statistics + println!("\n=== Learning Statistics ===\n"); + + let stats = store.intervention_stats(); + println!("Fork Point Statistics:"); + println!(" Total decisions: {}", stats.fork.total); + println!(" Correct: {}", stats.fork.correct); + println!(" Accuracy: {:.1}%", stats.fork.accuracy() * 100.0); + println!( + " Avg confidence (correct): {:.2}", + stats.fork.avg_confidence_correct + ); + println!( + " Avg confidence (incorrect): {:.2}", + stats.fork.avg_confidence_incorrect + ); + + let overall = store.overall_accuracy(); + println!("\nOverall accuracy: {:.1}%", overall * 100.0); + println!("Total records: {}", store.total_records()); + + // 6. Check if learner has enough data + println!("\n=== Learner Status ===\n"); + if learner.has_sufficient_data() { + println!("✓ Learner has sufficient data for adjustments"); + + // Get adjustment for similar context + let adjustment = learner.get_adjustment(InterventionPoint::Fork, 12345, 67890); + println!("\nAdjustment for similar context:"); + println!(" Confidence delta: {:.3}", adjustment.confidence_delta); + println!(" Algorithm weight: {:.2}", adjustment.algorithm_weight); + println!( + " Skip intervention: {}", + adjustment.skip_intervention + ); + } else { + println!("✗ Learner needs more data before adjusting"); + } + + // 7. Demonstrate persistence (optional) + println!("\n=== Persistence (Optional) ===\n"); + + let persistent_config = FeedbackStoreConfig::with_persistence("/tmp/feedback.json"); + let persistent_store = FeedbackStore::new(persistent_config); + + // In a real app, you would: + // - Load existing feedback at startup: persistent_store.load()? + // - Save periodically: persistent_store.persist()? + + println!("To enable persistence, create FeedbackStore with:"); + println!(" FeedbackStoreConfig::with_persistence(\"/path/to/feedback.json\")"); + + println!("\n=== Example Complete ==="); + Ok(()) +} From 59b6ed7b3f5bfa529c5cc4c3c78a44369211be08 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 19:51:49 +0800 Subject: [PATCH 5/6] refactor(examples): update imports and variable naming in feedback learning example - Remove unused DocumentTree import from vectorless::document module - Reorganize imports for better grouping and readability - Move PilotConfig import to appropriate location with other pilot-related imports - Rename persistent_store variable to _persistent_store to indicate it's intentionally unused --- examples/feedback_learning.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/feedback_learning.rs b/examples/feedback_learning.rs index 88a15f4e..3848e534 100644 --- a/examples/feedback_learning.rs +++ b/examples/feedback_learning.rs @@ -28,13 +28,11 @@ //! - Very low accuracy → Intervention skipped entirely use std::sync::Arc; -use vectorless::document::DocumentTree; use vectorless::llm::LlmClient; -use vectorless::retrieval::pilot::feedback::{ - FeedbackRecord, FeedbackStore, FeedbackStoreConfig, LearnerConfig, PilotLearner, - DecisionId, SubQueryComplexity, SubQueryType, +use vectorless::retrieval::pilot::{ + FeedbackRecord, FeedbackStore, FeedbackStoreConfig, InterventionPoint, LearnerConfig, + PilotLearner, DecisionId, LlmPilot, PilotConfig, }; -use vectorless::retrieval::pilot::{InterventionPoint, LlmPilot, PilotConfig}; fn main() -> Result<(), Box> { println!("=== Feedback Learning Example ===\n"); @@ -131,7 +129,7 @@ fn main() -> Result<(), Box> { println!("\n=== Persistence (Optional) ===\n"); let persistent_config = FeedbackStoreConfig::with_persistence("/tmp/feedback.json"); - let persistent_store = FeedbackStore::new(persistent_config); + let _persistent_store = FeedbackStore::new(persistent_config); // In a real app, you would: // - Load existing feedback at startup: persistent_store.load()? From e5b0732f3a52dda228d110af6fc3bc9b03fcd401 Mon Sep 17 00:00:00 2001 From: zTgx <747674262@qq.com> Date: Sun, 5 Apr 2026 19:58:59 +0800 Subject: [PATCH 6/6] docs(readme): remove star call-to-action and update development warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the star call-to-action message and update the early development warning by removing the ⭐ emoji prefix. docs(architecture): enhance diagram with comprehensive system design Update architecture.svg to include: - Expand canvas size from 720px to 800px height - Add detailed configuration section with TOML settings - Include LLM executor with throttle control and retry mechanisms - Add query decomposition and pilot guidance components - Implement unified metrics hub with LLM, pilot, retrieval, and feedback statistics - Add feedback learning loop from user feedback to decision adjustment - Include design philosophy section highlighting zero vectors, algorithm+LLM hybrid approach, feedback learning, and multi-turn support - Replace keyword/semantic/LLM strategies with keyword-only, BM25, and hybrid scoring strategies - Update workspace to include feedback store component --- README.md | 15 +-- docs/design/architecture.svg | 173 ++++++++++++++++++++--------------- 2 files changed, 101 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index a89e2f28..42742a46 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,7 @@ Ultra performant document intelligence engine for RAG, with written in **Rust**. Zero vector database, zero embedding model — just LLM-powered tree navigation. Incremental indexing and multi-format support out-of-box. -⭐ **Drop a star to help us grow!** - -**⚠️ Early Development**: This project is in active development. The API and features are likely to evolve, and breaking changes may occur. +**Early Development**: This project is in active development. The API and features are likely to evolve, and breaking changes may occur. ## Why Vectorless? @@ -134,19 +132,12 @@ async fn main() -> vectorless::Result<()> { ## Examples -See the [examples/](examples/) directory for complete working examples - +See the [examples/](examples/) directory for complete working examples. ## Architecture - -### Pilot Architecture - -![Pilot Architecture](docs/design/pilot-architecture.svg) - -### System Overview - ![Architecture](docs/design/architecture.svg) + ## Contributing Contributions are welcome! diff --git a/docs/design/architecture.svg b/docs/design/architecture.svg index 860e28f9..cb782610 100644 --- a/docs/design/architecture.svg +++ b/docs/design/architecture.svg @@ -1,26 +1,28 @@ - + - + Vectorless Architecture - + Engine Client - - Config - • indexer settings - • retrieval settings + + Config (TOML) + • LLM pool settings + • Metrics config + • Pilot + feedback + • Scoring strategy - - Workspace - • Persistence (JSON) - • LRU Cache - • Document Metadata + + Workspace + • Persistence (JSON) + • LRU Cache + • Feedback Store @@ -58,14 +60,14 @@ Thin - + Retrieval Pipeline Analyze • Complexity detect - • Keyword extract + • Decompose query @@ -79,7 +81,7 @@ Search • Tree traversal - • Beam/MCTS + • Pilot guidance @@ -88,74 +90,95 @@ • Sufficiency check • Backtrack control - - Strategies (Plan Stage): - - - Keyword - • TF-IDF / BM25 - • No API calls - • Fast, local - - - Semantic - • Embedding similarity - • Vector comparison - • Medium speed - - - LLM - • ToC navigation - • Reasoning-based - • Best accuracy - - - - - - - Search Algorithms (Search Stage): - - - Greedy - • Best-first traversal - - - Beam Search - • Top-k paths parallel - - - MCTS - • Monte Carlo Tree Search + + Scoring Strategies: + + + Keyword Only + • TF-IDF overlap + • Fast, no API calls + + + BM25 + • IDF + TF normalization + • Better relevance + + + Hybrid (Default) + • 40% keyword + 60% BM25 + • Best balance + + + + Query Decomposition + • Split complex queries into sub-queries + • Execute in dependency order + + + + Pilot + Feedback Learning + • LLM-guided navigation + • Learns from user feedback + + + + LLM Executor + • Throttle control + • Retry with backoff + • Fallback chain + • Unified metrics + + + + Unified Metrics Hub + + + LLM Metrics + calls • tokens • latency • cost + + + Pilot Metrics + decisions • confidence • accuracy + + + Retrieval Metrics + paths • scores • cache hits + + + Feedback Stats + accuracy • samples • trends - NeedMoreData / Backtrack (increase beam, go deeper) + NeedMoreData / Backtrack + + + + Feedback Loop: User feedback → Store → Learner → Adjusted decisions - - - LLM Client - • Retry with backoff - • Fallback chain - • Concurrency control + + + - - - LLM Usage in Retrieval + + + Design Philosophy - - LLM Strategy (Plan → Search) - • Evaluates node relevance using title, summary, ToC - • Guides tree traversal with reasoning + Zero Vectors + No embedding model + LLM-powered navigation - - Judge Stage (Judge) - • Evaluates if collected content is sufficient - • Triggers backtracking if more data needed + Algorithm + LLM + Efficient + Semantic + Hybrid scoring - - - + Feedback Learning + Continuous improvement + Context-aware adjustments + + Multi-turn Support + Query decomposition + Dependency ordering