diff --git a/Cargo.toml b/Cargo.toml index cd8fca41..950779db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ futures = "0.3" # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -toml = "0.8" # Error handling thiserror = "2" @@ -36,7 +35,7 @@ uuid = { version = "1.10", features = ["v4", "serde"] } chrono = { version = "0.4", default-features = false, features = ["serde", "clock"] } # Logging -tracing = "0.1" +tracing = { version = "0.1", features = ["attributes"] } # Rate limiting governor = "0.6" diff --git a/README.md b/README.md index 556ba5f3..b9a34abd 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ vectorless = "0.1" ``` ```rust -use vectorless::client::{EngineBuilder, IndexContext, QueryContext}; +use vectorless::{EngineBuilder, IndexContext, QueryContext}; #[tokio::main] async fn main() -> vectorless::Result<()> { diff --git a/docs/docs/examples/batch-indexing.mdx b/docs/docs/examples/batch-indexing.mdx index 847e738b..b9f77adc 100644 --- a/docs/docs/examples/batch-indexing.mdx +++ b/docs/docs/examples/batch-indexing.mdx @@ -47,7 +47,7 @@ asyncio.run(main()) ## Rust ```rust -use vectorless::client::{Engine, EngineBuilder, IndexContext}; +use vectorless::{Engine, EngineBuilder, IndexContext}; #[tokio::main] async fn main() -> vectorless::Result<()> { diff --git a/docs/docs/examples/quick-query.mdx b/docs/docs/examples/quick-query.mdx index 07f66390..9a39c82c 100644 --- a/docs/docs/examples/quick-query.mdx +++ b/docs/docs/examples/quick-query.mdx @@ -48,7 +48,7 @@ asyncio.run(main()) ## Rust ```rust -use vectorless::client::{Engine, EngineBuilder, IndexContext, QueryContext}; +use vectorless::{Engine, EngineBuilder, IndexContext, QueryContext}; use vectorless::StrategyPreference; #[tokio::main] diff --git a/docs/docs/features/synonym-expansion.mdx b/docs/docs/features/synonym-expansion.mdx index 527bc70e..9a88d901 100644 --- a/docs/docs/features/synonym-expansion.mdx +++ b/docs/docs/features/synonym-expansion.mdx @@ -46,7 +46,7 @@ opts = IndexOptions(enable_synonym_expansion=False) The synonym expansion is controlled via `ReasoningIndexConfig`: ```rust -use vectorless::document::ReasoningIndexConfig; +use vectorless::ReasoningIndexConfig; let config = ReasoningIndexConfig::default() .with_synonym_expansion(true); diff --git a/docs/docs/indexing/incremental.mdx b/docs/docs/indexing/incremental.mdx index d50ae3b3..482cd2b5 100644 --- a/docs/docs/indexing/incremental.mdx +++ b/docs/docs/indexing/incremental.mdx @@ -32,7 +32,7 @@ result = await engine.index(ctx) ### Rust ```rust -use vectorless::client::{IndexContext, IndexMode}; +use vectorless::{IndexContext, IndexMode}; let ctx = IndexContext::from_path("./report.pdf") .with_mode(IndexMode::Incremental); diff --git a/docs/docs/sdk/rust.mdx b/docs/docs/sdk/rust.mdx index 2136cafa..1117302d 100644 --- a/docs/docs/sdk/rust.mdx +++ b/docs/docs/sdk/rust.mdx @@ -16,7 +16,7 @@ vectorless = "0.1" ## Engine ```rust -use vectorless::client::{Engine, EngineBuilder}; +use vectorless::{Engine, EngineBuilder}; let engine = EngineBuilder::new() .with_key("sk-...") @@ -29,7 +29,7 @@ let engine = EngineBuilder::new() ## Indexing ```rust -use vectorless::client::{IndexContext, IndexOptions, IndexMode}; +use vectorless::{IndexContext, IndexOptions, IndexMode}; // From a file let result = engine.index(IndexContext::from_path("./report.pdf")).await?; @@ -51,7 +51,7 @@ let result = engine.index( ## Querying ```rust -use vectorless::client::QueryContext; +use vectorless::QueryContext; use vectorless::StrategyPreference; let result = engine.query( diff --git a/docs/src/pages/index.tsx b/docs/src/pages/index.tsx index c75abfd7..cae26e92 100644 --- a/docs/src/pages/index.tsx +++ b/docs/src/pages/index.tsx @@ -107,7 +107,7 @@ async def main(): asyncio.run(main())`; -const RUST_CODE = `use vectorless::client::{EngineBuilder, IndexContext, QueryContext}; +const RUST_CODE = `use vectorless::{EngineBuilder, IndexContext, QueryContext}; #[tokio::main] async fn main() -> vectorless::Result<()> { diff --git a/python/src/config.rs b/python/src/config.rs index 93a0552e..ce601311 100644 --- a/python/src/config.rs +++ b/python/src/config.rs @@ -55,14 +55,14 @@ impl PyConfig { /// /// Default: 10 fn set_max_concurrent_requests(&mut self, max: usize) { - self.inner.concurrency.max_concurrent_requests = max; + self.inner.llm.throttle.max_concurrent_requests = max; } /// Set the rate limit (requests per minute). /// /// Default: 500 fn set_requests_per_minute(&mut self, rpm: usize) { - self.inner.concurrency.requests_per_minute = rpm; + self.inner.llm.throttle.requests_per_minute = rpm; } /// Set the maximum iterations for retrieval search. @@ -70,13 +70,6 @@ impl PyConfig { self.inner.retrieval.search.max_iterations = max; } - /// Set the retrieval temperature. - /// - /// Default: 0.0 - fn set_temperature(&mut self, temp: f32) { - self.inner.retrieval.temperature = temp; - } - /// Enable or disable metrics collection. /// /// Default: True diff --git a/python/src/context.rs b/python/src/context.rs index 3eedc6f9..bdf37c6c 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; -use ::vectorless::client::{DocumentFormat, IndexContext, IndexMode, IndexOptions, QueryContext}; +use ::vectorless::{DocumentFormat, IndexContext, IndexMode, IndexOptions, QueryContext}; use super::error::VectorlessError; @@ -31,7 +31,6 @@ fn parse_format(format: &str) -> PyResult { /// mode: Indexing mode - "default", "force", or "incremental". /// generate_summaries: Whether to generate summaries. Default: True. /// generate_description: Whether to generate document description. Default: False. -/// include_text: Whether to include node text in the tree. Default: True. /// generate_ids: Whether to generate node IDs. Default: True. /// enable_synonym_expansion: Whether to expand keywords with LLM-generated /// synonyms during indexing. Improves recall for differently-worded queries. @@ -45,12 +44,11 @@ pub struct PyIndexOptions { #[pymethods] impl PyIndexOptions { #[new] - #[pyo3(signature = (mode="default", generate_summaries=true, generate_description=false, include_text=true, generate_ids=true, enable_synonym_expansion=false))] + #[pyo3(signature = (mode="default", generate_summaries=true, generate_description=false, generate_ids=true, enable_synonym_expansion=false))] fn new( mode: &str, generate_summaries: bool, generate_description: bool, - include_text: bool, generate_ids: bool, enable_synonym_expansion: bool, ) -> PyResult { @@ -71,7 +69,6 @@ impl PyIndexOptions { } opts.generate_summaries = generate_summaries; opts.generate_description = generate_description; - opts.include_text = include_text; opts.generate_ids = generate_ids; opts.enable_synonym_expansion = enable_synonym_expansion; Ok(Self { inner: opts }) @@ -79,7 +76,7 @@ impl PyIndexOptions { fn __repr__(&self) -> String { format!( - "IndexOptions(mode='{}', generate_summaries={}, generate_description={}, include_text={}, generate_ids={}, enable_synonym_expansion={})", + "IndexOptions(mode='{}', generate_summaries={}, generate_description={}, generate_ids={}, enable_synonym_expansion={})", match self.inner.mode { IndexMode::Default => "default", IndexMode::Force => "force", @@ -87,7 +84,6 @@ impl PyIndexOptions { }, self.inner.generate_summaries, self.inner.generate_description, - self.inner.include_text, self.inner.generate_ids, self.inner.enable_synonym_expansion, ) @@ -270,18 +266,6 @@ impl PyQueryContext { Self { inner: ctx } } - /// Set whether to include the reasoning chain. - fn with_include_reasoning(&self, include: bool) -> Self { - let ctx = self.inner.clone().with_include_reasoning(include); - Self { inner: ctx } - } - - /// Set the maximum tree traversal depth. - fn with_depth_limit(&self, depth: usize) -> Self { - let ctx = self.inner.clone().with_depth_limit(depth); - Self { inner: ctx } - } - fn __repr__(&self) -> String { "QueryContext(...)".to_string() } diff --git a/python/src/document.rs b/python/src/document.rs index eee70c0e..d5652fba 100644 --- a/python/src/document.rs +++ b/python/src/document.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; -use ::vectorless::client::DocumentInfo; +use ::vectorless::DocumentInfo; /// Information about an indexed document. #[pyclass(name = "DocumentInfo")] diff --git a/python/src/engine.rs b/python/src/engine.rs index 8f7dc015..1f8ae870 100644 --- a/python/src/engine.rs +++ b/python/src/engine.rs @@ -8,7 +8,7 @@ use pyo3_async_runtimes::tokio::future_into_py; use std::sync::Arc; use tokio::runtime::Runtime; -use ::vectorless::client::{Engine, EngineBuilder, IndexContext, QueryContext}; +use ::vectorless::{Engine, EngineBuilder, IndexContext, QueryContext}; use super::config::PyConfig; use super::context::{PyIndexContext, PyQueryContext}; @@ -229,9 +229,6 @@ impl PyEngine { } /// Generate a complete metrics report. - /// - /// Returns: - /// MetricsReport with LLM, Pilot, and Retrieval metrics. fn metrics_report(&self) -> PyMetricsReport { run_metrics_report(Arc::clone(&self.inner)) } diff --git a/python/src/error.rs b/python/src/error.rs index d128ce5a..e4a977b8 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -6,7 +6,7 @@ use pyo3::exceptions::PyException; use pyo3::prelude::*; -use ::vectorless::error::Error as RustError; +use ::vectorless::Error as RustError; /// Python exception for vectorless errors. #[pyclass(extends = PyException, subclass)] diff --git a/python/src/graph.rs b/python/src/graph.rs index 1aacd47f..a424316f 100644 --- a/python/src/graph.rs +++ b/python/src/graph.rs @@ -5,9 +5,7 @@ use pyo3::prelude::*; -use ::vectorless::graph::{ - DocumentGraph, DocumentGraphNode, EdgeEvidence, GraphEdge, WeightedKeyword, -}; +use ::vectorless::{DocumentGraph, DocumentGraphNode, EdgeEvidence, GraphEdge, WeightedKeyword}; /// A keyword with weight from document analysis. #[pyclass(name = "WeightedKeyword")] diff --git a/python/src/metrics.rs b/python/src/metrics.rs index 669511cb..27a71dcb 100644 --- a/python/src/metrics.rs +++ b/python/src/metrics.rs @@ -5,9 +5,7 @@ use pyo3::prelude::*; -use ::vectorless::metrics::{ - LlmMetricsReport, MetricsReport, PilotMetricsReport, RetrievalMetricsReport, -}; +use ::vectorless::{LlmMetricsReport, MetricsReport, PilotMetricsReport, RetrievalMetricsReport}; /// LLM usage metrics report. #[pyclass(name = "LlmMetricsReport")] @@ -360,11 +358,6 @@ impl PyMetricsReport { self.inner.total_cost_usd() } - /// Overall success rate (0.0 - 1.0). - fn overall_success_rate(&self) -> f64 { - self.inner.overall_success_rate() - } - fn __repr__(&self) -> String { format!( "MetricsReport(llm_calls={}, cost=${:.4}, queries={})", diff --git a/python/src/results.rs b/python/src/results.rs index fe780a4c..14735e71 100644 --- a/python/src/results.rs +++ b/python/src/results.rs @@ -5,8 +5,8 @@ use pyo3::prelude::*; -use ::vectorless::client::{FailedItem, IndexItem, IndexResult, QueryResult, QueryResultItem}; -use ::vectorless::metrics::IndexMetrics; +use ::vectorless::IndexMetrics; +use ::vectorless::{FailedItem, IndexItem, IndexResult, QueryResult, QueryResultItem}; // ============================================================ // QueryResultItem diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d6984f6e..09ab38fe 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,7 +22,6 @@ futures = { workspace = true } # Serialization serde = { workspace = true } serde_json = { workspace = true } -toml = { workspace = true } # Error handling thiserror = { workspace = true } diff --git a/rust/examples/events.rs b/rust/examples/events.rs index 51398da8..3db97706 100644 --- a/rust/examples/events.rs +++ b/rust/examples/events.rs @@ -22,8 +22,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use vectorless::client::{EngineBuilder, IndexContext, QueryContext}; -use vectorless::events::{EventEmitter, IndexEvent, QueryEvent}; +use vectorless::{EngineBuilder, IndexContext, QueryContext}; +use vectorless::{EventEmitter, IndexEvent, QueryEvent}; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/rust/examples/flow.rs b/rust/examples/flow.rs index 57d92891..fbc89423 100644 --- a/rust/examples/flow.rs +++ b/rust/examples/flow.rs @@ -20,8 +20,7 @@ //! cargo run --example flow //! ``` -use vectorless::EngineBuilder; -use vectorless::client::{IndexContext, IndexOptions, QueryContext}; +use vectorless::{EngineBuilder, IndexContext, IndexOptions, QueryContext}; /// Sample markdown content for demonstration. const SAMPLE_MARKDOWN: &str = r#" diff --git a/rust/src/client/builder.rs b/rust/src/client/builder.rs index b0c035cf..5f08af95 100644 --- a/rust/src/client/builder.rs +++ b/rust/src/client/builder.rs @@ -7,15 +7,16 @@ //! [`Engine`] instances with sensible defaults. use crate::{ - client::engine::Engine, config::Config, events::EventEmitter, retrieval::PipelineRetriever, - storage::Workspace, + client::engine::Engine, config::Config, events::EventEmitter, metrics::MetricsHub, + retrieval::PipelineRetriever, storage::Workspace, }; /// Builder for creating a [`Engine`] client. /// -/// `api_key`, `model` and `endpoint` are **required**. +/// `api_key`, `model` and `endpoint` are **required** for simple usage. +/// Advanced users can provide a pre-built [`Config`] via [`with_config`](EngineBuilder::with_config). /// -/// # Example +/// # Example (simple) /// /// ```rust,no_run /// use vectorless::client::EngineBuilder; @@ -31,6 +32,25 @@ use crate::{ /// Ok(()) /// } /// ``` +/// +/// # Example (advanced) +/// +/// ```rust,ignore +/// use vectorless::client::EngineBuilder; +/// use vectorless::config::{Config, LlmConfig, SlotConfig}; +/// +/// let config = Config::new().with_llm( +/// LlmConfig::new("gpt-4o") +/// .with_api_key("sk-...") +/// .with_endpoint("https://api.openai.com/v1") +/// .with_index(SlotConfig::fast().with_model("gpt-4o-mini")) +/// ); +/// +/// let engine = EngineBuilder::new() +/// .with_config(config) +/// .build() +/// .await?; +/// ``` #[derive(Debug)] pub struct EngineBuilder { /// Custom configuration for advanced tuning. @@ -63,7 +83,7 @@ impl EngineBuilder { } // ============================================================ - // Basic Configuration + // Configuration // ============================================================ /// Set a custom configuration. @@ -85,25 +105,10 @@ impl EngineBuilder { } // ============================================================ - // LLM Configuration + // LLM Configuration (simple overrides) // ============================================================ - /// Set the LLM API key. **Required**. - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::client::EngineBuilder; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), vectorless::BuildError> { - /// let engine = EngineBuilder::new() - /// .with_key("sk-...") - /// .build() - /// .await?; - /// # Ok(()) - /// # } - /// ``` + /// Set the LLM API key. **Required** (unless provided via Config). #[must_use] pub fn with_key(mut self, key: impl Into) -> Self { self.api_key = Some(key.into()); @@ -111,23 +116,6 @@ impl EngineBuilder { } /// Set the LLM model name. - /// - /// Default: "gpt-4o". - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::client::EngineBuilder; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), vectorless::BuildError> { - /// let engine = EngineBuilder::new() - /// .with_model("gpt-4o-mini") - /// .build() - /// .await?; - /// # Ok(()) - /// # } - /// ``` #[must_use] pub fn with_model(mut self, model: impl Into) -> Self { self.model = Some(model.into()); @@ -135,24 +123,6 @@ impl EngineBuilder { } /// Set a custom LLM endpoint URL. - /// - /// Use this for OpenAI-compatible APIs (e.g., Azure OpenAI, local models). - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::client::EngineBuilder; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), vectorless::BuildError> { - /// let engine = EngineBuilder::new() - /// .with_model("deepseek-chat") - /// .with_endpoint("https://api.deepseek.com/v1") - /// .build() - /// .await?; - /// # Ok(()) - /// # } - /// ``` #[must_use] pub fn with_endpoint(mut self, url: impl Into) -> Self { self.endpoint = Some(url.into()); @@ -160,17 +130,14 @@ impl EngineBuilder { } // ============================================================ - // Retrieval Configuration + // Build // ============================================================ /// Build the Engine client. /// - /// `api_key` and `model` must be provided via builder methods or config file. - /// /// # Errors /// /// Returns a [`BuildError`] if: - /// - Configuration loading fails /// - Workspace creation fails /// - Required `api_key` or `model` is missing /// @@ -184,6 +151,7 @@ impl EngineBuilder { /// let engine = EngineBuilder::new() /// .with_key("sk-...") /// .with_model("gpt-4o") + /// .with_endpoint("https://api.openai.com/v1") /// .build() /// .await?; /// # Ok(()) @@ -193,51 +161,22 @@ impl EngineBuilder { // Load user-provided or default configuration let mut config = self.config.unwrap_or_default(); - // Apply individual overrides to LlmPoolConfig (primary) + legacy config (compat) + // Apply simple overrides — write once, no dual-writing if let Some(api_key) = self.api_key { - config.llm.api_key = Some(api_key.clone()); - // Legacy compat - config.retrieval.api_key = Some(api_key.clone()); - config.summary.api_key = Some(api_key); + config.llm.api_key = Some(api_key); } if let Some(model) = self.model { - // Apply model to pool slots - if config.llm.index.model.is_empty() { - config.llm.index.model = model.clone(); - } - if config.llm.retrieval.model.is_empty() { - config.llm.retrieval.model = model.clone(); - } - if config.llm.pilot.model.is_empty() { - config.llm.pilot.model = model.clone(); - } - // Legacy compat - config.retrieval.model = model.clone(); - config.summary.model = model; + config.llm.model = model; } if let Some(endpoint) = self.endpoint { - config.llm.endpoint = Some(endpoint.clone()); - // Legacy compat - config.retrieval.endpoint = endpoint.clone(); - config.summary.endpoint = endpoint; + config.llm.endpoint = Some(endpoint); } + // Validate required settings - let resolved_key = config - .llm - .api_key - .as_ref() - .or_else(|| config.llm.retrieval.api_key.as_ref()) - .or_else(|| config.summary.api_key.as_ref()) - .or_else(|| config.retrieval.api_key.as_ref()); - if resolved_key.is_none() { + if config.llm.api_key.is_none() { return Err(BuildError::MissingApiKey); } - let retrieval_model = if config.llm.retrieval.model.is_empty() { - &config.retrieval.model - } else { - &config.llm.retrieval.model - }; - if retrieval_model.is_empty() { + if config.llm.model.is_empty() { return Err(BuildError::MissingModel); } if config.llm.endpoint.is_none() { @@ -249,17 +188,9 @@ impl EngineBuilder { .await .map_err(|e| BuildError::Workspace(e.to_string()))?; - // Build LlmPool from config.llm — centralizes all LLM client creation - let llm_configs: crate::llm::LlmConfigs = config.llm.clone().into(); - let pool = { - let controller = crate::throttle::ConcurrencyController::new( - crate::throttle::ConcurrencyConfig::new() - .with_max_concurrent_requests(config.concurrency.max_concurrent_requests) - .with_requests_per_minute(config.concurrency.requests_per_minute) - .with_enabled(config.concurrency.enabled), - ); - crate::llm::LlmPool::new(llm_configs).with_concurrency(controller) - }; + // Build LlmPool from unified LlmConfig (shared metrics hub) + let metrics_hub = std::sync::Arc::new(MetricsHub::with_defaults()); + let pool = crate::llm::LlmPool::from_config(&config.llm, Some(metrics_hub.clone())); // Indexer uses pool.index() let indexer = crate::client::indexer::IndexerClient::with_llm(pool.index().clone()); @@ -278,7 +209,7 @@ impl EngineBuilder { // Build engine let events = self.events.unwrap_or_default(); - Engine::with_components(config, workspace, retriever, indexer, events) + Engine::with_components(config, workspace, retriever, indexer, events, metrics_hub) .await .map_err(|e| BuildError::Other(e.to_string())) } @@ -298,11 +229,11 @@ pub enum BuildError { Workspace(String), /// Missing API key. - #[error("Missing API key: call .with_key(\"sk-...\") or set api_key in config file")] + #[error("Missing API key: call .with_key(\"sk-...\") or set api_key in config")] MissingApiKey, /// Missing model name. - #[error("Missing model: call .with_model(\"gpt-4o\") or set model in config file")] + #[error("Missing model: call .with_model(\"gpt-4o\") or set model in config")] MissingModel, /// Missing endpoint URL. diff --git a/rust/src/client/engine.rs b/rust/src/client/engine.rs index 6cc4e207..a607263c 100644 --- a/rust/src/client/engine.rs +++ b/rust/src/client/engine.rs @@ -37,7 +37,12 @@ //! # } //! ``` -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::Arc, + sync::Mutex, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, +}; use futures::StreamExt; use tracing::info; @@ -61,10 +66,18 @@ use super::{ indexer::IndexerClient, query_context::{QueryContext, QueryScope}, retriever::RetrieverClient, - types::{DocumentInfo, FailedItem, IndexItem, IndexMode, IndexResult, QueryResult}, + types::{ + DocumentInfo, FailedItem, IndexItem, IndexMode, IndexResult, QueryResult, QueryResultItem, + }, workspace::WorkspaceClient, }; +/// Shared cancel state: `true` means cancelled. +type CancelFlag = Arc; + +/// Max consecutive graph rebuild failures before giving up. +const GRAPH_REBUILD_MAX_FAILURES: u32 = 3; + /// The main Engine client. /// /// Provides high-level operations for document indexing and retrieval. @@ -89,16 +102,22 @@ pub struct Engine { retriever: RetrieverClient, /// Workspace client for persistence. - workspace: Option, - - /// Workspace root directory (for checkpoint path). - workspace_dir: Option, - - /// Event emitter. - events: EventEmitter, + workspace: WorkspaceClient, /// Central metrics hub for unified collection. metrics_hub: Arc, + + /// Whether the document graph needs rebuilding (set after index, consumed in query). + graph_dirty: Arc, + + /// Consecutive graph rebuild failures — skip rebuild after threshold. + graph_fail_count: Arc, + + /// Shared cancel flag — set by `cancel()`, checked by long-running operations. + cancelled: CancelFlag, + + /// Active operation count so `cancel()` can wait for drain. + active_ops: Arc>, } impl Engine { @@ -113,16 +132,15 @@ impl Engine { retriever: PipelineRetriever, indexer: IndexerClient, events: EventEmitter, + metrics_hub: Arc, ) -> Result { let config = Arc::new(config); - let workspace_dir = Some(std::path::PathBuf::from(&config.storage.workspace_dir)); // Attach event emitter to indexer let indexer = indexer.with_events(events.clone()); // Create retriever client - let retriever = - RetrieverClient::new(retriever, Arc::clone(&config)).with_events(events.clone()); + let retriever = RetrieverClient::new(retriever).with_events(events.clone()); // Create workspace client let workspace_client = WorkspaceClient::new(workspace) @@ -133,10 +151,12 @@ impl Engine { config, indexer, retriever, - workspace: Some(workspace_client), - workspace_dir, - events, - metrics_hub: Arc::new(MetricsHub::with_defaults()), + workspace: workspace_client, + metrics_hub, + graph_dirty: Arc::new(AtomicBool::new(false)), + graph_fail_count: Arc::new(AtomicU32::new(0)), + cancelled: Arc::new(AtomicBool::new(false)), + active_ops: Arc::new(Mutex::new(0)), }) } @@ -144,45 +164,38 @@ impl Engine { // Document Indexing // ============================================================ - /// Index a document. + /// Index one or more documents. /// /// Accepts an [`IndexContext`] that specifies the source (file path, - /// content string, or bytes) and indexing options. + /// directory, content string, or bytes) and indexing options. + /// Multiple sources are indexed in parallel. /// /// Returns an [`IndexResult`] containing the indexed document metadata. - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::client::{EngineBuilder, IndexContext}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let engine = EngineBuilder::new() - /// .with_key("sk-...") - /// .with_model("gpt-4o") - /// .build() - /// .await?; - /// - /// let result = engine.index(IndexContext::from_path("./doc.md")).await?; - /// println!("Indexed: {}", result.doc_id().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(skip_all, fields(sources = ctx.sources.len()))] pub async fn index(&self, ctx: IndexContext) -> Result { + self.check_cancel()?; if ctx.is_empty() { - return Err(Error::Config("No document sources provided".to_string())); + return Err(Error::Config("No document sources provided".into())); } - // Single source: no need for concurrency overhead - if ctx.sources.len() == 1 { - let source = &ctx.sources[0]; + let _guard = self.inc_active(); + let timeout_secs = ctx.options.timeout_secs; + + self.with_timeout(timeout_secs, async move { + let concurrency = self + .config + .llm + .throttle + .max_concurrent_requests + .min(ctx.sources.len()); + let (items, failed) = self - .process_source(source, &ctx.options, ctx.name.as_deref()) + .process_sources(&ctx.sources, &ctx.options, ctx.name.as_deref(), concurrency) .await; + if items.is_empty() && !failed.is_empty() { return Err(Error::Config(format!( - "All {} source(s) failed to index: {}", + "All {} source(s) failed: {}", failed.len(), failed .iter() @@ -191,26 +204,32 @@ impl Engine { .join("; ") ))); } - if !items.is_empty() { - if let Err(e) = self.rebuild_graph().await { - tracing::warn!("Graph rebuild failed: {}", e); - } + + // Mark graph as dirty — will be lazily rebuilt on next query() + // Also reset failure count so the new data gets a fresh rebuild attempt. + if !items.is_empty() && self.config.graph.enabled { + self.graph_dirty.store(true, Ordering::Relaxed); + self.graph_fail_count.store(0, Ordering::Relaxed); } - return Ok(IndexResult::with_partial(items, failed)); - } - // Multiple sources: parallel indexing - let concurrency = self - .config - .concurrency - .max_concurrent_requests - .min(ctx.sources.len()); + Ok(IndexResult::with_partial(items, failed)) + }) + .await + } + /// Process multiple sources in parallel. + async fn process_sources( + &self, + sources: &[IndexSource], + options: &super::types::IndexOptions, + name: Option<&str>, + concurrency: usize, + ) -> (Vec, Vec) { let results: Vec<(Vec, Vec)> = - futures::stream::iter(ctx.sources.iter().cloned()) + futures::stream::iter(sources.iter().cloned()) .map(|source| { - let options = ctx.options.clone(); - let name = ctx.name.clone(); + let options = options.clone(); + let name = name.map(str::to_string); let engine = self.clone(); async move { engine @@ -222,36 +241,18 @@ impl Engine { .collect() .await; - let mut items = Vec::new(); - let mut failed = Vec::new(); - for (ok, err) in results { - items.extend(ok); - failed.extend(err); - } - - if items.is_empty() && !failed.is_empty() { - return Err(Error::Config(format!( - "All {} source(s) failed to index: {}", - failed.len(), - failed - .iter() - .map(|f| format!("{} ({})", f.source, f.error)) - .collect::>() - .join("; ") - ))); - } - - // Rebuild document graph after successful batch index - if !items.is_empty() { - if let Err(e) = self.rebuild_graph().await { - tracing::warn!("Graph rebuild failed: {}", e); - } - } - - Ok(IndexResult::with_partial(items, failed)) + results.into_iter().fold( + (Vec::new(), Vec::new()), + |(mut items, mut failed), (ok, err)| { + items.extend(ok); + failed.extend(err); + (items, failed) + }, + ) } /// Process a single source — resolve action and index. + #[tracing::instrument(skip_all, fields(source = %source))] /// /// Returns `(items, failed)`. async fn process_source( @@ -260,6 +261,16 @@ impl Engine { options: &super::types::IndexOptions, name: Option<&str>, ) -> (Vec, Vec) { + if self.is_cancelled() { + return ( + Vec::new(), + vec![FailedItem::new( + source.to_string(), + "Operation cancelled".to_string(), + )], + ); + } + let source_label = source.to_string(); match self.resolve_index_action(source, options).await { @@ -277,43 +288,19 @@ impl Engine { ) } Ok(IndexAction::FullIndex { existing_id }) => { - match self.indexer.index(source, name, options).await { + let pipeline_options = self.build_pipeline_options(options, source); + match self + .index_with_retry(source, name, pipeline_options.clone(), None) + .await + { Ok(doc) => { - let pipeline_options = self.build_pipeline_options(options, doc.format); - let metrics = doc.metrics.clone(); - let item = IndexItem::new( - doc.id.clone(), - doc.name.clone(), - doc.format.clone(), - doc.description.clone(), - doc.page_count, - ) - .with_source_path( - doc.source_path - .as_ref() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_default(), + self.index_and_persist( + doc, + &pipeline_options, + &source_label, + existing_id.as_deref(), ) - .with_metrics_opt(metrics); - let persisted = self - .indexer - .to_persisted_with_options(doc, &pipeline_options); - - if let Some(ref workspace) = self.workspace { - if let Err(e) = workspace.save(&persisted).await { - return ( - Vec::new(), - vec![FailedItem::new(&source_label, e.to_string())], - ); - } - // Clean up old document after successful save (atomic: save-first, then remove old) - if let Some(old_id) = &existing_id { - let _ = workspace.remove(old_id).await; - } - } - - info!("Indexed document: {}", item.doc_id); - (vec![item], Vec::new()) + .await } Err(e) => { tracing::warn!("Failed to index {}: {}", source_label, e); @@ -329,45 +316,15 @@ impl Engine { existing_id, }) => { info!("Incremental update for: {}", source_label); + let pipeline_options = self.build_pipeline_options(options, source); match self - .indexer - .index_with_existing(source, name, options, Some(&old_tree)) + .index_with_retry(source, name, pipeline_options.clone(), Some(&old_tree)) .await { Ok(mut doc) => { doc.id = existing_id.clone(); - let pipeline_options = self.build_pipeline_options(options, doc.format); - let metrics = doc.metrics.clone(); - let item = IndexItem::new( - doc.id.clone(), - doc.name.clone(), - doc.format.clone(), - doc.description.clone(), - doc.page_count, - ) - .with_source_path( - doc.source_path - .as_ref() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_default(), - ) - .with_metrics_opt(metrics); - let persisted = self - .indexer - .to_persisted_with_options(doc, &pipeline_options); - - if let Some(ref workspace) = self.workspace { - // save() is atomic (write-lock + put), no need to remove first - if let Err(e) = workspace.save(&persisted).await { - return ( - Vec::new(), - vec![FailedItem::new(&source_label, e.to_string())], - ); - } - } - - info!("Incrementally updated: {}", item.doc_id); - (vec![item], Vec::new()) + self.index_and_persist(doc, &pipeline_options, &source_label, None) + .await } Err(e) => { tracing::warn!("Incremental update failed for {}: {}", source_label, e); @@ -388,6 +345,105 @@ impl Engine { } } + /// Index with retry on retryable errors. + /// + /// Reads `config.llm.retry` for backoff parameters. + /// Returns `Err` only after all retries are exhausted or the error + /// is not retryable. + async fn index_with_retry( + &self, + source: &IndexSource, + name: Option<&str>, + pipeline_options: PipelineOptions, + existing_tree: Option<&DocumentTree>, + ) -> Result { + let retry = &self.config.llm.retry; + let max_attempts = retry.max_attempts; + + for attempt in 0..max_attempts { + if self.is_cancelled() { + return Err(Error::Config("Operation cancelled".into())); + } + + let result = if let Some(tree) = existing_tree { + self.indexer + .index_with_existing(source, name, pipeline_options.clone(), Some(tree)) + .await + } else { + self.indexer + .index(source, name, pipeline_options.clone()) + .await + }; + + match result { + Ok(doc) => return Ok(doc), + Err(e) if e.is_retryable() && attempt + 1 < max_attempts => { + let delay = retry.delay_for_attempt(attempt); + tracing::warn!( + attempt, + max_attempts, + ?delay, + "Retryable error indexing, retrying: {e}" + ); + tokio::time::sleep(delay).await; + } + Err(e) => return Err(e), + } + } + + // Unreachable: loop always returns via Ok/Err branches + unreachable!() + } + + /// Convert an [`IndexedDocument`] to an [`IndexItem`] and persist it. + /// + /// If `old_id` is provided, the old document is removed after a + /// successful save (atomic save-first, then remove old). + async fn index_and_persist( + &self, + doc: super::indexed_document::IndexedDocument, + pipeline_options: &PipelineOptions, + source_label: &str, + old_id: Option<&str>, + ) -> (Vec, Vec) { + let item = Self::build_index_item(&doc); + let persisted = IndexerClient::to_persisted(doc, pipeline_options).await; + + if let Err(e) = self.workspace.save(&persisted).await { + return ( + Vec::new(), + vec![FailedItem::new(source_label, e.to_string())], + ); + } + // Clean up old document after successful save + if let Some(old_id) = old_id { + if let Err(e) = self.workspace.remove(old_id).await { + tracing::warn!("Failed to remove old document {}: {}", old_id, e); + } + } + + info!("Indexed document: {}", item.doc_id); + (vec![item], Vec::new()) + } + + /// Build an [`IndexItem`] from an [`IndexedDocument`](super::indexed_document::IndexedDocument). + fn build_index_item(doc: &super::indexed_document::IndexedDocument) -> IndexItem { + IndexItem::new( + doc.id.clone(), + doc.name.clone(), + doc.format.clone(), + doc.description.clone(), + doc.page_count, + ) + .with_source_path( + doc.source_path + .as_ref() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_default(), + ) + .with_metrics_opt(doc.metrics.clone()) + } + // ============================================================ // Document Querying // ============================================================ @@ -396,97 +452,118 @@ impl Engine { /// /// Accepts a [`QueryContext`] that specifies the query text and scope /// (single document, multiple documents, or entire workspace). - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::client::{EngineBuilder, QueryContext}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let engine = EngineBuilder::new() - /// .with_key("sk-...") - /// .with_model("gpt-4o") - /// .build() - /// .await?; - /// - /// // Single document - /// let result = engine.query( - /// QueryContext::new("What is the total revenue?") - /// .with_doc_ids(vec!["doc-123".to_string()]) - /// ).await?; - /// - /// if let Some(item) = result.single() { - /// println!("Answer: {}", item.content); - /// } - /// - /// // Entire workspace - /// let result = engine.query( - /// QueryContext::new("Summarize all documents") - /// ).await?; - /// for item in &result.items { - /// println!("{}: score={}", item.doc_id, item.score); - /// } - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(skip_all, fields(query = %ctx.query))] pub async fn query(&self, ctx: QueryContext) -> Result { - let doc_ids = self.resolve_scope(&ctx.scope).await?; - let mut options = ctx.to_retrieve_options(&self.config); - - // Load document graph for graph-aware retrieval (if enabled) - if self.config.graph.enabled { - if let Some(ref workspace) = self.workspace { - if let Ok(Some(graph)) = workspace.get_graph().await { + self.check_cancel()?; + let _guard = self.inc_active(); + let timeout_secs = ctx.timeout_secs; + + self.with_timeout(timeout_secs, async move { + let doc_ids = self.resolve_scope(&ctx.scope).await?; + let mut options = ctx.to_retrieve_options(&self.config); + + // Lazy graph rebuild: only rebuild if index() marked it dirty + if self.config.graph.enabled { + let fail_count = self.graph_fail_count.load(Ordering::Relaxed); + let should_try = fail_count < GRAPH_REBUILD_MAX_FAILURES; + + if self.graph_dirty.swap(false, Ordering::Relaxed) { + if should_try { + if let Err(e) = self.rebuild_graph().await { + let count = self.graph_fail_count.fetch_add(1, Ordering::Relaxed) + 1; + tracing::warn!(count, "Graph rebuild failed: {e}"); + // Re-mark dirty so next query retries + self.graph_dirty.store(true, Ordering::Relaxed); + } else { + // Reset failure count on success + self.graph_fail_count.store(0, Ordering::Relaxed); + } + } else { + tracing::warn!( + count = fail_count, + "Skipping graph rebuild after {} consecutive failures", + fail_count + ); + } + } + // Load (now up-to-date) graph for retrieval + if let Ok(Some(graph)) = self.workspace.get_graph().await { options = options.with_document_graph(Arc::new(graph)); } } - } - let mut items = Vec::with_capacity(doc_ids.len()); - let mut failed = Vec::new(); - - // TODO: if doc_ids.len() > 1, consider parallelizing queries across documents (with concurrency limit) - for doc_id in doc_ids { - let (tree, reasoning_index) = match self.get_structure(&doc_id).await { - Ok((t, ri)) => (t, ri), - Err(e) => { - tracing::warn!("Skipping document {}: {}", doc_id, e); - failed.push(FailedItem::new(&doc_id, e.to_string())); - continue; - } - }; + // Query documents in parallel (with concurrency limit) + let concurrency = self.config.llm.throttle.max_concurrent_requests; + let query = ctx.query.clone(); + let cancelled = Arc::clone(&self.cancelled); + + let results: Vec<(String, std::result::Result)> = + futures::stream::iter(doc_ids.into_iter()) + .map(|doc_id| { + let engine = self.clone(); + let options = options.clone(); + let query = query.clone(); + let cancelled = Arc::clone(&cancelled); + async move { + if cancelled.load(Ordering::Relaxed) { + return (doc_id, Err("Operation cancelled".to_string())); + } - match self - .retriever - .query_with_reasoning_index(&tree, &ctx.query, &options, reasoning_index) - .await - { - Ok(mut result) => { - result.doc_id = doc_id; - items.push(result); - } - Err(e) => { - tracing::warn!("Query failed for {}: {}", doc_id, e); - failed.push(FailedItem::new(&doc_id, e.to_string())); + let (tree, reasoning_index) = match engine.get_structure(&doc_id).await + { + Ok(t) => t, + Err(e) => return (doc_id, Err(e.to_string())), + }; + + match engine + .retriever + .query_with_reasoning_index( + &tree, + &query, + &options, + reasoning_index, + ) + .await + { + Ok(mut result) => { + result.doc_id = doc_id.clone(); + (doc_id, Ok(result)) + } + Err(e) => (doc_id, Err(e.to_string())), + } + } + }) + .buffer_unordered(concurrency) + .collect() + .await; + + let mut items = Vec::new(); + let mut failed = Vec::new(); + for (doc_id, result) in results { + match result { + Ok(item) => items.push(item), + Err(e) => { + tracing::warn!("Query failed for {}: {}", doc_id, e); + failed.push(FailedItem::new(&doc_id, e)); + } } } - } - // If everything failed, return error - if items.is_empty() && !failed.is_empty() { - return Err(Error::Config(format!( - "Query failed for all {} document(s): {}", - failed.len(), - failed - .iter() - .map(|f| format!("{} ({})", f.source, f.error)) - .collect::>() - .join("; ") - ))); - } + if items.is_empty() && !failed.is_empty() { + return Err(Error::Config(format!( + "Query failed for all {} document(s): {}", + failed.len(), + failed + .iter() + .map(|f| format!("{} ({})", f.source, f.error)) + .collect::>() + .join("; ") + ))); + } - Ok(QueryResult::with_partial(items, failed)) + Ok(QueryResult::with_partial(items, failed)) + }) + .await } /// Query a document with streaming results. @@ -522,44 +599,24 @@ impl Engine { /// Get a list of all indexed documents. pub async fn list(&self) -> Result> { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - workspace.list().await + self.workspace.list().await } /// Remove a document from the workspace. pub async fn remove(&self, doc_id: &str) -> Result { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - workspace.remove(doc_id).await + self.workspace.remove(doc_id).await } /// Check if a document exists in the workspace. pub async fn exists(&self, doc_id: &str) -> Result { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - workspace.exists(doc_id).await + self.workspace.exists(doc_id).await } /// Remove all documents from the workspace. /// /// Returns the number of documents removed. pub async fn clear(&self) -> Result { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - workspace.clear().await + self.workspace.clear().await } /// Get the cross-document relationship graph. @@ -567,12 +624,7 @@ impl Engine { /// The graph is automatically rebuilt after indexing documents. /// Returns `None` if no graph has been built yet. pub async fn get_graph(&self) -> Result> { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - workspace.get_graph().await + self.workspace.get_graph().await } /// Generate a complete metrics report. @@ -583,24 +635,79 @@ impl Engine { self.metrics_hub.generate_report() } + /// Cancel all in-flight `index()` and `query()` operations. + /// + /// After calling this, running operations will return at the next + /// convenient point with a cancellation error. New operations will + /// also fail until [`reset_cancel`](Self::reset_cancel) is called. + pub fn cancel(&self) { + self.cancelled.store(true, Ordering::Relaxed); + tracing::info!("Cancellation requested"); + } + + /// Reset the cancel flag so new operations can proceed. + pub fn reset_cancel(&self) { + self.cancelled.store(false, Ordering::Relaxed); + tracing::info!("Cancel flag reset"); + } + + /// Returns `true` if cancellation has been requested. + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Relaxed) + } + // ============================================================ // Internal // ============================================================ + /// Check cancel flag, returning an error if cancelled. + fn check_cancel(&self) -> Result<()> { + if self.cancelled.load(Ordering::Relaxed) { + return Err(Error::Config("Operation cancelled".into())); + } + Ok(()) + } + + /// Increment active operation counter. Returns a guard that decrements on drop. + fn inc_active(&self) -> ActiveGuard { + let mut ops = self.active_ops.lock().unwrap(); + *ops += 1; + ActiveGuard { + active_ops: Arc::clone(&self.active_ops), + } + } + + /// Get current active operation count. + pub fn active_operations(&self) -> usize { + *self.active_ops.lock().unwrap() + } + + /// Run a future with an optional timeout. + /// If `timeout_secs` is `Some`, wraps the future in `tokio::time::timeout`. + async fn with_timeout(&self, timeout_secs: Option, fut: F) -> Result + where + F: std::future::Future>, + { + match timeout_secs { + Some(secs) => { + match tokio::time::timeout(std::time::Duration::from_secs(secs), fut).await { + Ok(result) => result, + Err(_) => Err(Error::Config(format!("Operation timed out after {secs}s"))), + } + } + None => fut.await, + } + } + /// Get document structure (tree) and optional reasoning index. Internal use only. pub(crate) async fn get_structure( &self, doc_id: &str, ) -> Result<(DocumentTree, Option)> { - let workspace = self - .workspace - .as_ref() - .ok_or_else(|| Error::Config("No workspace configured".to_string()))?; - - let doc = workspace - .load(doc_id) - .await? - .ok_or_else(|| Error::DocumentNotFound(format!("Document not found: {}", doc_id)))?; + let doc = + self.workspace.load(doc_id).await?.ok_or_else(|| { + Error::DocumentNotFound(format!("Document not found: {}", doc_id)) + })?; Ok((doc.tree, doc.reasoning_index)) } @@ -619,18 +726,31 @@ impl Engine { } } - /// Build pipeline options from client IndexOptions and detected format. + /// Build pipeline options for pipeline execution (with checkpoint dir). + /// + /// This is the single source of truth for pipeline configuration. fn build_pipeline_options( &self, options: &super::types::IndexOptions, - format: crate::index::parse::DocumentFormat, + source: &IndexSource, ) -> PipelineOptions { - use crate::index::SummaryStrategy; - let checkpoint_dir = self.workspace_dir.as_ref().map(|p| p.join("checkpoints")); + use crate::index::{IndexMode, ReasoningIndexConfig, SummaryStrategy}; + + let format = match source { + IndexSource::Path(path) => self + .indexer + .detect_format_from_path(path) + .unwrap_or(crate::index::parse::DocumentFormat::Markdown), + IndexSource::Content { format, .. } => *format, + IndexSource::Bytes { format, .. } => *format, + }; + + let checkpoint_dir = Some(self.config.storage.checkpoint_dir.clone()); + PipelineOptions { mode: match format { - crate::index::parse::DocumentFormat::Markdown => crate::index::IndexMode::Markdown, - crate::index::parse::DocumentFormat::Pdf => crate::index::IndexMode::Pdf, + crate::index::parse::DocumentFormat::Markdown => IndexMode::Markdown, + crate::index::parse::DocumentFormat::Pdf => IndexMode::Pdf, }, generate_ids: options.generate_ids, summary_strategy: if options.generate_summaries { @@ -640,65 +760,22 @@ impl Engine { }, generate_description: options.generate_description, checkpoint_dir, + reasoning_index: ReasoningIndexConfig { + enable_synonym_expansion: options.enable_synonym_expansion, + ..ReasoningIndexConfig::default() + }, + concurrency: self.config.llm.throttle.to_runtime_config(), ..Default::default() } } - /// Rebuild the document graph after indexing, if graph is enabled. - async fn rebuild_graph(&self) -> Result<()> { - if !self.config.graph.enabled { - return Ok(()); - } - let workspace = match self.workspace { - Some(ref ws) => ws, - None => return Ok(()), - }; - - // Load all documents and extract keyword profiles - let doc_ids = workspace.inner().list_documents().await; - let mut builder = crate::graph::DocumentGraphBuilder::new(self.config.graph.clone()); - - for doc_id in &doc_ids { - if let Some(doc) = workspace.load(doc_id).await? { - let keywords = Self::extract_keywords_from_doc(&doc); - builder.add_document( - &doc.meta.id, - &doc.meta.name, - &doc.meta.format, - doc.meta.node_count, - keywords, - ); - } - } - - let graph = builder.build(); - workspace.set_graph(&graph).await?; - Ok(()) - } - - /// Extract keyword → weight map from a persisted document's ReasoningIndex. - fn extract_keywords_from_doc(doc: &PersistedDocument) -> HashMap { - let mut keywords = HashMap::new(); - if let Some(ref ri) = doc.reasoning_index { - for (kw, entries) in ri.all_topic_entries() { - let weight: f32 = - entries.iter().map(|e| e.weight).sum::() / entries.len().max(1) as f32; - keywords.insert(kw.clone(), weight); - } - } - keywords - } - /// Resolve what action to take for a source. async fn resolve_index_action( &self, source: &IndexSource, options: &super::types::IndexOptions, ) -> Result { - let workspace = match self.workspace { - Some(ref ws) => ws, - None => return Ok(IndexAction::FullIndex { existing_id: None }), - }; + let workspace = &self.workspace; // Force mode always re-indexes from scratch if options.mode == IndexMode::Force { @@ -735,7 +812,7 @@ impl Engine { } // Incremental mode: load stored document and delegate to resolver - let current_bytes = match std::fs::read(path) { + let current_bytes = match tokio::fs::read(path).await { Ok(b) => b, Err(_) => return Ok(IndexAction::FullIndex { existing_id: None }), }; @@ -747,7 +824,7 @@ impl Engine { let format = crate::index::parse::DocumentFormat::from_extension(&stored_doc.meta.format) .unwrap_or(crate::index::parse::DocumentFormat::Markdown); - let pipeline_options = self.build_pipeline_options(options, format); + let pipeline_options = self.build_pipeline_options(options, source); // If logic fingerprint changed, remove old doc before full reprocess let action = @@ -758,6 +835,55 @@ impl Engine { Ok(action) } + + /// Rebuild the document graph after indexing, if graph is enabled. + async fn rebuild_graph(&self) -> Result<()> { + if !self.config.graph.enabled { + return Ok(()); + } + + // Load all documents in parallel and extract keyword profiles + let doc_ids = self.workspace.inner().list_documents().await; + let concurrency = self.config.llm.throttle.max_concurrent_requests; + + let loaded: Vec> = futures::stream::iter(doc_ids.iter().cloned()) + .map(|doc_id| { + let ws = self.workspace.clone(); + async move { ws.load(&doc_id).await.ok().flatten() } + }) + .buffer_unordered(concurrency) + .collect() + .await; + + let mut builder = crate::graph::DocumentGraphBuilder::new(self.config.graph.clone()); + for doc in loaded.into_iter().flatten() { + let keywords = Self::extract_keywords_from_doc(&doc); + builder.add_document( + &doc.meta.id, + &doc.meta.name, + &doc.meta.format, + doc.meta.node_count, + keywords, + ); + } + + let graph = builder.build(); + self.workspace.set_graph(&graph).await?; + Ok(()) + } + + /// Extract keyword → weight map from a persisted document's ReasoningIndex. + fn extract_keywords_from_doc(doc: &PersistedDocument) -> HashMap { + let mut keywords = HashMap::new(); + if let Some(ref ri) = doc.reasoning_index { + for (kw, entries) in ri.all_topic_entries() { + let weight: f32 = + entries.iter().map(|e| e.weight).sum::() / entries.len().max(1) as f32; + keywords.insert(kw.clone(), weight); + } + } + keywords + } } impl Clone for Engine { @@ -767,28 +893,132 @@ impl Clone for Engine { indexer: self.indexer.clone(), retriever: self.retriever.clone(), workspace: self.workspace.clone(), - workspace_dir: self.workspace_dir.clone(), - events: self.events.clone(), metrics_hub: Arc::clone(&self.metrics_hub), + graph_dirty: Arc::clone(&self.graph_dirty), + graph_fail_count: Arc::clone(&self.graph_fail_count), + cancelled: Arc::clone(&self.cancelled), + active_ops: Arc::clone(&self.active_ops), } } } +/// RAII guard that decrements `active_ops` on drop. +struct ActiveGuard { + active_ops: Arc>, +} + +impl Drop for ActiveGuard { + fn drop(&mut self) { + let mut ops = self.active_ops.lock().unwrap(); + *ops = ops.saturating_sub(1); + } +} + impl std::fmt::Debug for Engine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Engine") - .field("has_workspace", &self.workspace.is_some()) - .finish_non_exhaustive() + f.debug_struct("Engine").finish_non_exhaustive() } } #[cfg(test)] mod tests { - use super::super::EngineBuilder; + use super::*; + use crate::client::types::IndexMode; + + // ── Cancel ──────────────────────────────────────────────────────────── #[test] - fn test_engine_builder() { - let builder = EngineBuilder::new(); - let _ = builder; + fn test_cancel_flag() { + // We can't construct a full Engine without async + LLM, so test the + // underlying primitives directly. + let flag = Arc::new(AtomicBool::new(false)); + assert!(!flag.load(Ordering::Relaxed)); + + flag.store(true, Ordering::Relaxed); + assert!(flag.load(Ordering::Relaxed)); + + flag.store(false, Ordering::Relaxed); + assert!(!flag.load(Ordering::Relaxed)); + } + + #[test] + fn test_graph_dirty_flag() { + let dirty = Arc::new(AtomicBool::new(false)); + assert!(!dirty.load(Ordering::Relaxed)); + + // Simulate: index marks dirty + dirty.store(true, Ordering::Relaxed); + + // Simulate: query swaps to false and rebuilds + let was_dirty = dirty.swap(false, Ordering::Relaxed); + assert!(was_dirty); + assert!(!dirty.load(Ordering::Relaxed)); + } + + #[test] + fn test_active_guard_decrement() { + let active_ops: Arc> = Arc::new(Mutex::new(0)); + + // Increment + { + let mut ops = active_ops.lock().unwrap(); + *ops += 1; + } + + assert_eq!(*active_ops.lock().unwrap(), 1); + + // Drop guard (simulate ActiveGuard drop) + { + let mut ops = active_ops.lock().unwrap(); + *ops = ops.saturating_sub(1); + } + + assert_eq!(*active_ops.lock().unwrap(), 0); + } + + // ── resolve_index_action Default mode ────────────────────────────────── + + // We can't call resolve_index_action without a workspace, but we can + // verify IndexMode equality logic used inside. + #[test] + fn test_index_mode_force_skips_incremental() { + let mode = IndexMode::Force; + assert_eq!(mode, IndexMode::Force); + assert_ne!(mode, IndexMode::Default); + assert_ne!(mode, IndexMode::Incremental); + } + + // ── build_index_item ────────────────────────────────────────────────── + + // Build_index_item only transforms data — no I/O. + use crate::client::indexed_document::IndexedDocument; + + fn make_doc() -> IndexedDocument { + IndexedDocument::new("test-id", crate::index::parse::DocumentFormat::Markdown) + .with_name("test.md") + .with_description("test doc") + .with_source_path(std::path::PathBuf::from("/tmp/test.md")) + } + + #[test] + fn test_build_index_item() { + let doc = make_doc(); + let item = Engine::build_index_item(&doc); + + assert_eq!(item.doc_id, "test-id"); + assert_eq!(item.name, "test.md"); + assert_eq!(item.format, crate::index::parse::DocumentFormat::Markdown); + assert_eq!(item.description, Some("test doc".to_string())); + assert_eq!(item.source_path, Some("/tmp/test.md".to_string())); + assert!(item.metrics.is_none()); + } + + #[test] + fn test_build_index_item_no_source_path() { + let doc = IndexedDocument::new("id", crate::index::parse::DocumentFormat::Pdf); + let item = Engine::build_index_item(&doc); + + assert_eq!(item.source_path, Some(String::new())); // unwrap_or_default + assert_eq!(item.format, crate::index::parse::DocumentFormat::Pdf); } } diff --git a/rust/src/client/index_context.rs b/rust/src/client/index_context.rs index 1ee324f1..30cb2502 100644 --- a/rust/src/client/index_context.rs +++ b/rust/src/client/index_context.rs @@ -146,7 +146,7 @@ impl IndexContext { /// Internal: scan a directory for supported document files. fn scan_dir(dir: impl Into, recursive: bool) -> Self { let dir = dir.into(); - let supported_extensions = ["md", "pdf"]; + let supported_extensions = DocumentFormat::SUPPORTED_EXTENSIONS; if !dir.exists() { tracing::warn!("Directory not found: {}", dir.display()); diff --git a/rust/src/client/indexed_document.rs b/rust/src/client/indexed_document.rs new file mode 100644 index 00000000..58560644 --- /dev/null +++ b/rust/src/client/indexed_document.rs @@ -0,0 +1,122 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Internal intermediate type produced by the indexing pipeline. +//! +//! [`IndexedDocument`] is an internal-only type that carries data from +//! [`IndexerClient`](super::indexer::IndexerClient) to [`Engine`](super::Engine). +//! It is **not** part of the public API. + +use std::path::PathBuf; + +use crate::document::DocumentTree; +use crate::index::parse::DocumentFormat; +use crate::metrics::IndexMetrics; +use crate::storage::PageContent; + +/// An indexed document with its tree structure and metadata. +/// +/// Internal intermediate produced by the indexing pipeline and consumed +/// by [`Engine`](super::Engine) to create a [`PersistedDocument`](crate::storage::PersistedDocument). +#[derive(Debug, Clone)] +pub(crate) struct IndexedDocument { + /// Unique document identifier. + pub id: String, + + /// Document format. + pub format: DocumentFormat, + + /// Document name/title. + pub name: String, + + /// Document description (generated by LLM). + pub description: Option, + + /// Source file path. + pub source_path: Option, + + /// Page count (for PDFs). + pub page_count: Option, + + /// The document tree structure. + pub tree: Option, + + /// Per-page content (for PDFs). + pub pages: Vec, + + /// Indexing pipeline metrics. + pub metrics: Option, + + /// Pre-computed reasoning index for retrieval acceleration. + pub reasoning_index: Option, +} + +impl IndexedDocument { + /// Create a new indexed document. + pub fn new(id: impl Into, format: DocumentFormat) -> Self { + Self { + id: id.into(), + format, + name: String::new(), + description: None, + source_path: None, + page_count: None, + tree: None, + pages: Vec::new(), + metrics: None, + reasoning_index: None, + } + } + + /// Set the document name. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } + + /// Set the document description. + pub fn with_description(mut self, desc: impl Into) -> Self { + self.description = Some(desc.into()); + self + } + + /// Set the source path. + pub fn with_source_path(mut self, path: impl Into) -> Self { + self.source_path = Some(path.into()); + self + } + + /// Set the page count. + pub fn with_page_count(mut self, count: usize) -> Self { + self.page_count = Some(count); + self + } + + /// Set the document tree. + pub fn with_tree(mut self, tree: DocumentTree) -> Self { + self.tree = Some(tree); + self + } + + /// Set the indexing metrics. + pub fn with_metrics(mut self, metrics: IndexMetrics) -> Self { + self.metrics = Some(metrics); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_indexed_document() { + let doc = IndexedDocument::new("doc-1", DocumentFormat::Markdown) + .with_name("Test Document") + .with_description("A test document"); + + assert_eq!(doc.id, "doc-1"); + assert_eq!(doc.name, "Test Document"); + assert!(doc.tree.is_none()); + } +} diff --git a/rust/src/client/indexer.rs b/rust/src/client/indexer.rs index 79c89b37..fa6a314f 100644 --- a/rust/src/client/indexer.rs +++ b/rust/src/client/indexer.rs @@ -28,14 +28,12 @@ use uuid::Uuid; use crate::error::{Error, Result}; use crate::index::parse::DocumentFormat; -use crate::index::{ - IndexInput, IndexMode, PipelineExecutor, PipelineOptions, ReasoningIndexConfig, SummaryStrategy, -}; +use crate::index::{IndexInput, IndexMode, PipelineExecutor, PipelineOptions}; use crate::llm::LlmClient; use crate::storage::{DocumentMeta, PersistedDocument}; use super::index_context::IndexSource; -use super::types::{IndexOptions, IndexedDocument}; +use super::indexed_document::IndexedDocument; use crate::events::{EventEmitter, IndexEvent}; /// Document indexing client. @@ -61,6 +59,14 @@ impl IndexerClient { } } + /// Create with a custom executor factory (for testing). + pub(crate) fn with_factory(factory: Arc PipelineExecutor + Send + Sync>) -> Self { + Self { + executor_factory: factory, + events: EventEmitter::new(), + } + } + /// Create with event emitter. pub fn with_events(mut self, events: EventEmitter) -> Self { self.events = events; @@ -68,46 +74,51 @@ impl IndexerClient { } /// Index a document from an index context. + /// + /// The caller provides fully constructed [`PipelineOptions`] + /// (including checkpoint dir, reasoning config, etc.). pub async fn index( &self, source: &IndexSource, name: Option<&str>, - options: &IndexOptions, + pipeline_options: PipelineOptions, ) -> Result { - self.index_with_existing(source, name, options, None).await + self.index_with_existing(source, name, pipeline_options, None) + .await } /// Index a document, optionally reusing an existing tree for incremental updates. + /// + /// The caller provides fully constructed [`PipelineOptions`]. pub async fn index_with_existing( &self, source: &IndexSource, name: Option<&str>, - options: &IndexOptions, + mut pipeline_options: PipelineOptions, existing_tree: Option<&crate::DocumentTree>, ) -> Result { + pipeline_options.existing_tree = existing_tree.cloned(); match source { - IndexSource::Path(path) => { - self.index_from_path(path, name, options, existing_tree) - .await - } + IndexSource::Path(path) => self.index_from_path(path, name, pipeline_options).await, IndexSource::Content { data, format } => { - self.index_from_content(data, *format, name, options, existing_tree) + self.index_from_content(data, *format, name, pipeline_options) .await } IndexSource::Bytes { data, format } => { - self.index_from_bytes(data, *format, name, options, existing_tree) + self.index_from_bytes(data, *format, name, pipeline_options) .await } } } /// Index from a file path. + /// + /// Uses the format from `PipelineOptions.mode` — no redundant detection. async fn index_from_path( &self, path: &Path, name: Option<&str>, - options: &IndexOptions, - existing_tree: Option<&crate::DocumentTree>, + pipeline_options: PipelineOptions, ) -> Result { let path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); @@ -126,31 +137,19 @@ impl IndexerClient { tracing::warn!("{}", warning); } - // Emit start event - self.events.emit_index(IndexEvent::Started { - path: path.display().to_string(), - }); - - // Generate document ID - let doc_id = Uuid::new_v4().to_string(); - - // Detect format from extension - let format = self.detect_format_from_path(&path)?; - self.events - .emit_index(IndexEvent::FormatDetected { format }); - - info!("Indexing {:?} document: {}", format, path.display()); - - // Build pipeline options - let pipeline_options = - self.build_pipeline_options_with_existing(options, format, existing_tree.cloned()); + // Resolve format from pipeline options (set by Engine) — no re-detection + let format = Self::format_from_mode(&pipeline_options.mode); - // Create pipeline input and execute let input = IndexInput::file(&path); - let mut executor = (self.executor_factory)(); - let result = executor.execute(input, pipeline_options).await?; - - self.build_indexed_document(doc_id, result, format, name, Some(&path)) + self.run_pipeline( + input, + format, + &path.display().to_string(), + name, + Some(&path), + pipeline_options, + ) + .await } /// Index from content string. @@ -159,8 +158,7 @@ impl IndexerClient { content: &str, format: DocumentFormat, name: Option<&str>, - options: &IndexOptions, - existing_tree: Option<&crate::DocumentTree>, + pipeline_options: PipelineOptions, ) -> Result { // Validate content before indexing let validation = crate::utils::validate_content(content, format); @@ -174,24 +172,16 @@ impl IndexerClient { )); } - self.events.emit_index(IndexEvent::Started { - path: name.unwrap_or("content").to_string(), - }); - - let doc_id = Uuid::new_v4().to_string(); - self.events - .emit_index(IndexEvent::FormatDetected { format }); - - info!("Indexing {:?} document from content", format); - - let pipeline_options = - self.build_pipeline_options_with_existing(options, format, existing_tree.cloned()); - let input = IndexInput::content(content); - let mut executor = (self.executor_factory)(); - let result = executor.execute(input, pipeline_options).await?; - - self.build_indexed_document(doc_id, result, format, name, None) + self.run_pipeline( + input, + format, + name.unwrap_or("content"), + name, + None, + pipeline_options, + ) + .await } /// Index from binary data. @@ -200,8 +190,7 @@ impl IndexerClient { bytes: &[u8], format: DocumentFormat, name: Option<&str>, - options: &IndexOptions, - existing_tree: Option<&crate::DocumentTree>, + pipeline_options: PipelineOptions, ) -> Result { // Validate bytes before indexing let validation = crate::utils::validate_bytes(bytes, format); @@ -215,56 +204,49 @@ impl IndexerClient { )); } - self.events.emit_index(IndexEvent::Started { - path: name.unwrap_or("bytes").to_string(), - }); - - let doc_id = Uuid::new_v4().to_string(); - self.events - .emit_index(IndexEvent::FormatDetected { format }); - info!( "Indexing {:?} document from bytes ({} bytes)", format, bytes.len() ); - let pipeline_options = - self.build_pipeline_options_with_existing(options, format, existing_tree.cloned()); - let input = IndexInput::bytes(bytes); - let mut executor = (self.executor_factory)(); - let result = executor.execute(input, pipeline_options).await?; - - self.build_indexed_document(doc_id, result, format, name, None) + self.run_pipeline( + input, + format, + name.unwrap_or("bytes"), + name, + None, + pipeline_options, + ) + .await } - /// Build pipeline options with optional existing tree for incremental updates. - fn build_pipeline_options_with_existing( + /// Common pipeline execution: emit events → run pipeline → build result. + #[tracing::instrument(skip_all, fields(format = ?format, source = %source_label))] + async fn run_pipeline( &self, - options: &IndexOptions, + input: IndexInput, format: DocumentFormat, - existing_tree: Option, - ) -> PipelineOptions { - PipelineOptions { - mode: match format { - DocumentFormat::Markdown => IndexMode::Markdown, - DocumentFormat::Pdf => IndexMode::Pdf, - }, - generate_ids: options.generate_ids, - summary_strategy: if options.generate_summaries { - SummaryStrategy::full() - } else { - SummaryStrategy::none() - }, - generate_description: options.generate_description, - reasoning_index: ReasoningIndexConfig { - enable_synonym_expansion: options.enable_synonym_expansion, - ..ReasoningIndexConfig::default() - }, - existing_tree, - ..Default::default() - } + source_label: &str, + name: Option<&str>, + path: Option<&Path>, + pipeline_options: PipelineOptions, + ) -> Result { + self.events.emit_index(IndexEvent::Started { + path: source_label.to_string(), + }); + + let doc_id = Uuid::new_v4().to_string(); + self.events + .emit_index(IndexEvent::FormatDetected { format }); + + info!("Indexing {:?} document: {}", format, source_label); + + let mut executor = (self.executor_factory)(); + let result = executor.execute(input, pipeline_options).await?; + + self.build_indexed_document(doc_id, result, format, name, path) } /// Build indexed document from pipeline result. @@ -316,16 +298,32 @@ impl IndexerClient { Ok(doc) } + /// Resolve `DocumentFormat` from `PipelineOptions.mode`. + /// + /// Falls back to Markdown for `Auto` mode (the engine resolves + /// `Auto` to a concrete format before calling the indexer). + fn format_from_mode(mode: &IndexMode) -> DocumentFormat { + match mode { + IndexMode::Markdown => DocumentFormat::Markdown, + IndexMode::Pdf => DocumentFormat::Pdf, + IndexMode::Auto => DocumentFormat::Markdown, + } + } + /// Detect document format from file extension. - fn detect_format_from_path(&self, path: &Path) -> Result { + pub(crate) fn detect_format_from_path(&self, path: &Path) -> Result { let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); DocumentFormat::from_extension(ext) .ok_or_else(|| Error::Parse(format!("Unsupported format: {}", ext))) } - /// Convert IndexedDocument to PersistedDocument, storing fingerprints from pipeline options. - pub fn to_persisted_with_options( - &self, + /// Convert [`IndexedDocument`] to [`PersistedDocument`]. + /// + /// This is an associated function — it does not depend on client state. + /// Stores content and logic fingerprints from the pipeline options. + /// + /// Uses async file I/O to avoid blocking the tokio runtime. + pub async fn to_persisted( doc: IndexedDocument, pipeline_options: &PipelineOptions, ) -> PersistedDocument { @@ -338,9 +336,9 @@ impl IndexerClient { ) .with_description(doc.description.clone().unwrap_or_default()); - // Compute content fingerprint for incremental indexing + // Compute content fingerprint for incremental indexing (async I/O) if let Some(ref path) = doc.source_path { - if let Ok(bytes) = std::fs::read(path) { + if let Ok(bytes) = tokio::fs::read(path).await { let fp = crate::utils::fingerprint::Fingerprint::from_bytes(&bytes); meta = meta.with_fingerprint(fp); } diff --git a/rust/src/client/mod.rs b/rust/src/client/mod.rs index ce00ff34..903316fa 100644 --- a/rust/src/client/mod.rs +++ b/rust/src/client/mod.rs @@ -68,9 +68,11 @@ mod builder; mod engine; mod index_context; +mod indexed_document; mod indexer; mod query_context; mod retriever; +pub(crate) mod test_support; mod types; mod workspace; @@ -93,8 +95,8 @@ pub use query_context::QueryContext; // ============================================================ pub use types::{ - ClientError, DocumentInfo, FailedItem, IndexItem, IndexMode, IndexOptions, IndexResult, - QueryResult, QueryResultItem, + DocumentInfo, FailedItem, IndexItem, IndexMode, IndexOptions, IndexResult, QueryResult, + QueryResultItem, }; // ============================================================ diff --git a/rust/src/client/query_context.rs b/rust/src/client/query_context.rs index 24d5d8b0..3b8f0726 100644 --- a/rust/src/client/query_context.rs +++ b/rust/src/client/query_context.rs @@ -56,10 +56,12 @@ pub struct QueryContext { pub(crate) max_tokens: Option, /// Retrieval strategy override. pub(crate) strategy: Option, - /// Whether to include the reasoning chain in the result. + /// Whether to include the pilot reasoning chain in the result. pub(crate) include_reasoning: bool, - /// Maximum tree traversal depth. + /// Maximum tree traversal depth for the pilot. pub(crate) depth_limit: Option, + /// Per-operation timeout (seconds). `None` means no timeout. + pub(crate) timeout_secs: Option, } impl QueryContext { @@ -72,6 +74,7 @@ impl QueryContext { strategy: None, include_reasoning: true, depth_limit: None, + timeout_secs: None, } } @@ -102,18 +105,24 @@ impl QueryContext { self } - /// Set whether to include the reasoning chain. + /// Set whether to include the pilot reasoning chain. pub fn with_include_reasoning(mut self, include: bool) -> Self { self.include_reasoning = include; self } - /// Set the maximum tree traversal depth. + /// Set the maximum tree traversal depth for the pilot. pub fn with_depth_limit(mut self, depth: usize) -> Self { self.depth_limit = Some(depth); self } + /// Set per-operation timeout in seconds. + pub fn with_timeout_secs(mut self, secs: u64) -> Self { + self.timeout_secs = Some(secs); + self + } + /// Convert to internal `RetrieveOptions`, merging with engine config. pub(crate) fn to_retrieve_options(&self, config: &Config) -> RetrieveOptions { let mut opts = RetrieveOptions::new() @@ -153,7 +162,6 @@ mod tests { fn test_query_context_new() { let ctx = QueryContext::new("What is this?"); assert_eq!(ctx.query, "What is this?"); - assert!(ctx.include_reasoning); } #[test] @@ -194,10 +202,18 @@ mod tests { .with_doc_ids(vec!["doc-1".to_string()]) .with_max_tokens(4000) .with_include_reasoning(false) - .with_depth_limit(5); + .with_depth_limit(5) + .with_timeout_secs(60); assert_eq!(ctx.max_tokens, Some(4000)); assert!(!ctx.include_reasoning); assert_eq!(ctx.depth_limit, Some(5)); + assert_eq!(ctx.timeout_secs, Some(60)); + } + + #[test] + fn test_query_context_timeout_default() { + let ctx = QueryContext::new("test"); + assert_eq!(ctx.timeout_secs, None); } } diff --git a/rust/src/client/retriever.rs b/rust/src/client/retriever.rs index 3e826d0a..6e612571 100644 --- a/rust/src/client/retriever.rs +++ b/rust/src/client/retriever.rs @@ -22,7 +22,6 @@ use std::sync::Arc; use tracing::info; use super::types::QueryResultItem; -use crate::config::Config; use crate::document::{DocumentTree, ReasoningIndex}; use crate::error::{Error, Result}; use crate::events::{EventEmitter, QueryEvent}; @@ -36,24 +35,16 @@ pub(crate) struct RetrieverClient { /// Pipeline retriever. retriever: Arc, - /// Configuration reference. - config: Arc, - /// Event emitter. events: EventEmitter, - - /// Default retrieval options. - default_options: RetrieveOptions, } impl RetrieverClient { /// Create a new retriever client. - pub fn new(retriever: crate::retrieval::PipelineRetriever, config: Arc) -> Self { + pub fn new(retriever: crate::retrieval::PipelineRetriever) -> Self { Self { retriever: Arc::new(retriever), - config, events: EventEmitter::new(), - default_options: RetrieveOptions::default(), } } @@ -64,6 +55,7 @@ impl RetrieverClient { } /// Query a document tree with optional reasoning index for fast-path lookup. + #[tracing::instrument(skip_all, fields(question = %question))] /// /// # Errors /// @@ -146,24 +138,12 @@ impl RetrieverClient { info!("Streaming query: {:?}", question); - let (handle, rx) = self.retriever.retrieve_streaming(tree, question, options); - - // Spawn a sidecar task that forwards events to the EventEmitter - let events = self.events.clone(); - let question_owned = question.to_string(); - tokio::spawn(async move { - // The handle will complete when the streaming task finishes. - // We don't need to forward events individually here since - // the primary channel (rx) is returned to the caller. - // The EventEmitter events are already emitted above for Started. - // The caller can consume rx for detailed streaming events. - let _ = handle.await; - events.emit_query(QueryEvent::Complete { - total_results: 0, - confidence: 0.0, - }); - let _ = question_owned; // suppress unused warning - }); + let (_handle, rx) = self.retriever.retrieve_streaming(tree, question, options); + + // Note: The Complete event is NOT emitted via EventEmitter here because + // the streaming handle returns () — the actual result flows through the + // rx channel as RetrieveEvent::Completed { response }. Callers who need + // completion metrics should consume the channel. Ok(rx) } @@ -209,9 +189,7 @@ impl Clone for RetrieverClient { fn clone(&self) -> Self { Self { retriever: Arc::clone(&self.retriever), - config: Arc::clone(&self.config), events: self.events.clone(), - default_options: self.default_options.clone(), } } } @@ -222,9 +200,7 @@ mod tests { #[test] fn test_retriever_client_creation() { - let config = Arc::new(Config::default()); let retriever = crate::retrieval::PipelineRetriever::new(); - let client = RetrieverClient::new(retriever, config); - assert!(client.default_options.top_k > 0); + let _client = RetrieverClient::new(retriever); } } diff --git a/rust/src/client/test_support.rs b/rust/src/client/test_support.rs new file mode 100644 index 00000000..6b936024 --- /dev/null +++ b/rust/src/client/test_support.rs @@ -0,0 +1,52 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Test-only helpers for constructing Engine instances without a real LLM. +//! +//! This module is exposed via `vectorless::__test_support` and should **only** +//! be used in integration tests. + +use std::sync::Arc; + +use crate::client::engine::Engine; +use crate::client::indexer::IndexerClient; +use crate::config::Config; +use crate::events::EventEmitter; +use crate::index::PipelineExecutor; +use crate::metrics::MetricsHub; +use crate::retrieval::PipelineRetriever; +use crate::storage::Workspace; + +/// Build an `Engine` with a no-LLM pipeline for integration testing. +/// +/// The pipeline skips enhance/summary stages but exercises: +/// parse → build → validate → split → enrich → optimize. +/// +/// # Example +/// +/// ```rust,ignore +/// let tmp = tempfile::tempdir().unwrap(); +/// let engine = vectorless::__test_support::build_test_engine(tmp.path()).await; +/// ``` +pub async fn build_test_engine(workspace_dir: &std::path::Path) -> Engine { + let config = Config::default(); + + // No-LLM indexer: pipeline without enhance stage + let executor_factory: Arc PipelineExecutor + Send + Sync> = + Arc::new(|| PipelineExecutor::new()); + let indexer = IndexerClient::with_factory(executor_factory); + + let workspace = Workspace::new(workspace_dir).await.unwrap(); + let retriever = PipelineRetriever::new(); + + Engine::with_components( + config, + workspace, + retriever, + indexer, + EventEmitter::new(), + Arc::new(MetricsHub::with_defaults()), + ) + .await + .unwrap() +} diff --git a/rust/src/client/types.rs b/rust/src/client/types.rs index 5c638846..16503054 100644 --- a/rust/src/client/types.rs +++ b/rust/src/client/types.rs @@ -6,114 +6,10 @@ //! This module contains all types exposed in the public API. use serde::{Deserialize, Serialize}; -use std::path::PathBuf; -use crate::document::DocumentTree; use crate::index::parse::DocumentFormat; use crate::metrics::IndexMetrics; -// ============================================================ -// Document Types -// ============================================================ - -/// An indexed document with its tree structure and metadata. -#[derive(Debug, Clone)] -pub struct IndexedDocument { - /// Unique document identifier. - pub id: String, - - /// Document format. - pub format: DocumentFormat, - - /// Document name/title. - pub name: String, - - /// Document description (generated by LLM). - pub description: Option, - - /// Source file path. - pub source_path: Option, - - /// Page count (for PDFs). - pub page_count: Option, - - /// The document tree structure. - pub tree: Option, - - /// Per-page content (for PDFs). - pub pages: Vec, - - /// Indexing pipeline metrics. - pub metrics: Option, - - /// Pre-computed reasoning index for retrieval acceleration. - pub reasoning_index: Option, -} - -impl IndexedDocument { - /// Create a new indexed document. - pub fn new(id: impl Into, format: DocumentFormat) -> Self { - Self { - id: id.into(), - format, - name: String::new(), - description: None, - source_path: None, - page_count: None, - tree: None, - pages: Vec::new(), - metrics: None, - reasoning_index: None, - } - } - - /// Set the document name. - pub fn with_name(mut self, name: impl Into) -> Self { - self.name = name.into(); - self - } - - /// Set the document description. - pub fn with_description(mut self, desc: impl Into) -> Self { - self.description = Some(desc.into()); - self - } - - /// Set the source path. - pub fn with_source_path(mut self, path: impl Into) -> Self { - self.source_path = Some(path.into()); - self - } - - /// Set the page count. - pub fn with_page_count(mut self, count: usize) -> Self { - self.page_count = Some(count); - self - } - - /// Set the document tree. - pub fn with_tree(mut self, tree: DocumentTree) -> Self { - self.tree = Some(tree); - self - } - - /// Set the indexing metrics. - pub fn with_metrics(mut self, metrics: IndexMetrics) -> Self { - self.metrics = Some(metrics); - self - } -} - -/// Content for a single page. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PageContent { - /// Page number (1-based). - pub page: usize, - - /// Page text content. - pub content: String, -} - // ============================================================ // Partial Success // ============================================================ @@ -175,9 +71,6 @@ pub struct IndexOptions { /// Whether to generate summaries using LLM. pub generate_summaries: bool, - /// Whether to include node text in the tree. - pub include_text: bool, - /// Whether to generate node IDs. pub generate_ids: bool, @@ -188,6 +81,9 @@ pub struct IndexOptions { /// during reasoning index construction. Improves recall for /// queries that use different wording than the document. pub enable_synonym_expansion: bool, + + /// Per-operation timeout (seconds). `None` means no timeout. + pub timeout_secs: Option, } impl Default for IndexOptions { @@ -195,10 +91,10 @@ impl Default for IndexOptions { Self { mode: IndexMode::Default, generate_summaries: true, - include_text: true, generate_ids: true, generate_description: false, enable_synonym_expansion: true, + timeout_secs: None, } } } @@ -232,6 +128,12 @@ impl IndexOptions { self.mode = mode; self } + + /// Set per-operation timeout in seconds. + pub fn with_timeout_secs(mut self, secs: u64) -> Self { + self.timeout_secs = Some(secs); + self + } } // ============================================================ @@ -483,45 +385,10 @@ impl DocumentInfo { } } -// ============================================================ -// Error Types -// ============================================================ - -/// Client error types. -#[derive(Debug, Clone, thiserror::Error)] -pub enum ClientError { - /// Document not found. - #[error("Document not found: {0}")] - NotFound(String), - - /// Invalid operation. - #[error("Invalid operation: {0}")] - InvalidOperation(String), - - /// Configuration error. - #[error("Configuration error: {0}")] - Config(String), - - /// Timeout error. - #[error("Operation timed out")] - Timeout, -} - #[cfg(test)] mod tests { use super::*; - #[test] - fn test_indexed_document() { - let doc = IndexedDocument::new("doc-1", DocumentFormat::Markdown) - .with_name("Test Document") - .with_description("A test document"); - - assert_eq!(doc.id, "doc-1"); - assert_eq!(doc.name, "Test Document"); - assert!(doc.tree.is_none()); - } - #[test] fn test_index_options() { let options = IndexOptions::new() @@ -532,6 +399,15 @@ mod tests { assert_eq!(options.mode, IndexMode::Force); } + #[test] + fn test_index_options_timeout() { + let opts = IndexOptions::new().with_timeout_secs(30); + assert_eq!(opts.timeout_secs, Some(30)); + + let default = IndexOptions::default(); + assert_eq!(default.timeout_secs, None); + } + #[test] fn test_query_result() { let result = QueryResult::new(); diff --git a/rust/src/client/workspace.rs b/rust/src/client/workspace.rs index 7f7b570c..db296493 100644 --- a/rust/src/client/workspace.rs +++ b/rust/src/client/workspace.rs @@ -68,12 +68,23 @@ impl WorkspaceClient { /// Save a document to the workspace. /// + /// If a document with the same ID already exists, logs a warning + /// (this can happen during concurrent indexing of the same source). + /// /// # Errors /// /// Returns an error if the workspace write fails. pub async fn save(&self, doc: &PersistedDocument) -> Result<()> { let doc_id = doc.meta.id.clone(); + if self.workspace.contains(&doc_id).await { + tracing::warn!( + doc_id, + name = %doc.meta.name, + "Overwriting existing document — possible concurrent index of the same source" + ); + } + self.workspace.add(doc).await?; info!("Saved document: {}", doc_id); @@ -90,20 +101,15 @@ impl WorkspaceClient { /// /// Returns an error if the workspace read fails. pub async fn load(&self, doc_id: &str) -> Result> { - if !self.workspace.contains(doc_id).await { - return Ok(None); - } - let doc = self.workspace.load_and_cache(doc_id).await?; - let cache_hit = doc.is_some(); - if let Some(ref _doc) = doc { - debug!("Loaded document: {} (cache={})", doc_id, cache_hit); + if let Some(ref _d) = doc { + debug!("Loaded document: {}", doc_id); } self.events.emit_workspace(WorkspaceEvent::Loaded { doc_id: doc_id.to_string(), - cache_hit, + cache_hit: doc.is_some(), }); Ok(doc) @@ -194,19 +200,23 @@ impl WorkspaceClient { /// Returns an error if the workspace write fails. pub async fn clear(&self) -> Result { let doc_ids = self.workspace.list_documents().await; - let count = doc_ids.len(); + let mut removed = 0usize; for doc_id in &doc_ids { - let _ = self.workspace.remove(doc_id).await; + match self.workspace.remove(doc_id).await { + Ok(true) => removed += 1, + Ok(false) => {} + Err(e) => tracing::warn!("Failed to remove document {}: {}", doc_id, e), + } } - if count > 0 { - info!("Cleared workspace: {} documents removed", count); + if removed > 0 { + info!("Cleared workspace: {removed} documents removed"); self.events - .emit_workspace(WorkspaceEvent::Cleared { count }); + .emit_workspace(WorkspaceEvent::Cleared { count: removed }); } - Ok(count) + Ok(removed) } /// Get the underlying workspace Arc (for advanced use). diff --git a/rust/src/config/loader.rs b/rust/src/config/loader.rs deleted file mode 100644 index 6f8f6d13..00000000 --- a/rust/src/config/loader.rs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Configuration loader. -//! -//! Loads configuration from TOML files. -//! -//! # Configuration Priority -//! -//! Configuration is loaded in this order (later overrides earlier): -//! 1. Default configuration -//! 2. Config file(s) -//! -//! # Example -//! -//! ```rust,no_run -//! use vectorless::config::{ConfigLoader, Config}; -//! -//! // Load from file -//! let config = ConfigLoader::new() -//! .file("config.toml") -//! .load()?; -//! -//! // Load with validation -//! let config = ConfigLoader::new() -//! .file("config.toml") -//! .with_validation(true) -//! .load()?; -//! -//! // Layered configuration -//! let config = ConfigLoader::new() -//! .file("default.toml") -//! .file("production.toml") -//! .with_validation(true) -//! .load()?; -//! # Ok::<(), vectorless::config::ConfigError>(()) -//! ``` - -use std::path::{Path, PathBuf}; -use thiserror::Error; - -use super::merge::Merge; -use super::types::Config; -use super::validator::ConfigValidator; - -/// Configuration loading errors. -#[derive(Debug, Error)] -pub enum ConfigError { - /// Failed to read configuration file. - #[error("Failed to read config file: {0}")] - Io(#[from] std::io::Error), - - /// Failed to parse TOML. - #[error("Failed to parse config: {0}")] - Parse(#[from] toml::de::Error), - - /// Configuration file not found. - #[error("Config file not found: {0}")] - NotFound(PathBuf), - - /// Invalid configuration value. - #[error("Invalid configuration: {0}")] - Invalid(String), - - /// Configuration validation failed. - #[error("{0}")] - Validation(#[from] super::types::ConfigValidationError), -} - -/// Configuration loader. -#[derive(Debug)] -pub struct ConfigLoader { - /// Configuration file paths (loaded in order, later files override earlier). - files: Vec, - - /// Whether to validate after loading. - validate: bool, - - /// Custom validator (optional). - validator: Option, -} - -impl Default for ConfigLoader { - fn default() -> Self { - Self::new() - } -} - -impl ConfigLoader { - /// Create a new configuration loader with defaults. - pub fn new() -> Self { - Self { - files: Vec::new(), - validate: false, - validator: None, - } - } - - /// Specify a configuration file to load. - /// - /// Multiple files can be specified; later files override earlier ones. - pub fn file>(mut self, path: P) -> Self { - self.files.push(path.as_ref().to_path_buf()); - self - } - - /// Specify multiple configuration files. - pub fn files(mut self, paths: I) -> Self - where - I: IntoIterator, - P: AsRef, - { - self.files - .extend(paths.into_iter().map(|p| p.as_ref().to_path_buf())); - self - } - - /// Enable or disable validation after loading. - pub fn with_validation(mut self, validate: bool) -> Self { - self.validate = validate; - self - } - - /// Set a custom validator. - pub fn with_validator(mut self, validator: ConfigValidator) -> Self { - self.validator = Some(validator); - self - } - - /// Load the configuration. - /// - /// # Behavior - /// - /// 1. Start with default configuration - /// 2. Load and merge each specified file (in order) - /// 3. Validate configuration (if enabled) - /// - /// # Errors - /// - /// Returns an error if: - /// - A specified file doesn't exist - /// - A file can't be parsed as valid TOML - /// - Validation fails (when enabled) - pub fn load(self) -> Result { - let mut config = Config::default(); - - // Load and merge each file - for path in &self.files { - if path.exists() { - let content = std::fs::read_to_string(path)?; - let file_config: Config = toml::from_str(&content)?; - config.merge(&file_config, super::merge::MergeStrategy::Replace); - } else { - return Err(ConfigError::NotFound(path.clone())); - } - } - - // Validate if requested - if self.validate { - let validator = self.validator.unwrap_or_default(); - validator.validate(&config)?; - } - - Ok(config) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Config::default(); - assert_eq!(config.indexer.subsection_threshold, 300); - assert!(config.summary.model.is_empty()); - assert!(config.retrieval.model.is_empty()); - } - - #[test] - fn test_config_loader_defaults() { - let config = ConfigLoader::new().load().unwrap(); - assert_eq!(config.indexer.subsection_threshold, 300); - } - - #[test] - fn test_config_loader_not_found() { - let result = ConfigLoader::new().file("nonexistent_config.toml").load(); - - assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), ConfigError::NotFound(_))); - } - - #[test] - fn test_config_loader_with_validation() { - let config = ConfigLoader::new().with_validation(true).load().unwrap(); - - assert!(config.retrieval.model.is_empty()); - } -} diff --git a/rust/src/config/merge.rs b/rust/src/config/merge.rs deleted file mode 100644 index 7e524aad..00000000 --- a/rust/src/config/merge.rs +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Configuration merging. -//! -//! This module provides utilities for merging multiple configurations, -//! enabling layered configuration from multiple sources. - -use super::types::{ - CacheConfig, ConcurrencyConfig, Config, ContentAggregatorConfig, FallbackConfig, IndexerConfig, - RetrievalConfig, SearchConfig, StorageConfig, StrategyConfig, SufficiencyConfig, SummaryConfig, -}; - -/// Configuration merge strategy. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MergeStrategy { - /// Replace with source value. - Replace, - /// Keep existing value if present (don't overwrite). - KeepExisting, - /// Recursively merge nested structures. - Recursive, -} - -/// Trait for configuration merging. -pub trait Merge { - /// Merge another configuration into this one. - fn merge(&mut self, other: &Self, strategy: MergeStrategy); -} - -impl Merge for Config { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - self.indexer.merge(&other.indexer, strategy); - self.summary.merge(&other.summary, strategy); - self.retrieval.merge(&other.retrieval, strategy); - self.storage.merge(&other.storage, strategy); - self.concurrency.merge(&other.concurrency, strategy); - self.fallback.merge(&other.fallback, strategy); - } -} - -impl Merge for IndexerConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.subsection_threshold == 300 { - self.subsection_threshold = other.subsection_threshold; - } - if strategy == MergeStrategy::Replace || self.max_segment_tokens == 3000 { - self.max_segment_tokens = other.max_segment_tokens; - } - if strategy == MergeStrategy::Replace || self.max_summary_tokens == 200 { - self.max_summary_tokens = other.max_summary_tokens; - } - if strategy == MergeStrategy::Replace || self.min_summary_tokens == 20 { - self.min_summary_tokens = other.min_summary_tokens; - } - } -} - -impl Merge for SummaryConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.model == "gpt-4o-mini" { - self.model = other.model.clone(); - } - if strategy == MergeStrategy::Replace || self.endpoint == "https://api.openai.com/v1" { - self.endpoint = other.endpoint.clone(); - } - // Always merge API keys if present - if other.api_key.is_some() { - self.api_key = other.api_key.clone(); - } - if strategy == MergeStrategy::Replace || self.max_tokens == 200 { - self.max_tokens = other.max_tokens; - } - if strategy == MergeStrategy::Replace || self.temperature == 0.0 { - self.temperature = other.temperature; - } - } -} - -impl Merge for RetrievalConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.model == "gpt-4o" { - self.model = other.model.clone(); - } - if strategy == MergeStrategy::Replace || self.endpoint == "https://api.openai.com/v1" { - self.endpoint = other.endpoint.clone(); - } - if other.api_key.is_some() { - self.api_key = other.api_key.clone(); - } - if strategy == MergeStrategy::Replace || self.max_tokens == 1000 { - self.max_tokens = other.max_tokens; - } - if strategy == MergeStrategy::Replace || self.temperature == 0.0 { - self.temperature = other.temperature; - } - if strategy == MergeStrategy::Replace || self.top_k == 3 { - self.top_k = other.top_k; - } - - self.search.merge(&other.search, strategy); - self.sufficiency.merge(&other.sufficiency, strategy); - self.cache.merge(&other.cache, strategy); - self.strategy.merge(&other.strategy, strategy); - self.content.merge(&other.content, strategy); - } -} - -impl Merge for SearchConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.top_k == 5 { - self.top_k = other.top_k; - } - if strategy == MergeStrategy::Replace || self.beam_width == 3 { - self.beam_width = other.beam_width; - } - if strategy == MergeStrategy::Replace || self.max_iterations == 10 { - self.max_iterations = other.max_iterations; - } - if strategy == MergeStrategy::Replace || (self.min_score - 0.1).abs() < f32::EPSILON { - self.min_score = other.min_score; - } - } -} - -impl Merge for SufficiencyConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.min_tokens == 500 { - self.min_tokens = other.min_tokens; - } - if strategy == MergeStrategy::Replace || self.target_tokens == 2000 { - self.target_tokens = other.target_tokens; - } - if strategy == MergeStrategy::Replace || self.max_tokens == 4000 { - self.max_tokens = other.max_tokens; - } - if strategy == MergeStrategy::Replace || self.min_content_length == 200 { - self.min_content_length = other.min_content_length; - } - if strategy == MergeStrategy::Replace - || (self.confidence_threshold - 0.7).abs() < f32::EPSILON - { - self.confidence_threshold = other.confidence_threshold; - } - } -} - -impl Merge for CacheConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.max_entries == 1000 { - self.max_entries = other.max_entries; - } - if strategy == MergeStrategy::Replace || self.ttl_secs == 3600 { - self.ttl_secs = other.ttl_secs; - } - } -} - -impl Merge for StrategyConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || (self.exploration_weight - 1.414).abs() < 0.001 { - self.exploration_weight = other.exploration_weight; - } - if strategy == MergeStrategy::Replace - || (self.similarity_threshold - 0.5).abs() < f32::EPSILON - { - self.similarity_threshold = other.similarity_threshold; - } - if strategy == MergeStrategy::Replace - || (self.high_similarity_threshold - 0.8).abs() < f32::EPSILON - { - self.high_similarity_threshold = other.high_similarity_threshold; - } - if strategy == MergeStrategy::Replace - || (self.low_similarity_threshold - 0.3).abs() < f32::EPSILON - { - self.low_similarity_threshold = other.low_similarity_threshold; - } - } -} - -impl Merge for ContentAggregatorConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if other.enabled != self.enabled { - self.enabled = other.enabled; - } - if strategy == MergeStrategy::Replace || self.token_budget == 4000 { - self.token_budget = other.token_budget; - } - if strategy == MergeStrategy::Replace - || (self.min_relevance_score - 0.2).abs() < f32::EPSILON - { - self.min_relevance_score = other.min_relevance_score; - } - if strategy == MergeStrategy::Replace || self.scoring_strategy == "keyword_bm25" { - self.scoring_strategy = other.scoring_strategy.clone(); - } - if strategy == MergeStrategy::Replace || self.output_format == "markdown" { - self.output_format = other.output_format.clone(); - } - if other.include_scores != self.include_scores { - self.include_scores = other.include_scores; - } - if strategy == MergeStrategy::Replace - || (self.hierarchical_min_per_level - 0.1).abs() < f32::EPSILON - { - self.hierarchical_min_per_level = other.hierarchical_min_per_level; - } - if other.deduplicate != self.deduplicate { - self.deduplicate = other.deduplicate; - } - if strategy == MergeStrategy::Replace || (self.dedup_threshold - 0.9).abs() < f32::EPSILON { - self.dedup_threshold = other.dedup_threshold; - } - } -} - -impl Merge for StorageConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace { - self.workspace_dir = other.workspace_dir.clone(); - } - } -} - -impl Merge for ConcurrencyConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if strategy == MergeStrategy::Replace || self.max_concurrent_requests == 10 { - self.max_concurrent_requests = other.max_concurrent_requests; - } - if strategy == MergeStrategy::Replace || self.requests_per_minute == 500 { - self.requests_per_minute = other.requests_per_minute; - } - if other.enabled != self.enabled { - self.enabled = other.enabled; - } - if other.semaphore_enabled != self.semaphore_enabled { - self.semaphore_enabled = other.semaphore_enabled; - } - } -} - -impl Merge for FallbackConfig { - fn merge(&mut self, other: &Self, strategy: MergeStrategy) { - if other.enabled != self.enabled { - self.enabled = other.enabled; - } - if !other.models.is_empty() { - self.models = other.models.clone(); - } - if !other.endpoints.is_empty() { - self.endpoints = other.endpoints.clone(); - } - if strategy == MergeStrategy::Replace { - self.on_rate_limit = other.on_rate_limit; - self.on_timeout = other.on_timeout; - self.on_all_failed = other.on_all_failed; - self.max_retries = other.max_retries; - self.initial_retry_delay_ms = other.initial_retry_delay_ms; - self.max_retry_delay_ms = other.max_retry_delay_ms; - self.retry_multiplier = other.retry_multiplier; - } - } -} - -/// Configuration overlay for layered configuration. -/// -/// Allows building a configuration from multiple sources, -/// with later overlays taking precedence. -#[derive(Debug, Clone)] -pub struct ConfigOverlay { - /// Base configuration. - base: Config, - /// Overlay configurations (applied in order). - overlays: Vec, -} - -impl ConfigOverlay { - /// Create a new overlay with a base configuration. - pub fn new(base: Config) -> Self { - Self { - base, - overlays: Vec::new(), - } - } - - /// Add an overlay configuration. - pub fn overlay(mut self, config: Config) -> Self { - self.overlays.push(config); - self - } - - /// Resolve all overlays into a final configuration. - pub fn resolve(self) -> Config { - let mut result = self.base; - for overlay in self.overlays { - result.merge(&overlay, MergeStrategy::Replace); - } - result - } -} - -impl Default for ConfigOverlay { - fn default() -> Self { - Self::new(Config::default()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_config_merge() { - let mut base = Config::default(); - let mut overlay = Config::default(); - - overlay.retrieval.top_k = 10; - overlay.summary.model = "gpt-4o".to_string(); - - base.merge(&overlay, MergeStrategy::Replace); - - assert_eq!(base.retrieval.top_k, 10); - assert_eq!(base.summary.model, "gpt-4o"); - } - - #[test] - fn test_config_overlay() { - let mut overlay1 = Config::default(); - overlay1.retrieval.top_k = 5; - - let mut overlay2 = Config::default(); - overlay2.retrieval.top_k = 10; - - let config = ConfigOverlay::new(Config::default()) - .overlay(overlay1) - .overlay(overlay2) - .resolve(); - - assert_eq!(config.retrieval.top_k, 10); - } - - #[test] - fn test_merge_keeps_api_keys() { - let mut base = Config::default(); - let mut overlay = Config::default(); - - overlay.summary.api_key = Some("test-key".to_string()); - - base.merge(&overlay, MergeStrategy::Replace); - - assert_eq!(base.summary.api_key, Some("test-key".to_string())); - } -} diff --git a/rust/src/config/mod.rs b/rust/src/config/mod.rs index f6d26927..f90f2af9 100644 --- a/rust/src/config/mod.rs +++ b/rust/src/config/mod.rs @@ -6,15 +6,12 @@ //! Users configure vectorless via [`EngineBuilder`](crate::client::EngineBuilder) methods, //! not by directly interacting with this module. -mod loader; -mod merge; mod types; mod validator; pub use types::Config; pub(crate) use types::{ - CacheConfig, CompressionAlgorithm, ConcurrencyConfig, FallbackBehavior, FallbackConfig, - IndexerConfig, LlmClientConfig, LlmConfig, LlmMetricsConfig, LlmPoolConfig, MetricsConfig, - OnAllFailedBehavior, PilotMetricsConfig, RetrievalConfig, RetrievalMetricsConfig, - SufficiencyConfig, SummaryConfig, + CacheConfig, CompressionAlgorithm, FallbackBehavior, FallbackConfig, IndexerConfig, LlmConfig, + LlmMetricsConfig, MetricsConfig, OnAllFailedBehavior, PilotMetricsConfig, + RetrievalMetricsConfig, SlotConfig, SufficiencyConfig, }; diff --git a/rust/src/config/types/concurrency.rs b/rust/src/config/types/concurrency.rs deleted file mode 100644 index c4172ba8..00000000 --- a/rust/src/config/types/concurrency.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Concurrency control configuration types. - -use serde::{Deserialize, Serialize}; - -/// Concurrency control configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConcurrencyConfig { - /// Maximum concurrent LLM API calls. - #[serde(default = "default_max_concurrent_requests")] - pub max_concurrent_requests: usize, - - /// Rate limit: requests per minute. - #[serde(default = "default_requests_per_minute")] - pub requests_per_minute: usize, - - /// Whether rate limiting is enabled. - #[serde(default = "default_true")] - pub enabled: bool, - - /// Whether semaphore-based concurrency limiting is enabled. - #[serde(default = "default_true")] - pub semaphore_enabled: bool, -} - -fn default_max_concurrent_requests() -> usize { - 10 -} - -fn default_requests_per_minute() -> usize { - 500 -} - -fn default_true() -> bool { - true -} - -impl Default for ConcurrencyConfig { - fn default() -> Self { - Self { - max_concurrent_requests: default_max_concurrent_requests(), - requests_per_minute: default_requests_per_minute(), - enabled: default_true(), - semaphore_enabled: default_true(), - } - } -} - -impl ConcurrencyConfig { - /// Create a new config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Set the maximum concurrent requests. - pub fn with_max_concurrent_requests(mut self, max: usize) -> Self { - self.max_concurrent_requests = max; - self - } - - /// Set the requests per minute rate limit. - pub fn with_requests_per_minute(mut self, rpm: usize) -> Self { - self.requests_per_minute = rpm; - self - } - - /// Enable or disable rate limiting. - pub fn with_enabled(mut self, enabled: bool) -> Self { - self.enabled = enabled; - self - } - - /// Enable or disable semaphore. - pub fn with_semaphore_enabled(mut self, enabled: bool) -> Self { - self.semaphore_enabled = enabled; - self - } - - /// Convert to the runtime concurrency config. - pub fn to_runtime_config(&self) -> crate::throttle::ConcurrencyConfig { - crate::throttle::ConcurrencyConfig { - max_concurrent_requests: self.max_concurrent_requests, - requests_per_minute: self.requests_per_minute, - enabled: self.enabled, - semaphore_enabled: self.semaphore_enabled, - } - } -} - -impl From for crate::throttle::ConcurrencyConfig { - fn from(config: ConcurrencyConfig) -> Self { - config.to_runtime_config() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_concurrency_config_defaults() { - let config = ConcurrencyConfig::default(); - assert_eq!(config.max_concurrent_requests, 10); - assert_eq!(config.requests_per_minute, 500); - assert!(config.enabled); - assert!(config.semaphore_enabled); - } - - #[test] - fn test_concurrency_config_builder() { - let config = ConcurrencyConfig::new() - .with_max_concurrent_requests(20) - .with_requests_per_minute(1000) - .with_enabled(false); - - assert_eq!(config.max_concurrent_requests, 20); - assert_eq!(config.requests_per_minute, 1000); - assert!(!config.enabled); - } -} diff --git a/rust/src/config/types/fallback.rs b/rust/src/config/types/fallback.rs deleted file mode 100644 index 96c2fd89..00000000 --- a/rust/src/config/types/fallback.rs +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Fallback and error recovery configuration types. - -use serde::{Deserialize, Serialize}; - -/// Fallback behavior when encountering errors. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum FallbackBehavior { - /// Only retry with the same model/endpoint. - Retry, - /// Immediately switch to fallback model/endpoint. - Fallback, - /// Retry first, then fallback if still failing. - RetryThenFallback, - /// Fail immediately without retry or fallback. - Fail, -} - -impl Default for FallbackBehavior { - fn default() -> Self { - Self::RetryThenFallback - } -} - -/// Behavior when all fallback attempts fail. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum OnAllFailedBehavior { - /// Return the error to the caller. - ReturnError, - /// Try to return cached result if available. - ReturnCache, -} - -impl Default for OnAllFailedBehavior { - fn default() -> Self { - Self::ReturnError - } -} - -/// Fallback configuration for error recovery. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FallbackConfig { - /// Whether fallback is enabled. - #[serde(default = "default_true")] - pub enabled: bool, - - /// Fallback models in priority order. - #[serde(default = "default_fallback_models")] - pub models: Vec, - - /// Fallback endpoints in priority order. - #[serde(default)] - pub endpoints: Vec, - - /// Behavior on rate limit error (429). - #[serde(default)] - pub on_rate_limit: FallbackBehavior, - - /// Behavior on timeout error. - #[serde(default)] - pub on_timeout: FallbackBehavior, - - /// Behavior when all attempts fail. - #[serde(default)] - pub on_all_failed: OnAllFailedBehavior, - - /// Maximum retry attempts. - #[serde(default = "default_max_retries")] - pub max_retries: usize, - - /// Initial retry delay in milliseconds. - #[serde(default = "default_initial_retry_delay_ms")] - pub initial_retry_delay_ms: u64, - - /// Maximum retry delay in milliseconds. - #[serde(default = "default_max_retry_delay_ms")] - pub max_retry_delay_ms: u64, - - /// Retry delay multiplier (exponential backoff). - #[serde(default = "default_retry_multiplier")] - pub retry_multiplier: f32, -} - -fn default_fallback_models() -> Vec { - vec!["gpt-4o-mini".to_string(), "glm-4-flash".to_string()] -} - -fn default_max_retries() -> usize { - 3 -} - -fn default_initial_retry_delay_ms() -> u64 { - 1000 -} - -fn default_max_retry_delay_ms() -> u64 { - 30000 -} - -fn default_retry_multiplier() -> f32 { - 2.0 -} - -impl Default for FallbackConfig { - fn default() -> Self { - Self { - enabled: default_true(), - models: default_fallback_models(), - endpoints: Vec::new(), - on_rate_limit: FallbackBehavior::default(), - on_timeout: FallbackBehavior::default(), - on_all_failed: OnAllFailedBehavior::default(), - max_retries: default_max_retries(), - initial_retry_delay_ms: default_initial_retry_delay_ms(), - max_retry_delay_ms: default_max_retry_delay_ms(), - retry_multiplier: default_retry_multiplier(), - } - } -} - -fn default_true() -> bool { - true -} - -impl FallbackConfig { - /// Create a new fallback config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Disable fallback entirely. - pub fn disabled() -> Self { - Self { - enabled: false, - ..Self::default() - } - } - - /// Set fallback models. - pub fn with_models(mut self, models: Vec) -> Self { - self.models = models; - self - } - - /// Set fallback endpoints. - pub fn with_endpoints(mut self, endpoints: Vec) -> Self { - self.endpoints = endpoints; - self - } - - /// Set behavior on rate limit. - pub fn with_on_rate_limit(mut self, behavior: FallbackBehavior) -> Self { - self.on_rate_limit = behavior; - self - } - - /// Set behavior on timeout. - pub fn with_on_timeout(mut self, behavior: FallbackBehavior) -> Self { - self.on_timeout = behavior; - self - } - - /// Set behavior when all attempts fail. - pub fn with_on_all_failed(mut self, behavior: OnAllFailedBehavior) -> Self { - self.on_all_failed = behavior; - self - } - - /// Set maximum retries. - pub fn with_max_retries(mut self, max: usize) -> Self { - self.max_retries = max; - self - } - - /// Calculate retry delay with exponential backoff. - pub fn calculate_retry_delay(&self, attempt: usize) -> std::time::Duration { - let delay_ms = if attempt == 0 { - self.initial_retry_delay_ms - } else { - let delay = - self.initial_retry_delay_ms as f32 * self.retry_multiplier.powi(attempt as i32); - delay.min(self.max_retry_delay_ms as f32) as u64 - }; - std::time::Duration::from_millis(delay_ms) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_fallback_config_defaults() { - let config = FallbackConfig::default(); - assert!(config.enabled); - assert_eq!(config.models.len(), 2); - assert_eq!(config.on_rate_limit, FallbackBehavior::RetryThenFallback); - assert_eq!(config.max_retries, 3); - } - - #[test] - fn test_fallback_config_disabled() { - let config = FallbackConfig::disabled(); - assert!(!config.enabled); - } - - #[test] - fn test_fallback_behavior_serde() { - let behavior = FallbackBehavior::RetryThenFallback; - let json = serde_json::to_string(&behavior).unwrap(); - assert_eq!(json, "\"retry_then_fallback\""); - - let decoded: FallbackBehavior = serde_json::from_str(&json).unwrap(); - assert_eq!(decoded, behavior); - } - - #[test] - fn test_retry_delay_calculation() { - let config = FallbackConfig::default(); - - let d0 = config.calculate_retry_delay(0); - let d1 = config.calculate_retry_delay(1); - let d2 = config.calculate_retry_delay(2); - - assert_eq!(d0.as_millis(), 1000); - assert_eq!(d1.as_millis(), 2000); - assert_eq!(d2.as_millis(), 4000); - } -} diff --git a/rust/src/config/types/llm.rs b/rust/src/config/types/llm.rs deleted file mode 100644 index a3a0a285..00000000 --- a/rust/src/config/types/llm.rs +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! LLM configuration types for summary and retrieval. - -use serde::{Deserialize, Serialize}; - -/// Generic LLM configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmConfig { - /// Model name (e.g., "gpt-4o-mini", "claude-3-haiku"). - #[serde(default)] - pub model: String, - - /// API endpoint. - #[serde(default)] - pub endpoint: String, - - /// API key. - #[serde(default)] - pub api_key: Option, - - /// Maximum tokens for responses. - #[serde(default = "default_max_tokens")] - pub max_tokens: usize, - - /// Temperature for generation. - #[serde(default = "default_temperature")] - pub temperature: f32, -} - -fn default_max_tokens() -> usize { - 1000 -} - -fn default_temperature() -> f32 { - 0.0 -} - -impl Default for LlmConfig { - fn default() -> Self { - Self { - model: String::new(), - endpoint: String::new(), - api_key: None, - max_tokens: default_max_tokens(), - temperature: default_temperature(), - } - } -} - -impl LlmConfig { - /// Create a new LLM config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Set the model. - pub fn with_model(mut self, model: impl Into) -> Self { - self.model = model.into(); - self - } - - /// Set the endpoint. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = endpoint.into(); - self - } - - /// Set the API key. - pub fn with_api_key(mut self, api_key: impl Into) -> Self { - self.api_key = Some(api_key.into()); - self - } - - /// Set the maximum tokens. - pub fn with_max_tokens(mut self, max_tokens: usize) -> Self { - self.max_tokens = max_tokens; - self - } - - /// Set the temperature. - pub fn with_temperature(mut self, temperature: f32) -> Self { - self.temperature = temperature; - self - } - - /// Get the API key from config. - pub fn get_api_key(&self) -> Option<&str> { - self.api_key.as_deref() - } -} - -/// Summary model configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SummaryConfig { - /// Model name for summarization. - #[serde(default)] - pub model: String, - - /// API endpoint for summary model. - #[serde(default)] - pub endpoint: String, - - /// API key. - #[serde(default)] - pub api_key: Option, - - /// Maximum tokens for summary generation. - #[serde(default = "default_max_summary_tokens")] - pub max_tokens: usize, - - /// Temperature for summary generation. - #[serde(default = "default_temperature")] - pub temperature: f32, -} - -fn default_max_summary_tokens() -> usize { - 200 -} - -impl Default for SummaryConfig { - fn default() -> Self { - Self { - model: String::new(), - endpoint: String::new(), - api_key: None, - max_tokens: default_max_summary_tokens(), - temperature: default_temperature(), - } - } -} - -impl SummaryConfig { - /// Create a new summary config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Set the model. - pub fn with_model(mut self, model: impl Into) -> Self { - self.model = model.into(); - self - } - - /// Set the endpoint. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = endpoint.into(); - self - } - - /// Set the API key. - pub fn with_api_key(mut self, api_key: impl Into) -> Self { - self.api_key = Some(api_key.into()); - self - } - - /// Set the maximum tokens. - pub fn with_max_tokens(mut self, max_tokens: usize) -> Self { - self.max_tokens = max_tokens; - self - } - - /// Convert to generic LLM config. - pub fn to_llm_config(&self) -> LlmConfig { - LlmConfig { - model: self.model.clone(), - endpoint: self.endpoint.clone(), - api_key: self.api_key.clone(), - max_tokens: self.max_tokens, - temperature: self.temperature, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_llm_config_defaults() { - let config = LlmConfig::default(); - assert!(config.model.is_empty()); - assert!(config.endpoint.is_empty()); - assert!(config.api_key.is_none()); - } - - #[test] - fn test_llm_config_builder() { - let config = LlmConfig::new() - .with_model("gpt-4o") - .with_api_key("test-key") - .with_max_tokens(2000); - - assert_eq!(config.model, "gpt-4o"); - assert_eq!(config.api_key, Some("test-key".to_string())); - assert_eq!(config.max_tokens, 2000); - } - - #[test] - fn test_summary_config() { - let config = SummaryConfig::default(); - assert!(config.model.is_empty()); - assert_eq!(config.max_tokens, 200); - } -} diff --git a/rust/src/config/types/llm_pool.rs b/rust/src/config/types/llm_pool.rs index d77d1241..fc092a12 100644 --- a/rust/src/config/types/llm_pool.rs +++ b/rust/src/config/types/llm_pool.rs @@ -1,74 +1,117 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Unified LLM configuration including pool, retry, throttle, and fallback. +//! Unified LLM configuration. //! //! This module consolidates all LLM-related configuration into a single -//! cohesive structure that maps directly to the TOML configuration file. +//! cohesive structure. Users configure via [`EngineBuilder`](crate::client::EngineBuilder) +//! for simple cases, or construct [`LlmConfig`] programmatically for advanced use. use serde::{Deserialize, Serialize}; -/// Unified LLM configuration. +/// Unified LLM configuration — the single entry point for all LLM settings. /// -/// Contains all settings for LLM operations including: -/// - Pool of clients for different purposes (index, retrieval, pilot) -/// - Retry behavior -/// - Throttle/rate limiting -/// - Fallback strategy +/// Contains: +/// - Global credentials (`api_key`, `model`, `endpoint`) +/// - Per-purpose slot overrides (`index`, `retrieval`, `pilot`) +/// - Infrastructure settings (`retry`, `throttle`, `fallback`) +/// +/// # Simple usage (via EngineBuilder) +/// +/// ```rust,no_run +/// use vectorless::client::EngineBuilder; +/// +/// # async fn example() -> Result<(), vectorless::BuildError> { +/// let engine = EngineBuilder::new() +/// .with_key("sk-...") +/// .with_model("gpt-4o") +/// .with_endpoint("https://api.openai.com/v1") +/// .build() +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Advanced usage (programmatic config) +/// +/// ```rust,ignore +/// use vectorless::config::{Config, LlmConfig, SlotConfig}; +/// +/// let config = Config::new().with_llm( +/// LlmConfig::new("gpt-4o") +/// .with_api_key("sk-...") +/// .with_endpoint("https://api.openai.com/v1") +/// .with_index(SlotConfig::fast().with_model("gpt-4o-mini")) +/// .with_retrieval(SlotConfig::default().with_max_tokens(200)) +/// ); +/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmPoolConfig { - /// Index client configuration (used during document indexing). - #[serde(default, alias = "summary")] - pub index: LlmClientConfig, - - /// Retrieval client configuration. +pub struct LlmConfig { + /// API key — **required**. #[serde(default)] - pub retrieval: LlmClientConfig, - - /// Pilot client configuration. - #[serde(default = "default_pilot_config")] - pub pilot: LlmClientConfig, + pub api_key: Option, - /// Default API key (used if not specified per-client). + /// Default model name — **required**. + /// + /// Individual slots can override this via [`SlotConfig::model`]. #[serde(default)] - pub api_key: Option, + pub model: String, - /// Default API endpoint (used if not specified per-client). + /// API endpoint URL — **required**. #[serde(default)] pub endpoint: Option, - /// Retry configuration. + /// Index slot (document indexing / summarization). + /// Uses a fast, cost-effective model by default. + #[serde(default)] + pub index: SlotConfig, + + /// Retrieval slot (document navigation). + /// Uses the default model. + #[serde(default = "default_retrieval_slot")] + pub retrieval: SlotConfig, + + /// Pilot slot (navigation guidance). + /// Uses a fast model with higher token limit. + #[serde(default = "default_pilot_slot")] + pub pilot: SlotConfig, + + /// Retry configuration for LLM calls. #[serde(default)] pub retry: RetryConfig, - /// Throttle/rate limiting configuration. + /// Throttle / rate-limiting configuration. #[serde(default)] pub throttle: ThrottleConfig, - /// Fallback configuration. + /// Fallback configuration for error recovery. #[serde(default)] pub fallback: FallbackConfig, } -fn default_pilot_config() -> LlmClientConfig { - LlmClientConfig { +fn default_retrieval_slot() -> SlotConfig { + SlotConfig { + max_tokens: 100, + ..SlotConfig::default() + } +} + +fn default_pilot_slot() -> SlotConfig { + SlotConfig { max_tokens: 300, - temperature: 0.0, - ..Default::default() + ..SlotConfig::default() } } -impl Default for LlmPoolConfig { +impl Default for LlmConfig { fn default() -> Self { Self { - index: LlmClientConfig::default(), - retrieval: LlmClientConfig { - max_tokens: 100, - ..Default::default() - }, - pilot: default_pilot_config(), api_key: None, + model: String::new(), endpoint: None, + index: SlotConfig::default(), + retrieval: default_retrieval_slot(), + pilot: default_pilot_slot(), retry: RetryConfig::default(), throttle: ThrottleConfig::default(), fallback: FallbackConfig::default(), @@ -76,71 +119,100 @@ impl Default for LlmPoolConfig { } } -impl LlmPoolConfig { - /// Create a new LLM pool config with defaults. - pub fn new() -> Self { - Self::default() +impl LlmConfig { + /// Create a new config with a specific model. + pub fn new(model: impl Into) -> Self { + Self { + model: model.into(), + ..Self::default() + } } - /// Set the default API key. - pub fn with_api_key(mut self, api_key: impl Into) -> Self { - self.api_key = Some(api_key.into()); + /// Set the API key. + pub fn with_api_key(mut self, key: impl Into) -> Self { + self.api_key = Some(key.into()); self } - /// Get API key for a specific client (client-specific or default). - pub fn get_api_key_for(&self, client_key: Option<&str>) -> Option { - // First check client-specific key - if let Some(key) = client_key { - if let Some(ref k) = self.index.api_key { - if self.index.model == key { - return Some(k.clone()); - } - } - if let Some(ref k) = self.retrieval.api_key { - if self.retrieval.model == key { - return Some(k.clone()); - } - } - if let Some(ref k) = self.pilot.api_key { - if self.pilot.model == key { - return Some(k.clone()); - } - } - } - // Fall back to default - self.api_key.clone() + /// Set the default model. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self } - /// Resolve API key: client-specific first, then default. - pub fn resolved_api_key(&self, client: &LlmClientConfig) -> Option { - client.api_key.clone().or_else(|| self.api_key.clone()) + /// Set the endpoint URL. + pub fn with_endpoint(mut self, url: impl Into) -> Self { + self.endpoint = Some(url.into()); + self } - /// Resolve endpoint: client-specific first, then default. - pub fn resolved_endpoint(&self, client: &LlmClientConfig) -> String { - if !client.endpoint.is_empty() { - client.endpoint.clone() - } else { - self.endpoint.clone().unwrap_or_default() - } + /// Set the index slot configuration. + pub fn with_index(mut self, slot: SlotConfig) -> Self { + self.index = slot; + self } -} -/// Individual LLM client configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmClientConfig { - /// Model name. - #[serde(default)] - pub model: String, + /// Set the retrieval slot configuration. + pub fn with_retrieval(mut self, slot: SlotConfig) -> Self { + self.retrieval = slot; + self + } - /// API endpoint. - #[serde(default)] - pub endpoint: String, + /// Set the pilot slot configuration. + pub fn with_pilot(mut self, slot: SlotConfig) -> Self { + self.pilot = slot; + self + } - /// API key (optional, falls back to default). + /// Set the retry configuration. + pub fn with_retry(mut self, retry: RetryConfig) -> Self { + self.retry = retry; + self + } + + /// Set the throttle configuration. + pub fn with_throttle(mut self, throttle: ThrottleConfig) -> Self { + self.throttle = throttle; + self + } + + /// Set the fallback configuration. + pub fn with_fallback(mut self, fallback: FallbackConfig) -> Self { + self.fallback = fallback; + self + } + + /// Convenience: set max concurrent requests (delegates to throttle). + pub fn with_max_concurrent(mut self, max: usize) -> Self { + self.throttle.max_concurrent_requests = max; + self + } + + /// Resolve the effective model for a given slot. + /// + /// Returns the slot-specific model if set, otherwise the default model. + pub fn resolve_model(&self, slot: &SlotConfig) -> String { + slot.model.clone().unwrap_or_else(|| self.model.clone()) + } +} + +/// Per-purpose LLM slot override. +/// +/// Controls model selection and generation parameters for a specific +/// LLM usage (index, retrieval, or pilot). +/// +/// - `model`: Override the default model (optional). +/// - `max_tokens`: Maximum response tokens. +/// - `temperature`: Generation temperature. +/// +/// `api_key` and `endpoint` are **not** here — they are always inherited +/// from the parent [`LlmConfig`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SlotConfig { + /// Override the default model for this purpose. + /// When `None`, uses [`LlmConfig::model`]. #[serde(default)] - pub api_key: Option, + pub model: Option, /// Maximum tokens for responses. #[serde(default = "default_max_tokens")] @@ -159,39 +231,33 @@ fn default_temperature() -> f32 { 0.0 } -impl Default for LlmClientConfig { +impl Default for SlotConfig { fn default() -> Self { Self { - model: String::new(), - endpoint: String::new(), - api_key: None, + model: None, max_tokens: default_max_tokens(), temperature: default_temperature(), } } } -impl LlmClientConfig { - /// Create a new client config with defaults. +impl SlotConfig { + /// Create a new slot config with defaults. pub fn new() -> Self { Self::default() } - /// Set the model. - pub fn with_model(mut self, model: impl Into) -> Self { - self.model = model.into(); - self - } - - /// Set the endpoint. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = endpoint.into(); - self + /// Create a "fast" preset (low tokens). + pub fn fast() -> Self { + Self { + max_tokens: 100, + ..Self::default() + } } - /// Set the API key. - pub fn with_api_key(mut self, api_key: impl Into) -> Self { - self.api_key = Some(api_key.into()); + /// Set the model override. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); self } @@ -200,8 +266,18 @@ impl LlmClientConfig { self.max_tokens = max_tokens; self } + + /// Set the temperature. + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } } +// ============================================================ +// Supporting configuration types +// ============================================================ + /// Retry configuration for LLM calls. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RetryConfig { @@ -276,9 +352,20 @@ impl RetryConfig { let delay_ms = delay_ms.min(self.max_delay_ms as f64); std::time::Duration::from_millis(delay_ms as u64) } + + /// Convert to the runtime retry config (used by llm module). + pub fn to_runtime_config(&self) -> crate::llm::config::RetryConfig { + crate::llm::config::RetryConfig { + max_attempts: self.max_attempts, + initial_delay_ms: self.initial_delay_ms, + max_delay_ms: self.max_delay_ms, + multiplier: self.multiplier, + retry_on_rate_limit: self.retry_on_rate_limit, + } + } } -/// Throttle/rate limiting configuration. +/// Throttle / rate-limiting configuration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ThrottleConfig { /// Maximum concurrent LLM API calls. @@ -334,9 +421,45 @@ impl ThrottleConfig { self.requests_per_minute = rpm; self } + + /// Convert to the runtime concurrency config. + pub fn to_runtime_config(&self) -> crate::llm::throttle::ConcurrencyConfig { + crate::llm::throttle::ConcurrencyConfig { + max_concurrent_requests: self.max_concurrent_requests, + requests_per_minute: self.requests_per_minute, + enabled: self.enabled, + semaphore_enabled: self.semaphore_enabled, + } + } +} + +/// Fallback behavior on errors. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum FallbackBehavior { + /// Retry the same model. + Retry, + /// Immediately fall back to next model. + Fallback, + /// Retry first, then fall back. + #[default] + RetryThenFallback, + /// Fail immediately. + Fail, +} + +/// Behavior when all fallback attempts fail. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum OnAllFailedBehavior { + /// Return an error. + #[default] + ReturnError, + /// Return cached result if available. + ReturnCache, } -/// Fallback configuration for LLM calls. +/// Fallback configuration for error recovery. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FallbackConfig { /// Enable fallback mechanism. @@ -362,36 +485,42 @@ pub struct FallbackConfig { /// Behavior when all attempts fail. #[serde(default)] pub on_all_failed: OnAllFailedBehavior, + + /// Maximum retry attempts. + #[serde(default = "default_max_retries")] + pub max_retries: usize, + + /// Initial retry delay in milliseconds. + #[serde(default = "default_initial_retry_delay_ms")] + pub initial_retry_delay_ms: u64, + + /// Maximum retry delay in milliseconds. + #[serde(default = "default_max_retry_delay_ms")] + pub max_retry_delay_ms: u64, + + /// Retry delay multiplier (exponential backoff). + #[serde(default = "default_retry_multiplier")] + pub retry_multiplier: f32, } fn default_fallback_models() -> Vec { vec!["gpt-4o-mini".to_string(), "glm-4-flash".to_string()] } -/// Fallback behavior on errors. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "snake_case")] -pub enum FallbackBehavior { - /// Retry the same model. - Retry, - /// Immediately fall back to next model. - Fallback, - /// Retry first, then fall back. - #[default] - RetryThenFallback, - /// Fail immediately. - Fail, +fn default_max_retries() -> usize { + 3 } -/// Behavior when all fallback attempts fail. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "snake_case")] -pub enum OnAllFailedBehavior { - /// Return an error. - #[default] - ReturnError, - /// Return cached result if available. - ReturnCache, +fn default_initial_retry_delay_ms() -> u64 { + 1000 +} + +fn default_max_retry_delay_ms() -> u64 { + 30000 +} + +fn default_retry_multiplier() -> f32 { + 2.0 } impl Default for FallbackConfig { @@ -403,6 +532,10 @@ impl Default for FallbackConfig { on_rate_limit: FallbackBehavior::default(), on_timeout: FallbackBehavior::default(), on_all_failed: OnAllFailedBehavior::default(), + max_retries: default_max_retries(), + initial_retry_delay_ms: default_initial_retry_delay_ms(), + max_retry_delay_ms: default_max_retry_delay_ms(), + retry_multiplier: default_retry_multiplier(), } } } @@ -413,13 +546,37 @@ impl FallbackConfig { Self::default() } - /// Disable fallback. + /// Disable fallback entirely. pub fn disabled() -> Self { Self { enabled: false, ..Self::default() } } + + /// Set fallback models. + pub fn with_models(mut self, models: Vec) -> Self { + self.models = models; + self + } + + /// Set behavior on rate limit. + pub fn with_on_rate_limit(mut self, behavior: FallbackBehavior) -> Self { + self.on_rate_limit = behavior; + self + } + + /// Calculate retry delay with exponential backoff. + pub fn calculate_retry_delay(&self, attempt: usize) -> std::time::Duration { + let delay_ms = if attempt == 0 { + self.initial_retry_delay_ms + } else { + let delay = + self.initial_retry_delay_ms as f32 * self.retry_multiplier.powi(attempt as i32); + delay.min(self.max_retry_delay_ms as f32) as u64 + }; + std::time::Duration::from_millis(delay_ms) + } } #[cfg(test)] @@ -427,32 +584,68 @@ mod tests { use super::*; #[test] - fn test_llm_pool_config_defaults() { - let config = LlmPoolConfig::default(); - assert!(config.index.model.is_empty()); - assert!(config.retrieval.model.is_empty()); - assert!(config.pilot.model.is_empty()); - assert_eq!(config.retry.max_attempts, 3); - assert_eq!(config.throttle.max_concurrent_requests, 10); + fn test_llm_config_defaults() { + let config = LlmConfig::default(); + assert!(config.api_key.is_none()); + assert!(config.model.is_empty()); + assert!(config.endpoint.is_none()); + assert!(config.index.model.is_none()); + assert!(config.retrieval.model.is_none()); + assert!(config.pilot.model.is_none()); + assert_eq!(config.index.max_tokens, 200); + assert_eq!(config.retrieval.max_tokens, 100); + assert_eq!(config.pilot.max_tokens, 300); + } + + #[test] + fn test_llm_config_builder() { + let config = LlmConfig::new("gpt-4o") + .with_api_key("sk-test") + .with_endpoint("https://api.openai.com/v1") + .with_index(SlotConfig::fast().with_model("gpt-4o-mini")); + + assert_eq!(config.model, "gpt-4o"); + assert_eq!(config.api_key, Some("sk-test".to_string())); + assert_eq!(config.index.model, Some("gpt-4o-mini".to_string())); + assert_eq!(config.index.max_tokens, 100); + } + + #[test] + fn test_resolve_model() { + let config = + LlmConfig::new("gpt-4o").with_retrieval(SlotConfig::new().with_model("gpt-4o-mini")); + + assert_eq!(config.resolve_model(&config.index), "gpt-4o"); + assert_eq!(config.resolve_model(&config.retrieval), "gpt-4o-mini"); + assert_eq!(config.resolve_model(&config.pilot), "gpt-4o"); + } + + #[test] + fn test_slot_config_fast() { + let slot = SlotConfig::fast(); + assert_eq!(slot.max_tokens, 100); } #[test] fn test_retry_delay_calculation() { let config = RetryConfig::default(); - - // Initial delay assert_eq!( config.delay_for_attempt(0), std::time::Duration::from_millis(500) ); - - // Second attempt: 500 * 2 = 1000 assert_eq!( config.delay_for_attempt(1), std::time::Duration::from_millis(1000) ); } + #[test] + fn test_throttle_config_defaults() { + let config = ThrottleConfig::default(); + assert_eq!(config.max_concurrent_requests, 10); + assert_eq!(config.requests_per_minute, 500); + } + #[test] fn test_fallback_config_defaults() { let config = FallbackConfig::default(); diff --git a/rust/src/config/types/mod.rs b/rust/src/config/types/mod.rs index 8ca3b434..da53b1f3 100644 --- a/rust/src/config/types/mod.rs +++ b/rust/src/config/types/mod.rs @@ -2,15 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 //! Configuration type definitions. -//! -//! All configuration values are defined inline in `Default` trait implementations. -//! Configuration is loaded from TOML files only — no environment variables, no auto-detection. -mod concurrency; mod content; -mod fallback; mod indexer; -mod llm; mod llm_pool; mod metrics; mod retrieval; @@ -18,28 +12,53 @@ mod storage; use serde::{Deserialize, Serialize}; -pub(crate) use concurrency::ConcurrencyConfig; -pub(crate) use content::ContentAggregatorConfig; -pub(crate) use fallback::{FallbackBehavior, FallbackConfig, OnAllFailedBehavior}; pub(crate) use indexer::IndexerConfig; -pub(crate) use llm::{LlmConfig, SummaryConfig}; -pub(crate) use llm_pool::{LlmClientConfig, LlmPoolConfig}; +pub(crate) use llm_pool::{ + FallbackBehavior, FallbackConfig, LlmConfig, OnAllFailedBehavior, SlotConfig, +}; pub(crate) use metrics::{ LlmMetricsConfig, MetricsConfig, PilotMetricsConfig, RetrievalMetricsConfig, }; -pub(crate) use retrieval::{RetrievalConfig, SearchConfig}; -pub(crate) use storage::{ - CacheConfig, CompressionAlgorithm, StorageConfig, StrategyConfig, SufficiencyConfig, -}; +pub(crate) use retrieval::RetrievalConfig; +pub(crate) use storage::{CacheConfig, CompressionAlgorithm, StorageConfig, SufficiencyConfig}; /// Main configuration for vectorless. +/// +/// Users typically configure via [`EngineBuilder`](crate::client::EngineBuilder): +/// +/// ```rust,no_run +/// use vectorless::client::EngineBuilder; +/// +/// # async fn example() -> Result<(), vectorless::BuildError> { +/// let engine = EngineBuilder::new() +/// .with_key("sk-...") +/// .with_model("gpt-4o") +/// .with_endpoint("https://api.openai.com/v1") +/// .build() +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Advanced users can construct this programmatically: +/// +/// ```rust,ignore +/// use vectorless::config::{Config, LlmConfig, SlotConfig}; +/// +/// let config = Config::new().with_llm( +/// LlmConfig::new("gpt-4o") +/// .with_api_key("sk-...") +/// .with_endpoint("https://api.openai.com/v1") +/// .with_index(SlotConfig::fast().with_model("gpt-4o-mini")) +/// ); +/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { - /// Unified LLM configuration (pool, retry, throttle, fallback). + /// LLM configuration (model, credentials, retry, throttle, fallback). #[serde(default)] - pub llm: LlmPoolConfig, + pub llm: LlmConfig, - /// Unified metrics configuration. + /// Metrics configuration. #[serde(default)] pub metrics: MetricsConfig, @@ -47,11 +66,7 @@ pub struct Config { #[serde(default)] pub indexer: IndexerConfig, - /// Summary model configuration (legacy, prefer llm.summary). - #[serde(default)] - pub summary: SummaryConfig, - - /// Retrieval model configuration. + /// Retrieval strategy configuration (search, content aggregation, etc.). #[serde(default)] pub retrieval: RetrievalConfig, @@ -59,31 +74,20 @@ pub struct Config { #[serde(default)] pub storage: StorageConfig, - /// Concurrency control configuration (legacy, prefer llm.throttle). - #[serde(default)] - pub concurrency: ConcurrencyConfig, - /// Document graph configuration. #[serde(default)] pub graph: crate::graph::DocumentGraphConfig, - - /// Fallback/error recovery configuration (legacy, prefer llm.fallback). - #[serde(default)] - pub fallback: FallbackConfig, } impl Default for Config { fn default() -> Self { Self { - llm: LlmPoolConfig::default(), + llm: LlmConfig::default(), metrics: MetricsConfig::default(), indexer: IndexerConfig::default(), - summary: SummaryConfig::default(), retrieval: RetrievalConfig::default(), storage: StorageConfig::default(), - concurrency: ConcurrencyConfig::default(), graph: crate::graph::DocumentGraphConfig::default(), - fallback: FallbackConfig::default(), } } } @@ -94,8 +98,8 @@ impl Config { Self::default() } - /// Set the LLM pool configuration. - pub fn with_llm(mut self, llm: LlmPoolConfig) -> Self { + /// Set the LLM configuration. + pub fn with_llm(mut self, llm: LlmConfig) -> Self { self.llm = llm; self } @@ -112,12 +116,6 @@ impl Config { self } - /// Set the summary configuration. - pub fn with_summary(mut self, summary: SummaryConfig) -> Self { - self.summary = summary; - self - } - /// Set the retrieval configuration. pub fn with_retrieval(mut self, retrieval: RetrievalConfig) -> Self { self.retrieval = retrieval; @@ -130,24 +128,12 @@ impl Config { self } - /// Set the concurrency configuration. - pub fn with_concurrency(mut self, concurrency: ConcurrencyConfig) -> Self { - self.concurrency = concurrency; - self - } - /// Set the document graph configuration. pub fn with_graph(mut self, graph: crate::graph::DocumentGraphConfig) -> Self { self.graph = graph; self } - /// Set the fallback configuration. - pub fn with_fallback(mut self, fallback: FallbackConfig) -> Self { - self.fallback = fallback; - self - } - /// Validate the configuration. pub fn validate(&self) -> Result<(), ConfigValidationError> { let mut errors = Vec::new(); @@ -160,11 +146,18 @@ impl Config { )); } - // Validate summary (index) - if self.summary.max_tokens == 0 { + // Validate LLM slot tokens + if self.llm.index.max_tokens == 0 { errors.push(ValidationError::error( - "summary.max_tokens", - "Summary max tokens must be greater than 0", + "llm.index.max_tokens", + "Index max tokens must be greater than 0", + )); + } + + if self.llm.retrieval.max_tokens == 0 { + errors.push(ValidationError::error( + "llm.retrieval.max_tokens", + "Retrieval max tokens must be greater than 0", )); } @@ -176,16 +169,6 @@ impl Config { )); } - if self.retrieval.temperature < 0.0 || self.retrieval.temperature > 2.0 { - errors.push( - ValidationError::warning( - "retrieval.temperature", - "Temperature outside typical range [0.0, 2.0]", - ) - .with_actual(self.retrieval.temperature.to_string()), - ); - } - // Validate content aggregator if self.retrieval.content.token_budget == 0 { errors.push(ValidationError::error( @@ -207,10 +190,10 @@ impl Config { ); } - // Validate concurrency - if self.concurrency.max_concurrent_requests == 0 { + // Validate throttle + if self.llm.throttle.max_concurrent_requests == 0 { errors.push(ValidationError::error( - "concurrency.max_concurrent_requests", + "llm.throttle.max_concurrent_requests", "Max concurrent requests must be greater than 0", )); } @@ -230,9 +213,9 @@ impl Config { } // Validate fallback - if self.fallback.enabled && self.fallback.models.is_empty() { + if self.llm.fallback.enabled && self.llm.fallback.models.is_empty() { errors.push(ValidationError::warning( - "fallback.models", + "llm.fallback.models", "Fallback enabled but no fallback models configured", )); } @@ -355,20 +338,18 @@ mod tests { #[test] fn test_config_defaults() { let config = Config::default(); + assert!(config.llm.model.is_empty()); + assert!(config.llm.index.model.is_none()); + assert_eq!(config.retrieval.top_k, 3); assert_eq!(config.indexer.subsection_threshold, 300); - assert!(config.summary.model.is_empty()); - assert!(config.retrieval.model.is_empty()); - assert_eq!(config.concurrency.max_concurrent_requests, 10); - // New fields - assert!(config.llm.index.model.is_empty()); assert!(config.metrics.enabled); } #[test] - fn test_llm_pool_config_defaults() { - let config = LlmPoolConfig::default(); - assert!(config.index.model.is_empty()); - assert!(config.retrieval.model.is_empty()); + fn test_llm_config_defaults() { + let config = LlmConfig::default(); + assert!(config.index.model.is_none()); + assert!(config.retrieval.model.is_none()); assert_eq!(config.retry.max_attempts, 3); assert_eq!(config.throttle.max_concurrent_requests, 10); } diff --git a/rust/src/config/types/retrieval.rs b/rust/src/config/types/retrieval.rs index fc131bc6..18df7cc8 100644 --- a/rust/src/config/types/retrieval.rs +++ b/rust/src/config/types/retrieval.rs @@ -1,36 +1,23 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! Retrieval configuration types. +//! Retrieval strategy configuration types. +//! +//! LLM configuration (model, api_key, endpoint) is managed centrally +//! in [`LlmConfig`](super::LlmConfig). This module only contains +//! retrieval strategy parameters. use serde::{Deserialize, Serialize}; use super::content::ContentAggregatorConfig; use super::storage::{CacheConfig, StrategyConfig, SufficiencyConfig}; -/// Retrieval model configuration (for navigation). +/// Retrieval strategy configuration. +/// +/// Controls how documents are searched and retrieved, independent +/// of which LLM model is used for navigation. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RetrievalConfig { - /// Model name for retrieval/navigation. - #[serde(default)] - pub model: String, - - /// API endpoint for retrieval model. - #[serde(default)] - pub endpoint: String, - - /// API key. - #[serde(default)] - pub api_key: Option, - - /// Maximum tokens for retrieval context. - #[serde(default = "default_max_retrieval_tokens")] - pub max_tokens: usize, - - /// Temperature for retrieval. - #[serde(default = "default_temperature")] - pub temperature: f32, - /// Number of top-k results to return. #[serde(default = "default_top_k")] pub top_k: usize, @@ -56,14 +43,6 @@ pub struct RetrievalConfig { pub content: ContentAggregatorConfig, } -fn default_max_retrieval_tokens() -> usize { - 1000 -} - -fn default_temperature() -> f32 { - 0.0 -} - fn default_top_k() -> usize { 3 } @@ -71,11 +50,6 @@ fn default_top_k() -> usize { impl Default for RetrievalConfig { fn default() -> Self { Self { - model: String::new(), - endpoint: String::new(), - api_key: None, - max_tokens: default_max_retrieval_tokens(), - temperature: default_temperature(), top_k: default_top_k(), search: SearchConfig::default(), sufficiency: SufficiencyConfig::default(), @@ -92,24 +66,6 @@ impl RetrievalConfig { Self::default() } - /// Set the model. - pub fn with_model(mut self, model: impl Into) -> Self { - self.model = model.into(); - self - } - - /// Set the endpoint. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = endpoint.into(); - self - } - - /// Set the API key. - pub fn with_api_key(mut self, api_key: impl Into) -> Self { - self.api_key = Some(api_key.into()); - self - } - /// Set the top_k. pub fn with_top_k(mut self, top_k: usize) -> Self { self.top_k = top_k; @@ -206,7 +162,6 @@ mod tests { #[test] fn test_retrieval_config_defaults() { let config = RetrievalConfig::default(); - assert!(config.model.is_empty()); assert_eq!(config.top_k, 3); assert_eq!(config.search.top_k, 5); } diff --git a/rust/src/config/types/storage.rs b/rust/src/config/types/storage.rs index 00b9b7ea..b13304ea 100644 --- a/rust/src/config/types/storage.rs +++ b/rust/src/config/types/storage.rs @@ -33,6 +33,10 @@ pub struct StorageConfig { /// Enable compression for stored documents. #[serde(default)] pub compression: CompressionConfig, + + /// Directory for pipeline checkpoints (derived from `workspace_dir`). + #[serde(skip)] + pub checkpoint_dir: PathBuf, } fn default_workspace_dir() -> PathBuf { @@ -100,13 +104,16 @@ fn default_checksum_enabled() -> bool { impl Default for StorageConfig { fn default() -> Self { + let workspace_dir = default_workspace_dir(); + let checkpoint_dir = workspace_dir.join("checkpoints"); Self { - workspace_dir: default_workspace_dir(), + workspace_dir, cache_size: default_cache_size(), atomic_writes: default_atomic_writes(), file_lock: default_file_lock(), checksum_enabled: default_checksum_enabled(), compression: CompressionConfig::default(), + checkpoint_dir, } } } diff --git a/rust/src/config/validator.rs b/rust/src/config/validator.rs index c4000764..7f7d01f6 100644 --- a/rust/src/config/validator.rs +++ b/rust/src/config/validator.rs @@ -82,22 +82,26 @@ impl ValidationRule for RangeValidator { ); } - // Summary ranges - if config.summary.max_tokens == 0 { + // LLM slot token ranges + if config.llm.index.max_tokens == 0 { errors.push(ValidationError::error( - "summary.max_tokens", - "Summary max tokens must be greater than 0", + "llm.index.max_tokens", + "Index max tokens must be greater than 0", )); } - if config.summary.temperature < 0.0 || config.summary.temperature > 2.0 { - errors.push( - ValidationError::warning( - "summary.temperature", - "Temperature outside typical range [0.0, 2.0]", - ) - .with_actual(config.summary.temperature.to_string()), - ); + if config.llm.retrieval.max_tokens == 0 { + errors.push(ValidationError::error( + "llm.retrieval.max_tokens", + "Retrieval max tokens must be greater than 0", + )); + } + + if config.llm.pilot.max_tokens == 0 { + errors.push(ValidationError::error( + "llm.pilot.max_tokens", + "Pilot max tokens must be greater than 0", + )); } // Retrieval ranges @@ -145,25 +149,25 @@ impl ValidationRule for RangeValidator { )); } - // Concurrency ranges - if config.concurrency.max_concurrent_requests == 0 { + // Throttle ranges + if config.llm.throttle.max_concurrent_requests == 0 { errors.push(ValidationError::error( - "concurrency.max_concurrent_requests", + "llm.throttle.max_concurrent_requests", "Max concurrent requests must be greater than 0", )); } - if config.concurrency.requests_per_minute == 0 { + if config.llm.throttle.requests_per_minute == 0 { errors.push(ValidationError::error( - "concurrency.requests_per_minute", + "llm.throttle.requests_per_minute", "Requests per minute must be greater than 0", )); } // Fallback ranges - if config.fallback.max_retries == 0 { + if config.llm.fallback.max_retries == 0 { errors.push(ValidationError::warning( - "fallback.max_retries", + "llm.fallback.max_retries", "Max retries is 0, fallback will not retry", )); } @@ -176,15 +180,15 @@ struct ConsistencyValidator; impl ValidationRule for ConsistencyValidator { fn validate(&self, config: &Config, errors: &mut Vec) { - // Check if summary tokens are reasonable - if config.summary.max_tokens > config.indexer.max_segment_tokens { + // Check if index tokens are reasonable + if config.llm.index.max_tokens > config.indexer.max_segment_tokens { errors.push( ValidationError::warning( - "summary.max_tokens", - "Summary max tokens exceeds max segment tokens", + "llm.index.max_tokens", + "Index max tokens exceeds max segment tokens", ) .with_expected(format!("<= {}", config.indexer.max_segment_tokens)) - .with_actual(config.summary.max_tokens.to_string()), + .with_actual(config.llm.index.max_tokens.to_string()), ); } @@ -254,33 +258,32 @@ struct DependencyValidator; impl ValidationRule for DependencyValidator { fn validate(&self, config: &Config, errors: &mut Vec) { // Check if API key is available when summaries are needed - if config.summary.api_key.is_none() { - // Check if any feature requires LLM + if config.llm.api_key.is_none() { if config.indexer.max_summary_tokens > 0 { errors.push(ValidationError::info( - "summary.api_key", + "llm.api_key", "No API key configured, summary generation will be disabled", )); } } // Check fallback configuration - if config.fallback.enabled { - if config.fallback.models.is_empty() && config.fallback.endpoints.is_empty() { + if config.llm.fallback.enabled { + if config.llm.fallback.models.is_empty() && config.llm.fallback.endpoints.is_empty() { errors.push(ValidationError::warning( - "fallback.models", + "llm.fallback.models", "Fallback enabled but no fallback models or endpoints configured", )); } // Check retry behavior consistency if matches!( - config.fallback.on_rate_limit, + config.llm.fallback.on_rate_limit, super::types::FallbackBehavior::Fallback - ) && config.fallback.models.is_empty() + ) && config.llm.fallback.models.is_empty() { errors.push(ValidationError::error( - "fallback.models", + "llm.fallback.models", "Rate limit behavior is 'fallback' but no fallback models configured", )); } @@ -374,8 +377,8 @@ mod tests { #[test] fn test_validator_catches_dependency_warnings() { let mut config = Config::default(); - config.fallback.enabled = true; - config.fallback.models.clear(); + config.llm.fallback.enabled = true; + config.llm.fallback.models.clear(); let validator = ConfigValidator::new(); let result = validator.validate(&config); @@ -385,7 +388,7 @@ mod tests { assert!( err.errors .iter() - .any(|e| e.path.contains("fallback.models")) + .any(|e| e.path.contains("llm.fallback.models")) ); } } diff --git a/rust/src/document/graph.rs b/rust/src/document/graph.rs deleted file mode 100644 index 2a4cade8..00000000 --- a/rust/src/document/graph.rs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Re-export all graph types from the standalone `graph` module. -//! -//! This shim preserves backward compatibility for code importing -//! from `crate::document::DocumentGraph`. - -pub use crate::graph::{ - DocumentGraph, DocumentGraphConfig, DocumentGraphNode, EdgeEvidence, GraphEdge, GraphMetadata, - KeywordDocEntry, SharedKeyword, WeightedKeyword, -}; diff --git a/rust/src/document/mod.rs b/rust/src/document/mod.rs index cc7c22e8..7f00a84b 100644 --- a/rust/src/document/mod.rs +++ b/rust/src/document/mod.rs @@ -16,7 +16,6 @@ //! - [`NodeReference`] - In-document reference (e.g., "see Appendix G") //! - [`RefType`] - Type of reference (Section, Appendix, Table, etc.) -mod graph; mod node; mod reasoning; mod reference; @@ -24,16 +23,12 @@ mod structure; mod toc; mod tree; -pub use graph::{ - DocumentGraph, DocumentGraphConfig, DocumentGraphNode, EdgeEvidence, GraphEdge, GraphMetadata, - KeywordDocEntry, SharedKeyword, WeightedKeyword, -}; pub use node::{NodeId, TreeNode}; pub use reasoning::{ HotNodeEntry, ReasoningIndex, ReasoningIndexBuilder, ReasoningIndexConfig, SectionSummary, SummaryShortcut, TopicEntry, }; -pub use reference::{NodeReference, RefType, ReferenceExtractor, ReferenceResolver}; +pub use reference::{NodeReference, RefType, ReferenceExtractor}; pub use structure::{DocumentStructure, StructureNode}; pub use toc::{TocConfig, TocEntry, TocNode, TocView}; pub use tree::{DocumentTree, RetrievalIndex}; diff --git a/rust/src/events/emitter.rs b/rust/src/events/emitter.rs index c54efa94..7804a25c 100644 --- a/rust/src/events/emitter.rs +++ b/rust/src/events/emitter.rs @@ -8,24 +8,9 @@ use std::sync::Arc; -use async_trait::async_trait; use parking_lot::RwLock; -use tracing::info; -use super::types::{Event, IndexEvent, QueryEvent, WorkspaceEvent}; - -/// Sync event handler trait. -pub(crate) trait EventHandler: Send + Sync { - /// Handle an event. - fn handle(&self, event: &Event); -} - -/// Async event handler trait. -#[async_trait] -pub(crate) trait AsyncEventHandler: Send + Sync { - /// Handle an event asynchronously. - async fn handle(&self, event: &Event); -} +use super::types::{IndexEvent, QueryEvent, WorkspaceEvent}; /// Type alias for sync index handler. pub(crate) type IndexHandler = Box; @@ -46,9 +31,6 @@ struct EventEmitterInner { /// Workspace event handlers. workspace_handlers: Vec, - - /// Async handlers. - async_handlers: Vec>, } impl Default for EventEmitterInner { @@ -57,7 +39,6 @@ impl Default for EventEmitterInner { index_handlers: Vec::new(), query_handlers: Vec::new(), workspace_handlers: Vec::new(), - async_handlers: Vec::new(), } } } @@ -120,26 +101,12 @@ impl EventEmitter { self } - /// Add an async event handler. - pub(crate) fn with_async_handler(self, handler: Arc) -> Self - where - H: AsyncEventHandler + 'static, - { - self.inner.write().async_handlers.push(handler); - self - } - /// Emit an index event. pub fn emit_index(&self, event: IndexEvent) { let inner = self.inner.read(); for handler in &inner.index_handlers { handler(&event); } - for _handler in &inner.async_handlers { - // For sync context, we just log async handlers - let event = Event::Index(event.clone()); - info!("Async event: {:?}", event); - } } /// Emit a query event. @@ -164,7 +131,6 @@ impl EventEmitter { !inner.index_handlers.is_empty() || !inner.query_handlers.is_empty() || !inner.workspace_handlers.is_empty() - || !inner.async_handlers.is_empty() } /// Merge another emitter into this one. @@ -180,9 +146,6 @@ impl EventEmitter { inner .workspace_handlers .extend(other_inner.workspace_handlers.drain(..)); - inner - .async_handlers - .extend(other_inner.async_handlers.drain(..)); drop(inner); drop(other_inner); self @@ -212,7 +175,6 @@ impl std::fmt::Debug for EventEmitter { .field("index_handlers", &inner.index_handlers.len()) .field("query_handlers", &inner.query_handlers.len()) .field("workspace_handlers", &inner.workspace_handlers.len()) - .field("async_handlers", &inner.async_handlers.len()) .finish() } } diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 7e390219..e8e55df5 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -28,4 +28,4 @@ mod emitter; mod types; pub use emitter::EventEmitter; -pub use types::{Event, IndexEvent, QueryEvent, WorkspaceEvent}; +pub use types::{IndexEvent, QueryEvent, WorkspaceEvent}; diff --git a/rust/src/events/types.rs b/rust/src/events/types.rs index 2d5c22f7..05ca0754 100644 --- a/rust/src/events/types.rs +++ b/rust/src/events/types.rs @@ -9,19 +9,6 @@ use crate::index::parse::DocumentFormat; use crate::retrieval::SufficiencyLevel; -/// Top-level event types for client operations. -#[derive(Debug, Clone)] -pub enum Event { - /// Indexing events. - Index(IndexEvent), - - /// Query events. - Query(QueryEvent), - - /// Workspace events. - Workspace(WorkspaceEvent), -} - /// Indexing operation events. #[derive(Debug, Clone)] pub enum IndexEvent { diff --git a/rust/src/graph/mod.rs b/rust/src/graph/mod.rs index 6c084e22..f1b48862 100644 --- a/rust/src/graph/mod.rs +++ b/rust/src/graph/mod.rs @@ -35,7 +35,4 @@ mod types; // Re-export public API pub use builder::DocumentGraphBuilder; pub use config::DocumentGraphConfig; -pub use types::{ - DocumentGraph, DocumentGraphNode, EdgeEvidence, GraphEdge, GraphMetadata, KeywordDocEntry, - SharedKeyword, WeightedKeyword, -}; +pub use types::{DocumentGraph, DocumentGraphNode, EdgeEvidence, GraphEdge, WeightedKeyword}; diff --git a/rust/src/index/config.rs b/rust/src/index/config.rs index edb20c2e..798951b1 100644 --- a/rust/src/index/config.rs +++ b/rust/src/index/config.rs @@ -10,8 +10,9 @@ //! - [`ThinningConfig`] - Node merging settings use super::summary::SummaryStrategy; -use crate::config::{ConcurrencyConfig, IndexerConfig}; +use crate::config::IndexerConfig; use crate::document::{DocumentTree, ReasoningIndexConfig}; +use crate::llm::throttle::ConcurrencyConfig; use crate::utils::fingerprint::{Fingerprint, Fingerprinter}; use std::path::PathBuf; diff --git a/rust/src/index/parse/toc/assigner.rs b/rust/src/index/parse/toc/assigner.rs index b7399dce..267cda18 100644 --- a/rust/src/index/parse/toc/assigner.rs +++ b/rust/src/index/parse/toc/assigner.rs @@ -7,9 +7,9 @@ use futures::stream::{self, StreamExt}; use std::collections::HashMap; use tracing::{debug, info}; -use crate::config::LlmConfig; use crate::error::Result; use crate::index::parse::pdf::PdfPage; +use crate::llm::config::LlmConfig; use super::types::{PageOffset, TocEntry}; use crate::llm::LlmClient; diff --git a/rust/src/index/parse/toc/detector.rs b/rust/src/index/parse/toc/detector.rs index 050c6b2a..8484e101 100644 --- a/rust/src/index/parse/toc/detector.rs +++ b/rust/src/index/parse/toc/detector.rs @@ -6,8 +6,8 @@ use regex::Regex; use tracing::debug; -use crate::config::LlmConfig; use crate::error::Result; +use crate::llm::config::LlmConfig; use super::types::TocDetection; use crate::index::parse::pdf::PdfPage; diff --git a/rust/src/index/parse/toc/parser.rs b/rust/src/index/parse/toc/parser.rs index 06aaade3..df0f306d 100644 --- a/rust/src/index/parse/toc/parser.rs +++ b/rust/src/index/parse/toc/parser.rs @@ -5,8 +5,8 @@ use tracing::debug; -use crate::config::LlmConfig; use crate::error::Result; +use crate::llm::config::LlmConfig; use super::types::TocEntry; use crate::llm::LlmClient; diff --git a/rust/src/index/parse/toc/repairer.rs b/rust/src/index/parse/toc/repairer.rs index 3c7666fe..61ba414e 100644 --- a/rust/src/index/parse/toc/repairer.rs +++ b/rust/src/index/parse/toc/repairer.rs @@ -6,9 +6,9 @@ use futures::stream::{self, StreamExt}; use tracing::{debug, info}; -use crate::config::LlmConfig; use crate::error::Result; use crate::index::parse::pdf::PdfPage; +use crate::llm::config::LlmConfig; use super::types::{TocEntry, VerificationError, VerificationReport}; use super::verifier::IndexVerifier; diff --git a/rust/src/index/parse/toc/structure_extractor.rs b/rust/src/index/parse/toc/structure_extractor.rs index 36925644..63ce9d7e 100644 --- a/rust/src/index/parse/toc/structure_extractor.rs +++ b/rust/src/index/parse/toc/structure_extractor.rs @@ -10,9 +10,9 @@ use futures::stream::{self, StreamExt}; use tracing::{debug, info, warn}; -use crate::config::LlmConfig; use crate::error::Result; use crate::index::parse::pdf::PdfPage; +use crate::llm::config::LlmConfig; use super::types::TocEntry; use crate::llm::LlmClient; diff --git a/rust/src/index/parse/toc/verifier.rs b/rust/src/index/parse/toc/verifier.rs index 3eda474c..1e3d1d45 100644 --- a/rust/src/index/parse/toc/verifier.rs +++ b/rust/src/index/parse/toc/verifier.rs @@ -7,9 +7,9 @@ use futures::stream::{self, StreamExt}; use rand::seq::SliceRandom; use tracing::{debug, info}; -use crate::config::LlmConfig; use crate::error::Result; use crate::index::parse::pdf::PdfPage; +use crate::llm::config::LlmConfig; use super::types::{ErrorType, TocEntry, VerificationError, VerificationReport}; use crate::llm::LlmClient; diff --git a/rust/src/index/parse/types.rs b/rust/src/index/parse/types.rs index 6a7fa07b..baaa8224 100644 --- a/rust/src/index/parse/types.rs +++ b/rust/src/index/parse/types.rs @@ -36,6 +36,12 @@ impl DocumentFormat { Self::Pdf => "pdf", } } + + /// All supported file extensions (lowercase). + /// + /// Single source of truth — used by directory scanning to + /// discover indexable files. + pub const SUPPORTED_EXTENSIONS: &'static [&'static str] = &["md", "pdf"]; } /// A raw node extracted from a document. diff --git a/rust/src/index/stages/enhance.rs b/rust/src/index/stages/enhance.rs index d33e0acc..770fddf9 100644 --- a/rust/src/index/stages/enhance.rs +++ b/rust/src/index/stages/enhance.rs @@ -13,7 +13,7 @@ use crate::document::NodeId; use crate::error::Result; use crate::index::incremental; use crate::llm::LlmClient; -use crate::memo::{MemoKey, MemoStore}; +use crate::llm::memo::{MemoKey, MemoStore}; use crate::utils::fingerprint::Fingerprint; use super::{IndexStage, StageResult}; diff --git a/rust/src/index/summary/strategy.rs b/rust/src/index/summary/strategy.rs index 2753be91..33f58e49 100644 --- a/rust/src/index/summary/strategy.rs +++ b/rust/src/index/summary/strategy.rs @@ -7,7 +7,7 @@ use async_trait::async_trait; use crate::document::{DocumentTree, NodeId}; use crate::llm::{LlmClient, LlmResult}; -use crate::memo::{MemoKey, MemoStore, MemoValue}; +use crate::llm::memo::{MemoKey, MemoStore, MemoValue}; use crate::utils::fingerprint::Fingerprint; /// Configuration for summary strategies. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 26dcceae..6cc4f91c 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,12 +1,15 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 + #![allow(dead_code)] //! # Vectorless //! -//! A document engine for AI. It transforms documents into hierarchical semantic -//! trees and uses the LLM itself to navigate and retrieve — purely LLM-guided, -//! from indexing to querying. No vector databases, no embeddings, no similarity search. +//! A reasoning-native document engine for AI. +//! +//! It will reason through any of your structured documents — **PDFs, Markdown, +//! reports, contracts** — and retrieve only what's relevant. Nothing more, +//! nothing less. //! //! ## Quick Start //! @@ -15,17 +18,19 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { -//! let client = EngineBuilder::new() +//! let engine = EngineBuilder::new() //! .with_key("sk-...") //! .with_model("gpt-4o") +//! .with_endpoint("https://api.openai.com/v1") //! .build() //! .await?; //! -//! let result = client.index(IndexContext::from_path("./document.md")).await?; +//! let result = engine.index(IndexContext::from_path("./document.md")).await?; //! let doc_id = result.doc_id().unwrap(); //! -//! let result = client.query( -//! QueryContext::new("What is this about?").with_doc_ids(vec![doc_id.to_string()]) +//! let result = engine.query( +//! QueryContext::new("What is this about?") +//! .with_doc_ids(vec![doc_id.to_string()]), //! ).await?; //! println!("{}", result.content); //! @@ -33,45 +38,58 @@ //! } //! ``` -pub mod client; -pub mod config; -pub use config::Config; -pub mod document; -pub mod error; -pub mod events; -pub mod graph; +// ── Modules ────────────────────────────────────────────────────────────────── + +mod client; +mod config; +mod document; +mod error; +mod events; +mod graph; +mod metrics; + mod index; mod llm; -mod memo; -pub mod metrics; mod retrieval; mod storage; -mod throttle; mod utils; -// Client API +// ── Public API ─────────────────────────────────────────────────────────────── + +// Client pub use client::{ - BuildError, ClientError, DocumentFormat, DocumentInfo, Engine, EngineBuilder, FailedItem, - IndexContext, IndexItem, IndexMode, IndexOptions, IndexResult, QueryContext, QueryResult, - QueryResultItem, + BuildError, DocumentFormat, DocumentInfo, Engine, EngineBuilder, FailedItem, IndexContext, + IndexItem, IndexMode, IndexOptions, IndexResult, QueryContext, QueryResult, QueryResultItem, }; -// Error types -pub use error::{Error, Result}; +// Config +pub use config::Config; -// Document types +// Documents pub use document::{ DocumentStructure, DocumentTree, NodeId, ReasoningIndexConfig, StructureNode, TocConfig, TocEntry, TocNode, TocView, TreeNode, }; -// Graph types -pub use graph::DocumentGraph; +// Graph +pub use graph::{DocumentGraph, DocumentGraphNode, EdgeEvidence, GraphEdge, WeightedKeyword}; -// Event types +// Events pub use events::{EventEmitter, IndexEvent, QueryEvent, WorkspaceEvent}; -// Runtime metrics reports +// Metrics pub use metrics::{ IndexMetrics, LlmMetricsReport, MetricsReport, PilotMetricsReport, RetrievalMetricsReport, }; + +// Errors +pub use error::{Error, Result}; + +/// Test-only utilities. +/// +/// **Do not use in production code.** This module exposes helpers for writing +/// integration tests without a real LLM endpoint. +#[doc(hidden)] +pub mod __test_support { + pub use crate::client::test_support::*; +} diff --git a/rust/src/llm/client.rs b/rust/src/llm/client.rs index 0c01bbdc..3eeb60af 100644 --- a/rust/src/llm/client.rs +++ b/rust/src/llm/client.rs @@ -12,7 +12,7 @@ use super::config::LlmConfig; use super::error::{LlmError, LlmResult}; use super::executor::LlmExecutor; use super::fallback::FallbackChain; -use crate::throttle::ConcurrencyController; +use super::throttle::ConcurrencyController; /// Unified LLM client. /// @@ -113,6 +113,15 @@ impl LlmClient { self } + /// Replace the async-openai client with a shared instance (reuses connection pool). + pub fn with_shared_openai_client( + mut self, + client: Arc>, + ) -> Self { + self.executor = self.executor.with_openai_client(client); + self + } + /// Add fallback chain for error recovery. /// /// # Example @@ -137,6 +146,12 @@ impl LlmClient { self } + /// Add metrics hub for recording LLM call statistics. + pub fn with_shared_metrics(mut self, hub: Arc) -> Self { + self.executor = self.executor.with_shared_metrics(hub); + self + } + /// Get the configuration. pub fn config(&self) -> &LlmConfig { self.executor.config() @@ -340,11 +355,24 @@ mod tests { #[test] fn test_client_with_concurrency() { - use crate::throttle::ConcurrencyConfig; + use crate::llm::throttle::ConcurrencyConfig; let controller = ConcurrencyController::new(ConcurrencyConfig::conservative()); let client = LlmClient::for_model("gpt-4o-mini").with_concurrency(controller); assert!(client.concurrency().is_some()); } + + #[test] + fn test_client_with_shared_metrics() { + use crate::metrics::MetricsHub; + + let hub = MetricsHub::shared(); + let client = LlmClient::for_model("gpt-4o").with_shared_metrics(hub.clone()); + + // Client should still function normally + assert_eq!(client.config().model, "gpt-4o"); + assert!(client.fallback().is_none()); // no fallback added + assert!(client.concurrency().is_none()); // no concurrency added + } } diff --git a/rust/src/llm/config.rs b/rust/src/llm/config.rs index 7be140a1..429cd971 100644 --- a/rust/src/llm/config.rs +++ b/rust/src/llm/config.rs @@ -1,109 +1,16 @@ // Copyright (c) 2026 vectorless developers // SPDX-License-Identifier: Apache-2.0 -//! LLM configuration types. +//! Runtime LLM configuration types. use serde::{Deserialize, Serialize}; use std::time::Duration; -/// Retry configuration for LLM calls. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RetryConfig { - /// Maximum number of retry attempts (including initial call). - /// e.g., max_attempts=3 means 1 initial + 2 retries. - #[serde(default = "default_max_attempts")] - pub max_attempts: usize, - - /// Initial delay before first retry (milliseconds). - #[serde(default = "default_initial_delay_ms")] - pub initial_delay_ms: u64, - - /// Maximum delay between retries (milliseconds). - #[serde(default = "default_max_delay_ms")] - pub max_delay_ms: u64, - - /// Multiplier for exponential backoff. - #[serde(default = "default_multiplier")] - pub multiplier: f64, - - /// Whether to retry on rate limit errors. - #[serde(default = "default_true")] - pub retry_on_rate_limit: bool, -} - -fn default_max_attempts() -> usize { - 3 -} -fn default_initial_delay_ms() -> u64 { - 500 -} -fn default_max_delay_ms() -> u64 { - 30000 -} -fn default_multiplier() -> f64 { - 2.0 -} -fn default_true() -> bool { - true -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_attempts: default_max_attempts(), - initial_delay_ms: default_initial_delay_ms(), - max_delay_ms: default_max_delay_ms(), - multiplier: default_multiplier(), - retry_on_rate_limit: default_true(), - } - } -} - -impl RetryConfig { - /// Create a new retry config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Set the maximum number of attempts. - pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { - self.max_attempts = max_attempts; - self - } - - /// Set the initial delay (milliseconds). - pub fn with_initial_delay(mut self, delay_ms: u64) -> Self { - self.initial_delay_ms = delay_ms; - self - } - - /// Set the maximum delay (milliseconds). - pub fn with_max_delay(mut self, delay_ms: u64) -> Self { - self.max_delay_ms = delay_ms; - self - } - - /// Set the backoff multiplier. - pub fn with_multiplier(mut self, multiplier: f64) -> Self { - self.multiplier = multiplier; - self - } - - /// Set whether to retry on rate limit. - pub fn with_retry_on_rate_limit(mut self, retry: bool) -> Self { - self.retry_on_rate_limit = retry; - self - } - - /// Calculate delay for a given attempt (0-indexed). - pub fn delay_for_attempt(&self, attempt: usize) -> Duration { - let delay_ms = (self.initial_delay_ms as f64) * self.multiplier.powf(attempt as f64); - let delay_ms = delay_ms.min(self.max_delay_ms as f64); - Duration::from_millis(delay_ms as u64) - } -} - -/// LLM client configuration. +/// Runtime LLM client configuration. +/// +/// This is the runtime representation used by [`LlmClient`](super::LlmClient). +/// Created from the config-layer [`LlmConfig`](crate::config::LlmConfig) +/// during pool construction — users never construct this directly. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmConfig { /// Model name (e.g., "gpt-4o-mini", "gpt-4o"). @@ -134,6 +41,7 @@ pub struct LlmConfig { fn default_max_tokens() -> usize { 2000 } + fn default_temperature() -> f32 { 0.0 } @@ -197,129 +105,99 @@ impl LlmConfig { } } -/// Pool of LLM configurations for different purposes. +/// Runtime retry configuration for LLM calls. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmConfigs { - /// Configuration for indexing tasks (document summarization, etc.). - #[serde(default = "default_index_config", alias = "summary")] - pub index: LlmConfig, - - /// Configuration for retrieval/navigation tasks. - #[serde(default = "default_retrieval_config")] - pub retrieval: LlmConfig, - - /// Configuration for Pilot navigation tasks. - #[serde(default = "default_pilot_config")] - pub pilot: LlmConfig, -} +pub struct RetryConfig { + /// Maximum number of retry attempts (including initial call). + #[serde(default = "default_max_attempts")] + pub max_attempts: usize, -fn default_index_config() -> LlmConfig { - LlmConfig { - max_tokens: 200, - temperature: 0.0, - ..LlmConfig::default() - } -} + /// Initial delay before first retry (milliseconds). + #[serde(default = "default_initial_delay_ms")] + pub initial_delay_ms: u64, -fn default_retrieval_config() -> LlmConfig { - LlmConfig { - max_tokens: 100, - temperature: 0.0, - ..LlmConfig::default() - } + /// Maximum delay between retries (milliseconds). + #[serde(default = "default_max_delay_ms")] + pub max_delay_ms: u64, + + /// Multiplier for exponential backoff. + #[serde(default = "default_multiplier")] + pub multiplier: f64, + + /// Whether to retry on rate limit errors. + #[serde(default = "default_true")] + pub retry_on_rate_limit: bool, } -fn default_pilot_config() -> LlmConfig { - LlmConfig { - max_tokens: 300, - temperature: 0.0, - ..LlmConfig::default() - } +fn default_max_attempts() -> usize { + 3 +} +fn default_initial_delay_ms() -> u64 { + 500 +} +fn default_max_delay_ms() -> u64 { + 30000 +} +fn default_multiplier() -> f64 { + 2.0 +} +fn default_true() -> bool { + true } -impl Default for LlmConfigs { +impl Default for RetryConfig { fn default() -> Self { Self { - index: default_index_config(), - retrieval: default_retrieval_config(), - pilot: default_pilot_config(), + max_attempts: default_max_attempts(), + initial_delay_ms: default_initial_delay_ms(), + max_delay_ms: default_max_delay_ms(), + multiplier: default_multiplier(), + retry_on_rate_limit: default_true(), } } } -// ============================================================================ -// Conversion from config types -// ============================================================================ - -impl From for LlmConfigs { - fn from(pool: crate::config::LlmPoolConfig) -> Self { - // Resolve shared values before moving individual client configs - let default_api_key = pool.api_key.clone(); - let default_endpoint = pool.endpoint.clone(); - - fn to_llm_config( - client: crate::config::LlmClientConfig, - default_api_key: &Option, - default_endpoint: &Option, - ) -> LlmConfig { - LlmConfig { - model: client.model, - endpoint: if client.endpoint.is_empty() { - default_endpoint.clone().unwrap_or_default() - } else { - client.endpoint - }, - api_key: client.api_key.or_else(|| default_api_key.clone()), - max_tokens: client.max_tokens, - temperature: client.temperature, - retry: RetryConfig::default(), - } - } +impl RetryConfig { + /// Create a new retry config with defaults. + pub fn new() -> Self { + Self::default() + } - Self { - index: to_llm_config(pool.index, &default_api_key, &default_endpoint), - retrieval: to_llm_config(pool.retrieval, &default_api_key, &default_endpoint), - pilot: to_llm_config(pool.pilot, &default_api_key, &default_endpoint), - } + /// Set the maximum number of attempts. + pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { + self.max_attempts = max_attempts; + self } -} -impl From for LlmConfig { - fn from(old: crate::config::LlmConfig) -> Self { - Self { - model: old.model, - endpoint: old.endpoint, - api_key: old.api_key, - max_tokens: old.max_tokens, - temperature: old.temperature, - retry: RetryConfig::default(), - } + /// Set the initial delay (milliseconds). + pub fn with_initial_delay(mut self, delay_ms: u64) -> Self { + self.initial_delay_ms = delay_ms; + self } -} -impl From for LlmConfig { - fn from(old: crate::config::SummaryConfig) -> Self { - Self { - model: old.model, - endpoint: old.endpoint, - api_key: old.api_key, - max_tokens: old.max_tokens, - temperature: old.temperature, - retry: RetryConfig::default(), - } + /// Set the maximum delay (milliseconds). + pub fn with_max_delay(mut self, delay_ms: u64) -> Self { + self.max_delay_ms = delay_ms; + self } -} -impl From for LlmConfig { - fn from(old: crate::config::RetrievalConfig) -> Self { - Self { - model: old.model, - endpoint: old.endpoint, - api_key: old.api_key, - max_tokens: old.max_tokens, - temperature: old.temperature, - retry: RetryConfig::default(), - } + /// Set the backoff multiplier. + pub fn with_multiplier(mut self, multiplier: f64) -> Self { + self.multiplier = multiplier; + self + } + + /// Set whether to retry on rate limit. + pub fn with_retry_on_rate_limit(mut self, retry: bool) -> Self { + self.retry_on_rate_limit = retry; + self + } + + /// Calculate delay for a given attempt (0-indexed). + pub fn delay_for_attempt(&self, attempt: usize) -> Duration { + let delay_ms = (self.initial_delay_ms as f64) * self.multiplier.powf(attempt as f64); + let delay_ms = delay_ms.min(self.max_delay_ms as f64); + Duration::from_millis(delay_ms as u64) } } diff --git a/rust/src/llm/executor.rs b/rust/src/llm/executor.rs index 8498062b..ef66c134 100644 --- a/rust/src/llm/executor.rs +++ b/rust/src/llm/executor.rs @@ -53,10 +53,16 @@ use std::sync::Arc; use std::time::Duration; use tracing::{debug, info, warn}; +use async_openai::types::chat::{ + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + CreateChatCompletionRequestArgs, +}; + use super::config::LlmConfig; use super::error::{LlmError, LlmResult}; use super::fallback::{FallbackChain, FallbackStep}; -use crate::throttle::ConcurrencyController; +use crate::metrics::MetricsHub; +use super::throttle::ConcurrencyController; /// Unified executor for LLM operations. /// @@ -65,10 +71,14 @@ use crate::throttle::ConcurrencyController; pub struct LlmExecutor { /// LLM configuration. config: LlmConfig, + /// Reusable async-openai client (created once, shared via Arc). + openai_client: Arc>, /// Throttle controller (optional). throttle: Option>, /// Fallback chain (optional). fallback: Option>, + /// Metrics hub for recording LLM call statistics (optional). + metrics: Option>, } impl std::fmt::Debug for LlmExecutor { @@ -78,6 +88,8 @@ impl std::fmt::Debug for LlmExecutor { .field("endpoint", &self.config.endpoint) .field("has_throttle", &self.throttle.is_some()) .field("has_fallback", &self.fallback.is_some()) + .field("has_openai_client", &true) + .field("has_metrics", &self.metrics.is_some()) .finish() } } @@ -85,13 +97,32 @@ impl std::fmt::Debug for LlmExecutor { impl LlmExecutor { /// Create a new executor with the given configuration. pub fn new(config: LlmConfig) -> Self { + let openai_client = Self::build_openai_client(&config); Self { config, + openai_client: Arc::new(openai_client), throttle: None, fallback: None, + metrics: None, } } + /// Build the async-openai client from config. + fn build_openai_client( + config: &LlmConfig, + ) -> async_openai::Client { + let api_key = config.api_key.clone().unwrap_or_default(); + let endpoint = if config.endpoint.is_empty() { + "https://api.openai.com/v1".to_string() + } else { + config.endpoint.clone() + }; + let openai_config = async_openai::config::OpenAIConfig::new() + .with_api_key(api_key) + .with_api_base(endpoint); + async_openai::Client::with_config(openai_config) + } + /// Create an executor with default configuration. pub fn with_defaults() -> Self { Self::new(LlmConfig::default()) @@ -126,6 +157,21 @@ impl LlmExecutor { self } + /// Add metrics hub for recording LLM call statistics. + pub fn with_shared_metrics(mut self, hub: Arc) -> Self { + self.metrics = Some(hub); + self + } + + /// Replace the async-openai client (used when pool reconfigures clients). + pub fn with_openai_client( + mut self, + client: Arc>, + ) -> Self { + self.openai_client = client; + self + } + /// Get the configuration. pub fn config(&self) -> &LlmConfig { &self.config @@ -171,7 +217,6 @@ impl LlmExecutor { ) -> LlmResult { let mut attempts = 0; let mut current_model = self.config.model.clone(); - let current_endpoint = self.config.endpoint.clone(); let mut fallback_history: Vec = vec![]; let mut total_attempts_including_fallback = 0; @@ -198,13 +243,12 @@ impl LlmExecutor { debug!( attempt = attempts, model = %current_model, - endpoint = %current_endpoint, "Executing LLM request" ); // Step 2: Execute the request let result = self - .do_request(¤t_model, ¤t_endpoint, system, user, max_tokens) + .do_request(¤t_model, system, user, max_tokens) .await; match result { @@ -224,6 +268,15 @@ impl LlmExecutor { return Ok(response); } Err(error) => { + // Record specific error events + if let Some(ref metrics) = self.metrics { + match &error { + LlmError::RateLimit(_) => metrics.record_llm_rate_limit(), + LlmError::Timeout(_) => metrics.record_llm_timeout(), + _ => {} + } + } + // Step 3: Check if we should retry if self.should_retry(&error, attempts) { let delay = self.retry_delay(attempts); @@ -250,11 +303,14 @@ impl LlmExecutor { to_model = %next_model, "Falling back to next model" ); + if let Some(ref metrics) = self.metrics { + metrics.record_llm_fallback(); + } fallback.record_fallback( &mut fallback_history, current_model.clone(), Some(next_model.clone()), - current_endpoint.clone(), + self.config.endpoint.clone(), None, error.to_string(), ); @@ -297,19 +353,11 @@ impl LlmExecutor { return false; } - match error { - LlmError::RateLimit(_) => self.config.retry.retry_on_rate_limit, - LlmError::Timeout(_) => true, - LlmError::Api(msg) => { - let msg_lower = msg.to_lowercase(); - msg_lower.contains("rate limit") - || msg_lower.contains("429") - || msg_lower.contains("503") - || msg_lower.contains("502") - || msg_lower.contains("timeout") - || msg_lower.contains("overloaded") - } - _ => false, + // Use unified retryable check, with rate-limit override + if matches!(error, LlmError::RateLimit(_)) { + self.config.retry.retry_on_rate_limit + } else { + error.is_retryable() } } @@ -322,74 +370,42 @@ impl LlmExecutor { async fn do_request( &self, model: &str, - endpoint: &str, system: &str, user: &str, max_tokens: Option, ) -> LlmResult { - use async_openai::{ - Client, - config::OpenAIConfig, - types::chat::{ - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequestArgs, - }, - }; - - let api_key = self.config.api_key.clone().ok_or_else(|| { - LlmError::Config( - "No API key configured. Call .with_key(\"sk-...\") when building the engine." - .to_string(), - ) - })?; - - let openai_config = OpenAIConfig::new() - .with_api_key(api_key) - .with_api_base(endpoint); - - let client = Client::with_config(openai_config); - - // Truncate user prompt if too long - let truncated = self.truncate_prompt(user); - - // Build request based on whether max_tokens is specified - let request = if let Some(_tokens) = max_tokens { - CreateChatCompletionRequestArgs::default() - .model(model) - .messages([ - ChatCompletionRequestSystemMessage::from(system).into(), - ChatCompletionRequestUserMessage::from(truncated).into(), - ]) - .temperature(self.config.temperature) - // .max_tokens(tokens) - .build() - } else { - CreateChatCompletionRequestArgs::default() - .model(model) - .messages([ - ChatCompletionRequestSystemMessage::from(system).into(), - ChatCompletionRequestUserMessage::from(truncated).into(), - ]) - .temperature(self.config.temperature) - .build() - }; - - let request = - request.map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?; + // Build request + let request = CreateChatCompletionRequestArgs::default() + .model(model) + .messages([ + ChatCompletionRequestSystemMessage::from(system).into(), + ChatCompletionRequestUserMessage::from(user).into(), + ]) + .temperature(self.config.temperature) + .max_tokens(max_tokens.unwrap_or(self.config.max_tokens as u16)) + .build() + .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?; info!( "LLM request → endpoint: {}, model: {}, system: {} chars, user: {} chars", - endpoint, + self.config.endpoint, model, system.len(), - truncated.len() + user.len() ); let request_start = std::time::Instant::now(); - let response = client.chat().create(request).await.map_err(|e| { - let msg = e.to_string(); - LlmError::from_api_message(&msg) - })?; + let response = match self.openai_client.chat().create(request).await { + Ok(r) => r, + Err(e) => { + let elapsed = request_start.elapsed(); + if let Some(ref metrics) = self.metrics { + metrics.record_llm_call(0, 0, elapsed.as_millis() as u64, false); + } + let msg = e.to_string(); + return Err(LlmError::from_api_message(&msg)); + } + }; let request_elapsed = request_start.elapsed(); let usage = response.usage.as_ref(); @@ -402,6 +418,15 @@ impl LlmExecutor { .and_then(|choice| choice.message.content.clone()) .ok_or(LlmError::NoContent)?; + if let Some(ref metrics) = self.metrics { + metrics.record_llm_call( + prompt_tokens as u64, + completion_tokens as u64, + request_elapsed.as_millis() as u64, + true, + ); + } + info!( "LLM response ← {}ms, tokens: {} prompt + {} completion, content: {} chars", request_elapsed.as_millis(), @@ -412,17 +437,6 @@ impl LlmExecutor { Ok(content) } - - /// Truncate a prompt to a reasonable length. - fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str { - // Roughly 4 chars per token, limit to ~30k chars - const MAX_CHARS: usize = 30000; - if text.len() > MAX_CHARS { - &text[..MAX_CHARS] - } else { - text - } - } } impl Default for LlmExecutor { @@ -445,7 +459,7 @@ mod tests { #[test] fn test_executor_with_throttle() { - use crate::throttle::ConcurrencyConfig; + use crate::llm::throttle::ConcurrencyConfig; let controller = ConcurrencyController::new(ConcurrencyConfig::conservative()); let executor = LlmExecutor::for_model("gpt-4o-mini").with_throttle(controller); @@ -478,4 +492,18 @@ mod tests { let delay = executor.retry_delay(1); assert_eq!(delay, Duration::from_millis(500)); } + + #[test] + fn test_executor_with_metrics() { + let hub = MetricsHub::shared(); + let executor = LlmExecutor::for_model("gpt-4o").with_shared_metrics(hub); + + assert!(executor.metrics.is_some()); + } + + #[test] + fn test_executor_without_metrics() { + let executor = LlmExecutor::for_model("gpt-4o"); + assert!(executor.metrics.is_none()); + } } diff --git a/rust/src/llm/memo/mod.rs b/rust/src/llm/memo/mod.rs new file mode 100644 index 00000000..fff44e65 --- /dev/null +++ b/rust/src/llm/memo/mod.rs @@ -0,0 +1,14 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! LLM Memoization system for caching expensive LLM calls. +//! +//! Provides a caching layer for LLM-generated content, avoiding +//! redundant API calls via content-addressed LRU cache with TTL +//! and optional disk persistence. + +mod store; +mod types; + +pub use store::MemoStore; +pub use types::{MemoKey, MemoOpType, MemoValue, PilotDecisionValue}; diff --git a/rust/src/memo/store.rs b/rust/src/llm/memo/store.rs similarity index 84% rename from rust/src/memo/store.rs rename to rust/src/llm/memo/store.rs index 85860937..b75dfb92 100644 --- a/rust/src/memo/store.rs +++ b/rust/src/llm/memo/store.rs @@ -15,7 +15,6 @@ use chrono::Duration; use lru::LruCache; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; -use tokio::sync::RwLock as AsyncRwLock; use tracing::{debug, info}; use super::types::{MemoEntry, MemoKey, MemoOpType, MemoStats, MemoValue}; @@ -41,8 +40,8 @@ struct MemoStoreData { stats: MemoStats, } -/// Atomic statistics for lock-free access. -#[derive(Debug, Default)] +/// Lock-free atomic statistics for concurrent access. +#[derive(Debug)] struct AtomicStats { hits: AtomicU64, misses: AtomicU64, @@ -77,6 +76,12 @@ impl AtomicStats { self.tokens_saved.load(Ordering::Relaxed), ) } + + fn load_from(&self, hits: u64, misses: u64, tokens_saved: u64) { + self.hits.store(hits, Ordering::Relaxed); + self.misses.store(misses, Ordering::Relaxed); + self.tokens_saved.store(tokens_saved, Ordering::Relaxed); + } } /// LLM Memoization store. @@ -90,7 +95,7 @@ impl AtomicStats { /// # Example /// /// ```rust,ignore -/// let store = MemoStore::new(1000); +/// let store = MemoStore::new(); /// /// let summary = store.get_or_compute( /// MemoKey::summary(&content_fp), @@ -103,8 +108,8 @@ pub struct MemoStore { /// LRU cache for entries. cache: Arc>>, - /// Statistics (async for safe updates). - stats: Arc>, + /// Lock-free statistics. + stats: Arc, /// TTL for entries. ttl: Duration, @@ -152,7 +157,7 @@ impl MemoStore { std::num::NonZeroUsize::new(capacity) .unwrap_or(std::num::NonZeroUsize::new(1000).unwrap()), ))), - stats: Arc::new(AsyncRwLock::new(MemoStats::default())), + stats: Arc::new(AtomicStats::new()), ttl: DEFAULT_TTL, model_id: None, version: 1, @@ -183,13 +188,10 @@ impl MemoStore { let mut cache = self.cache.write(); if let Some(entry) = cache.get_mut(&full_key) { - // Check TTL if entry.is_expired(self.ttl) { cache.pop(&full_key); return None; } - - // Record hit entry.record_hit(); debug!("Memo cache hit for {:?}", key.op_type); return Some(entry.value.clone()); @@ -226,17 +228,12 @@ impl MemoStore { { // Check cache first (synchronous) if let Some(value) = self.get(&key) { - // Update stats - let mut stats = self.stats.write().await; - stats.hits += 1; + self.stats.record_hit(); return Ok(value); } // Record miss - { - let mut stats = self.stats.write().await; - stats.misses += 1; - } + self.stats.record_miss(); // Compute let (value, tokens) = compute().await?; @@ -244,11 +241,8 @@ impl MemoStore { // Cache result self.put_with_tokens(key.clone(), value.clone(), tokens); - // Update stats - { - let mut stats = self.stats.write().await; - stats.tokens_saved += tokens; - } + // Update tokens saved + self.stats.add_tokens_saved(tokens); Ok(value) } @@ -285,50 +279,25 @@ impl MemoStore { self.len() == 0 } - /// Get cache statistics. - pub async fn stats(&self) -> MemoStats { - let stats = self.stats.read().await; - let mut result = stats.clone(); - result.entries = self.len(); - result - } - - /// Get cache statistics synchronously. - /// - /// This acquires a read lock on the stats, which is generally fast. - /// Use this when you need stats without async context. - pub fn stats_snapshot(&self) -> MemoStats { - // Use try_read to avoid blocking; fall back to defaults if locked - match self.stats.try_read() { - Ok(stats) => { - let mut result = stats.clone(); - result.entries = self.len(); - result - } - Err(_) => MemoStats { - entries: self.len(), - ..Default::default() - }, + /// Get cache statistics (synchronous, lock-free). + pub fn stats(&self) -> MemoStats { + let (hits, misses, tokens_saved) = self.stats.snapshot(); + MemoStats { + entries: self.len(), + hits, + misses, + tokens_saved, + cost_saved: 0.0, } } /// Invalidate all entries of a specific operation type. /// - /// This is useful for batch invalidation when the algorithm for - /// a specific operation type changes. - /// - /// # Example - /// - /// ```rust,ignore - /// // Invalidate all pilot decision caches - /// let removed = store.invalidate_by_op_type(MemoOpType::PilotDecision); - /// println!("Removed {} cached pilot decisions", removed); - /// ``` + /// Useful when the algorithm for a specific operation changes. pub fn invalidate_by_op_type(&self, op_type: MemoOpType) -> usize { let mut cache = self.cache.write(); let before = cache.len(); - // Collect keys to remove based on entry value type let keys_to_remove: Vec = cache .iter() .filter_map(|(key, entry)| { @@ -343,7 +312,6 @@ impl MemoStore { }) .collect(); - // Remove entries for key in keys_to_remove { cache.pop(&key); } @@ -357,21 +325,11 @@ impl MemoStore { /// Invalidate all entries matching a model ID prefix. /// - /// This is useful when switching models or when a model's behavior changes. - /// - /// # Example - /// - /// ```rust,ignore - /// // Invalidate all GPT-4 caches - /// let removed = store.invalidate_by_model_prefix("gpt-4"); - /// ``` + /// Useful when switching models or when a model's behavior changes. pub fn invalidate_by_model_prefix(&self, prefix: &str) -> usize { let mut cache = self.cache.write(); let before = cache.len(); - // Since the key is a fingerprint, we need to check model_id from entries - // For now, we'll clear all entries if prefix matches our model_id - // A better approach would be to store model_id in entry metadata let should_clear = self .model_id .as_ref() @@ -396,14 +354,12 @@ impl MemoStore { let mut cache = self.cache.write(); let before = cache.len(); - // Collect expired keys let expired: Vec = cache .iter() .filter(|(_, entry)| entry.is_expired(self.ttl)) .map(|(k, _)| k.clone()) .collect(); - // Remove expired entries for key in expired { cache.pop(&key); } @@ -417,8 +373,11 @@ impl MemoStore { /// Save the cache to disk. pub async fn save(&self, path: &Path) -> Result<()> { + // Prune expired entries before persisting + self.prune_expired(); + let cache = self.cache.read(); - let stats = self.stats.read().await; + let stats = self.stats(); let entries: HashMap = cache.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); @@ -426,7 +385,7 @@ impl MemoStore { let data = MemoStoreData { version: 1, entries, - stats: stats.clone(), + stats, }; let parent = path @@ -459,20 +418,15 @@ impl MemoStore { .map_err(|e| crate::Error::Parse(format!("Failed to deserialize memo store: {}", e)))?; let mut cache = self.cache.write(); - let mut stats = self.stats.write().await; for (key, entry) in data.entries { - // Skip expired entries if !entry.is_expired(self.ttl) { cache.put(key, entry); } } - stats.entries = cache.len(); - stats.hits = data.stats.hits; - stats.misses = data.stats.misses; - stats.tokens_saved = data.stats.tokens_saved; - stats.cost_saved = data.stats.cost_saved; + // Restore stats + self.stats.load_from(data.stats.hits, data.stats.misses, data.stats.tokens_saved); info!( "Loaded memo store with {} entries from {:?}", @@ -484,7 +438,6 @@ impl MemoStore { /// Make a full cache key from a MemoKey. fn make_key(&self, key: &MemoKey) -> String { - // Include model_id and version in the key let mut key_with_context = key.clone(); if key_with_context.model_id.is_none() { key_with_context.model_id = self.model_id.clone(); @@ -616,7 +569,6 @@ mod tests { store.put(key, MemoValue::Summary(format!("Summary {}", i))); } - // Only 3 entries should remain assert_eq!(store.len(), 3); } @@ -656,7 +608,7 @@ mod tests { .unwrap(); assert_eq!(result2.as_summary(), Some("Computed")); - assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); // Still 1 + assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); } #[tokio::test] @@ -699,7 +651,7 @@ mod tests { .await .unwrap(); - // Hit via get_or_compute (this updates global stats) + // Hit store .get_or_compute(key.clone(), || async { Ok((MemoValue::Summary("Should not be called".to_string()), 0)) @@ -707,7 +659,7 @@ mod tests { .await .unwrap(); - let stats = store.stats().await; + let stats = store.stats(); assert_eq!(stats.misses, 1); assert_eq!(stats.hits, 1); assert_eq!(stats.tokens_saved, 100); diff --git a/rust/src/memo/types.rs b/rust/src/llm/memo/types.rs similarity index 96% rename from rust/src/memo/types.rs rename to rust/src/llm/memo/types.rs index 0ed92400..a45aed12 100644 --- a/rust/src/memo/types.rs +++ b/rust/src/llm/memo/types.rs @@ -23,6 +23,18 @@ pub enum MemoOpType { /// Content extraction result. Extraction, + /// LLM node evaluation during retrieval. + NodeEvaluation, + + /// Sufficiency check result. + SufficiencyCheck, + + /// Query complexity detection. + ComplexityDetection, + + /// Query decomposition. + QueryDecomposition, + /// Custom operation type. Custom(u8), } @@ -35,6 +47,10 @@ impl MemoOpType { MemoOpType::PilotDecision => 1, MemoOpType::QueryAnalysis => 2, MemoOpType::Extraction => 3, + MemoOpType::NodeEvaluation => 4, + MemoOpType::SufficiencyCheck => 5, + MemoOpType::ComplexityDetection => 6, + MemoOpType::QueryDecomposition => 7, MemoOpType::Custom(n) => 100 + n, } } diff --git a/rust/src/llm/mod.rs b/rust/src/llm/mod.rs index 84fca4f2..215aba7f 100644 --- a/rust/src/llm/mod.rs +++ b/rust/src/llm/mod.rs @@ -8,13 +8,6 @@ //! - **Retrieval** — Document tree navigation //! - **Pilot** — Navigation guidance //! -//! # Features -//! -//! - Unified configuration with purpose-specific presets -//! - Automatic retry with exponential backoff -//! - JSON response parsing -//! - Unified error handling -//! //! # Architecture //! //! ```text @@ -34,45 +27,17 @@ //! │ └─────────────────────┘ │ //! └─────────────────────────────────────────────────────────────────┘ //! ``` -//! -//! # Example -//! -//! ```rust,no_run -//! use vectorless::llm::{LlmPool, LlmConfig, RetryConfig}; -//! -//! # #[tokio::main] -//! # async fn main() -> vectorless::llm::LlmResult<()> { -//! // Create a pool with default configurations -//! let pool = LlmPool::from_defaults(); -//! -//! // Use index client -//! let summary = pool.index().complete( -//! "You summarize text concisely.", -//! "Long text to summarize..." -//! ).await?; -//! -//! // Use retrieval client with JSON output -//! #[derive(serde::Deserialize)] -//! struct NavDecision { section: usize } -//! let decision: NavDecision = pool.retrieval().complete_json( -//! "You navigate documents.", -//! "Find section about X..." -//! ).await?; -//! -//! # Ok(()) -//! # } -//! ``` mod client; -mod config; +pub(crate) mod config; mod error; mod executor; mod fallback; +pub(crate) mod memo; mod pool; -mod retry; +pub(crate) mod throttle; pub use client::LlmClient; -pub use config::LlmConfigs; pub use error::LlmResult; pub use executor::LlmExecutor; pub use pool::LlmPool; diff --git a/rust/src/llm/pool.rs b/rust/src/llm/pool.rs index d7ddf637..76f04ada 100644 --- a/rust/src/llm/pool.rs +++ b/rust/src/llm/pool.rs @@ -6,8 +6,10 @@ use std::sync::Arc; use super::client::LlmClient; -use super::config::LlmConfigs; -use crate::throttle::ConcurrencyController; +use super::config::LlmConfig; +use super::fallback::{FallbackChain, FallbackConfig}; +use crate::metrics::MetricsHub; +use super::throttle::ConcurrencyController; /// Pool of LLM clients for different purposes. /// @@ -17,155 +19,112 @@ use crate::throttle::ConcurrencyController; /// - **Retrieval** — Document navigation (capable model) /// - **Pilot** — Navigation guidance (fast model) /// -/// # Example +/// # Construction /// -/// ```rust,no_run +/// The pool is built from a [`config::LlmConfig`](crate::config::LlmConfig) +/// which defines the global credentials and per-slot overrides. +/// +/// ```rust,ignore /// use vectorless::llm::LlmPool; /// -/// # #[tokio::main] -/// # async fn main() -> vectorless::llm::LlmResult<()> { -/// let pool = LlmPool::from_defaults(); +/// let pool = LlmPool::from_config(&config.llm); /// /// // Use index client for summarization /// let summary = pool.index().complete( /// "You summarize text concisely.", /// "Long text to summarize..." /// ).await?; -/// -/// // Use retrieval client for navigation -/// let nav = pool.retrieval().complete( -/// "You navigate documents.", -/// "Find information about X..." -/// ).await?; -/// -/// # Ok(()) -/// # } /// ``` #[derive(Debug, Clone)] pub struct LlmPool { index: Arc, retrieval: Arc, pilot: Arc, - concurrency: Option>, } impl LlmPool { - /// Create a new LLM pool from configurations. - pub fn new(configs: LlmConfigs) -> Self { + /// Create a pool from the unified LLM configuration. + /// + /// Resolves per-slot model overrides and creates individual + /// [`LlmClient`] instances with the appropriate settings. + /// When `metrics` is provided, all clients share the same hub + /// for unified LLM call statistics. + pub fn from_config( + config: &crate::config::LlmConfig, + metrics: Option>, + ) -> Self { + let api_key = config.api_key.clone(); + let endpoint = config.endpoint.clone().unwrap_or_default(); + let retry = config.retry.to_runtime_config(); + + let make_config = |slot: &crate::config::SlotConfig| -> LlmConfig { + LlmConfig { + model: config.resolve_model(slot), + endpoint: endpoint.clone(), + api_key: api_key.clone(), + max_tokens: slot.max_tokens, + temperature: slot.temperature, + retry: retry.clone(), + } + }; + + // Build a single shared async-openai client (reuses connection pool) + let openai_base = if endpoint.is_empty() { + "https://api.openai.com/v1".to_string() + } else { + endpoint.clone() + }; + let openai_client = Arc::new(async_openai::Client::with_config( + async_openai::config::OpenAIConfig::new() + .with_api_key(api_key.clone().unwrap_or_default()) + .with_api_base(openai_base), + )); + + // Attach shared throttle controller from config + let concurrency_config = config.throttle.to_runtime_config(); + let controller = Arc::new(ConcurrencyController::new(concurrency_config)); + + // Attach shared fallback chain from config + let fallback_config: FallbackConfig = config.fallback.clone().into(); + let fallback_chain = Arc::new(FallbackChain::new(fallback_config)); + + let build_client = |slot_config: &crate::config::SlotConfig| { + let mut client = LlmClient::new(make_config(slot_config)) + .with_shared_concurrency(controller.clone()) + .with_shared_openai_client(openai_client.clone()) + .with_shared_fallback(fallback_chain.clone()); + if let Some(ref hub) = metrics { + client = client.with_shared_metrics(hub.clone()); + } + Arc::new(client) + }; + Self { - index: Arc::new(LlmClient::new(configs.index)), - retrieval: Arc::new(LlmClient::new(configs.retrieval)), - pilot: Arc::new(LlmClient::new(configs.pilot)), - concurrency: None, + index: build_client(&config.index), + retrieval: build_client(&config.retrieval), + pilot: build_client(&config.pilot), } } /// Create a pool with default configurations. - /// - /// Uses auto-detected models based on available API keys: - /// - OpenAI: gpt-4o-mini for summary/toc, gpt-4o for retrieval - /// - Anthropic: claude-3-haiku for summary/toc, claude-3-sonnet for retrieval - /// - Default: glm-4-flash for summary/toc, glm-4 for retrieval pub fn from_defaults() -> Self { - Self::new(LlmConfigs::default()) - } - - /// Add concurrency control to all clients in the pool. - /// - /// All clients share the same ConcurrencyController, which means - /// rate limiting and concurrency limits are applied globally - /// across all LLM operations. - /// - /// # Example - /// - /// ```rust,no_run - /// use vectorless::llm::LlmPool; - /// use vectorless::throttle::{ConcurrencyController, ConcurrencyConfig}; - /// - /// let config = ConcurrencyConfig::new() - /// .with_max_concurrent_requests(10) - /// .with_requests_per_minute(500); - /// - /// let pool = LlmPool::from_defaults() - /// .with_concurrency(ConcurrencyController::new(config)); - /// ``` - pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self { - let arc = Arc::new(controller); - self.concurrency = Some(arc.clone()); - self.index = Arc::new( - LlmClient::new(self.index.config().clone()).with_shared_concurrency(arc.clone()), - ); - self.retrieval = Arc::new( - LlmClient::new(self.retrieval.config().clone()).with_shared_concurrency(arc.clone()), - ); - self.pilot = Arc::new( - LlmClient::new(self.pilot.config().clone()).with_shared_concurrency(arc.clone()), - ); - self - } - - /// Add concurrency control from an existing Arc. - pub fn with_shared_concurrency(mut self, controller: Arc) -> Self { - self.concurrency = Some(controller.clone()); - self.index = Arc::new( - LlmClient::new(self.index.config().clone()).with_shared_concurrency(controller.clone()), - ); - self.retrieval = Arc::new( - LlmClient::new(self.retrieval.config().clone()) - .with_shared_concurrency(controller.clone()), - ); - self.pilot = Arc::new( - LlmClient::new(self.pilot.config().clone()).with_shared_concurrency(controller.clone()), - ); - self - } - - /// Get the concurrency controller (if any). - pub fn concurrency(&self) -> Option<&ConcurrencyController> { - self.concurrency.as_deref() + Self::from_config(&crate::config::LlmConfig::default(), None) } /// Get the index client. - /// - /// Used for document indexing and summarization. - /// Typically uses a fast, cost-effective model. pub fn index(&self) -> &LlmClient { &self.index } /// Get the retrieval client. - /// - /// Used for document navigation and retrieval. - /// Typically uses a more capable model for better navigation decisions. pub fn retrieval(&self) -> &LlmClient { &self.retrieval } /// Get the pilot client. - /// - /// Used for intelligent navigation guidance. - /// Typically uses a fast model for quick decisions. pub fn pilot(&self) -> &LlmClient { &self.pilot } - - /// Get a client for a specific purpose by name. - /// - /// # Arguments - /// - /// * `purpose` - One of: "index", "summary", "retrieval", "retrieve", "navigate", "pilot" - /// - /// # Returns - /// - /// Returns `None` if the purpose is not recognized. - pub fn get(&self, purpose: &str) -> Option<&LlmClient> { - match purpose { - "index" | "summary" | "summarize" => Some(&self.index), - "retrieval" | "retrieve" | "navigate" => Some(&self.retrieval), - "pilot" => Some(&self.pilot), - _ => None, - } - } } impl Default for LlmPool { @@ -179,38 +138,50 @@ mod tests { use super::*; #[test] - fn test_pool_creation() { - let pool = LlmPool::from_defaults(); - - // Should have all clients - assert!(pool.get("index").is_some()); - assert!(pool.get("retrieval").is_some()); - assert!(pool.get("pilot").is_some()); - assert!(pool.get("unknown").is_none()); + fn test_pool_from_config() { + let config = crate::config::LlmConfig::new("gpt-4o") + .with_api_key("sk-test") + .with_endpoint("https://api.openai.com/v1") + .with_index(crate::config::SlotConfig::fast().with_model("gpt-4o-mini")); + + let pool = LlmPool::from_config(&config, None); + + assert_eq!(pool.index().config().model, "gpt-4o-mini"); + assert_eq!(pool.retrieval().config().model, "gpt-4o"); + assert_eq!(pool.pilot().config().model, "gpt-4o"); + assert_eq!(pool.index().config().max_tokens, 100); } #[test] - fn test_pool_get_aliases() { - let pool = LlmPool::from_defaults(); - - // Test aliases - assert!(pool.get("summary").is_some()); - assert!(pool.get("summarize").is_some()); - assert!(pool.get("retrieve").is_some()); - assert!(pool.get("navigate").is_some()); + fn test_pool_from_config_with_metrics() { + let config = crate::config::LlmConfig::new("gpt-4o") + .with_api_key("sk-test") + .with_endpoint("https://api.openai.com/v1"); + + let hub = MetricsHub::shared(); + let pool = LlmPool::from_config(&config, Some(hub.clone())); + + // Verify each client has fallback (which means executor was built correctly) + assert!(pool.index().fallback().is_some()); + assert!(pool.retrieval().fallback().is_some()); + assert!(pool.pilot().fallback().is_some()); + + // Verify models are resolved correctly + assert_eq!(pool.index().config().model, "gpt-4o"); + assert_eq!(pool.retrieval().config().model, "gpt-4o"); + assert_eq!(pool.pilot().config().model, "gpt-4o"); } #[test] - fn test_pool_with_concurrency() { - use crate::throttle::ConcurrencyConfig; + fn test_pool_shared_metrics_hub() { + let config = crate::config::LlmConfig::new("gpt-4o") + .with_api_key("sk-test") + .with_endpoint("https://api.openai.com/v1"); - let controller = ConcurrencyController::new(ConcurrencyConfig::conservative()); - let pool = LlmPool::from_defaults().with_concurrency(controller); + let hub = MetricsHub::shared(); + let _pool = LlmPool::from_config(&config, Some(hub.clone())); - // All clients should have concurrency enabled - assert!(pool.concurrency().is_some()); - assert!(pool.index().concurrency().is_some()); - assert!(pool.retrieval().concurrency().is_some()); - assert!(pool.pilot().concurrency().is_some()); + // Hub is shared with all three clients — Arc refcount > 1 + assert!(Arc::strong_count(&hub) > 1); } } diff --git a/rust/src/llm/retry.rs b/rust/src/llm/retry.rs deleted file mode 100644 index e0fdb19e..00000000 --- a/rust/src/llm/retry.rs +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Retry logic for LLM calls. - -use std::future::Future; -use tracing::{debug, warn}; - -use super::config::RetryConfig; -use super::error::{LlmError, LlmResult}; - -/// Execute an async operation with retry logic. -/// -/// This function implements exponential backoff retry for operations -/// that may fail with transient errors (rate limits, timeouts, etc.). -/// -/// # Example -/// -/// ```rust,ignore -/// use vectorless::llm::{RetryConfig, with_retry, LlmError, LlmResult}; -/// -/// # #[tokio::main] -/// # async fn main() -> LlmResult<()> { -/// let config = RetryConfig::default(); -/// -/// let result = with_retry(&config, || async { -/// // Some operation that might fail -/// Ok::<_, LlmError>("success".to_string()) -/// }).await?; -/// -/// # Ok(()) -/// # } -/// ``` -pub async fn with_retry(config: &RetryConfig, operation: F) -> LlmResult -where - F: Fn() -> Fut, - Fut: Future>, -{ - let mut attempts = 0; - - loop { - attempts += 1; - - match operation().await { - Ok(result) => { - if attempts > 1 { - debug!("Retry succeeded on attempt {}", attempts); - } - return Ok(result); - } - Err(e) => { - // Check if we should retry - if !should_retry(&e, config) { - return Err(e); - } - - // Check if we've exhausted retries - if attempts >= config.max_attempts { - warn!( - attempts = attempts, - max_attempts = config.max_attempts, - "Retry exhausted" - ); - return Err(LlmError::RetryExhausted { - attempts, - last_error: e.to_string(), - }); - } - - // Calculate delay for this attempt (0-indexed for delay calculation) - let delay = config.delay_for_attempt(attempts - 1); - warn!( - attempt = attempts, - max_attempts = config.max_attempts, - delay_ms = delay.as_millis() as u64, - error = %e, - "LLM call failed, retrying..." - ); - - tokio::time::sleep(delay).await; - } - } - } -} - -/// Determine if an error should trigger a retry. -fn should_retry(error: &LlmError, config: &RetryConfig) -> bool { - match error { - LlmError::RateLimit(_) => config.retry_on_rate_limit, - LlmError::Timeout(_) => true, - LlmError::Api(msg) => { - let msg_lower = msg.to_lowercase(); - // Check for retryable API errors - msg_lower.contains("rate limit") - || msg_lower.contains("429") - || msg_lower.contains("503") - || msg_lower.contains("502") - || msg_lower.contains("timeout") - || msg_lower.contains("overloaded") - } - _ => false, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::atomic::{AtomicU32, Ordering}; - - #[tokio::test] - async fn test_retry_success_on_second_attempt() { - let config = RetryConfig::new().with_max_attempts(3); - let attempts = AtomicU32::new(0); - - let result = with_retry(&config, || async { - let current = attempts.fetch_add(1, Ordering::SeqCst) + 1; - if current < 2 { - Err(LlmError::Timeout("timeout".to_string())) - } else { - Ok("success") - } - }) - .await; - - assert_eq!(result.unwrap(), "success"); - assert_eq!(attempts.load(Ordering::SeqCst), 2); - } - - #[tokio::test] - async fn test_retry_max_attempts_reached() { - let config = RetryConfig::new().with_max_attempts(2); - let attempts = AtomicU32::new(0); - - let result: LlmResult = with_retry(&config, || async { - attempts.fetch_add(1, Ordering::SeqCst); - Err(LlmError::Timeout("timeout".to_string())) - }) - .await; - - assert!(matches!(result, Err(LlmError::RetryExhausted { .. }))); - assert_eq!(attempts.load(Ordering::SeqCst), 2); - } - - #[tokio::test] - async fn test_non_retryable_error_fails_immediately() { - let config = RetryConfig::new().with_max_attempts(3); - let attempts = AtomicU32::new(0); - - let result: LlmResult = with_retry(&config, || async { - attempts.fetch_add(1, Ordering::SeqCst); - Err(LlmError::Config("bad config".to_string())) - }) - .await; - - assert!(matches!(result, Err(LlmError::Config(_)))); - assert_eq!(attempts.load(Ordering::SeqCst), 1); // Should only try once - } -} diff --git a/rust/src/llm/throttle.rs b/rust/src/llm/throttle.rs new file mode 100644 index 00000000..5de96743 --- /dev/null +++ b/rust/src/llm/throttle.rs @@ -0,0 +1,259 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Concurrency control for LLM API calls. +//! +//! Combines semaphore (concurrency limit) with token-bucket rate limiter (RPM). + +use std::num::NonZeroU32; +use std::sync::Arc; + +use governor::{ + Quota, RateLimiter as GovernorLimiter, + clock::{Clock, DefaultClock}, + state::{InMemoryState, NotKeyed}, +}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{Semaphore, SemaphorePermit}; +use tracing::{debug, trace}; + +// ============================================================ +// ConcurrencyConfig +// ============================================================ + +/// Concurrency control configuration. +/// +/// Controls how LLM requests are rate-limited and throttled +/// to avoid overwhelming the API. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConcurrencyConfig { + /// Maximum concurrent LLM API calls. + #[serde(default = "default_max_concurrent_requests")] + pub max_concurrent_requests: usize, + + /// Rate limit: requests per minute (token bucket). + #[serde(default = "default_requests_per_minute")] + pub requests_per_minute: usize, + + /// Whether rate limiting is enabled. + #[serde(default = "default_true")] + pub enabled: bool, + + /// Whether semaphore-based concurrency limiting is enabled. + #[serde(default = "default_true")] + pub semaphore_enabled: bool, +} + +fn default_max_concurrent_requests() -> usize { + 10 +} +fn default_requests_per_minute() -> usize { + 500 +} +fn default_true() -> bool { + true +} + +impl Default for ConcurrencyConfig { + fn default() -> Self { + Self { + max_concurrent_requests: default_max_concurrent_requests(), + requests_per_minute: default_requests_per_minute(), + enabled: true, + semaphore_enabled: true, + } + } +} + +impl ConcurrencyConfig { + /// Create a new config with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Set the maximum concurrent requests. + pub fn with_max_concurrent_requests(mut self, max: usize) -> Self { + self.max_concurrent_requests = max; + self + } + + /// Set the requests per minute rate limit. + pub fn with_requests_per_minute(mut self, rpm: usize) -> Self { + self.requests_per_minute = rpm; + self + } + + /// Enable or disable rate limiting. + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Create a config for conservative scenarios. + pub fn conservative() -> Self { + Self { + max_concurrent_requests: 5, + requests_per_minute: 100, + enabled: true, + semaphore_enabled: true, + } + } + + /// Create a config that disables all limits. + pub fn unlimited() -> Self { + Self { + max_concurrent_requests: usize::MAX, + requests_per_minute: usize::MAX, + enabled: false, + semaphore_enabled: false, + } + } +} + +// ============================================================ +// ConcurrencyController +// ============================================================ + +/// Concurrency controller for LLM API calls. +/// +/// Combines: +/// - **Rate Limiter** — Token bucket to limit requests per time period +/// - **Semaphore** — Limit concurrent requests +/// +/// The only operation needed by business code is [`acquire()`](ConcurrencyController::acquire). +#[derive(Clone)] +pub struct ConcurrencyController { + semaphore: Arc, + rate_limiter: Option>>, + semaphore_enabled: bool, +} + +impl ConcurrencyController { + /// Create a new concurrency controller with the given configuration. + pub fn new(config: ConcurrencyConfig) -> Self { + let semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests)); + let rate_limiter = if config.enabled { + let rpm = NonZeroU32::new(config.requests_per_minute as u32) + .unwrap_or_else(|| NonZeroU32::new(1).unwrap()); + Some(Arc::new(GovernorLimiter::direct(Quota::per_minute(rpm)))) + } else { + None + }; + + Self { + semaphore, + rate_limiter, + semaphore_enabled: config.semaphore_enabled, + } + } + + /// Create a controller with default configuration. + pub fn with_defaults() -> Self { + Self::new(ConcurrencyConfig::default()) + } + + /// Acquire a permit for making an LLM request. + /// + /// This will: + /// 1. Wait for the rate limiter (if enabled) + /// 2. Acquire a semaphore permit (if enabled) + /// + /// The permit is automatically released when dropped. + pub async fn acquire(&self) -> Option> { + // Step 1: Wait for rate limiter + if let Some(ref limiter) = self.rate_limiter { + let clock = DefaultClock::default(); + loop { + match limiter.check() { + Ok(_) => { + trace!("Rate limiter: token acquired"); + break; + } + Err(negative) => { + let wait_duration = negative.wait_time_from(clock.now()); + trace!( + wait_ms = wait_duration.as_millis() as u64, + "Rate limiter: waiting for token" + ); + tokio::time::sleep(wait_duration).await; + } + } + } + debug!("Rate limiter: token acquired"); + } + + // Step 2: Acquire semaphore permit + if self.semaphore_enabled { + trace!("Waiting for semaphore permit"); + let permit = self + .semaphore + .acquire() + .await + .expect("semaphore should not be closed"); + debug!( + "Semaphore: permit acquired (available: {})", + self.semaphore.available_permits() + ); + Some(permit) + } else { + None + } + } +} + +impl std::fmt::Debug for ConcurrencyController { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConcurrencyController") + .field("available_permits", &self.semaphore.available_permits()) + .field("has_rate_limiter", &self.rate_limiter.is_some()) + .field("semaphore_enabled", &self.semaphore_enabled) + .finish() + } +} + +impl Default for ConcurrencyController { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_controller_acquire() { + let controller = ConcurrencyController::new(ConcurrencyConfig { + max_concurrent_requests: 2, + requests_per_minute: 100, + enabled: false, + semaphore_enabled: true, + }); + + let permit1 = controller.acquire().await; + assert!(permit1.is_some()); + + let permit2 = controller.acquire().await; + assert!(permit2.is_some()); + + drop(permit1); + } + + #[test] + fn test_controller_creation() { + let controller = ConcurrencyController::with_defaults(); + assert!(controller.semaphore.available_permits() > 0); + } + + #[test] + fn test_rate_limiter_creation() { + let config = ConcurrencyConfig { + max_concurrent_requests: 10, + requests_per_minute: 100, + enabled: true, + semaphore_enabled: true, + }; + let controller = ConcurrencyController::new(config); + assert!(controller.rate_limiter.is_some()); + } +} diff --git a/rust/src/memo/mod.rs b/rust/src/memo/mod.rs deleted file mode 100644 index 50523c16..00000000 --- a/rust/src/memo/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! LLM Memoization system for caching expensive LLM calls. -//! -//! This module provides a caching layer for LLM-generated content, -//! enabling significant cost savings by avoiding redundant API calls. -//! -//! # Key Features -//! -//! - **Operation-based caching**: Cache summaries, pilot decisions, query results -//! - **Content-addressed**: Keys are based on content fingerprints -//! - **TTL support**: Optional time-to-live for cache entries -//! - **Persistence**: Save/load cache to disk for cross-session reuse -//! -//! # Usage -//! -//! ```rust,ignore -//! use vectorless::memo::{MemoStore, MemoKey, MemoOpType}; -//! -//! // Create a memo store -//! let mut store = MemoStore::new(1000); -//! -//! // Get or compute a summary -//! let key = MemoKey::summary(&node_fingerprint); -//! let summary = store.get_or_compute(key, || async { -//! llm_client.generate_summary(node).await -//! }).await?; -//! ``` - -mod store; -mod types; - -pub use store::MemoStore; -pub use types::{MemoKey, MemoValue, PilotDecisionValue}; diff --git a/rust/src/metrics/hub.rs b/rust/src/metrics/hub.rs index ee6e14af..c00471cc 100644 --- a/rust/src/metrics/hub.rs +++ b/rust/src/metrics/hub.rs @@ -297,17 +297,6 @@ impl MetricsReport { pub fn total_cost_usd(&self) -> f64 { self.llm.estimated_cost_usd } - - /// Calculate overall success rate. - pub fn overall_success_rate(&self) -> f64 { - let llm_rate = self.llm.success_rate; - let pilot_rate = if self.pilot.total_decisions > 0 { - self.pilot.accuracy - } else { - 1.0 - }; - (llm_rate + pilot_rate) / 2.0 - } } #[cfg(test)] @@ -354,4 +343,64 @@ mod tests { let report = hub.generate_report(); assert_eq!(report.llm.total_calls, 0); } + + #[test] + fn test_llm_metrics_success_and_failure() { + let hub = MetricsHub::with_defaults(); + + // Record successes + hub.record_llm_call(100, 50, 150, true); + hub.record_llm_call(200, 100, 300, true); + + // Record failure + hub.record_llm_call(0, 0, 50, false); + + let report = hub.llm_report(); + assert_eq!(report.total_calls, 3); + assert_eq!(report.successful_calls, 2); + assert_eq!(report.failed_calls, 1); + assert!((report.success_rate - 0.666).abs() < 0.01); + assert_eq!(report.total_input_tokens, 300); + assert_eq!(report.total_output_tokens, 150); + } + + #[test] + fn test_llm_error_events() { + let hub = MetricsHub::with_defaults(); + + hub.record_llm_rate_limit(); + hub.record_llm_rate_limit(); + hub.record_llm_timeout(); + hub.record_llm_fallback(); + + let report = hub.llm_report(); + assert_eq!(report.rate_limit_errors, 2); + assert_eq!(report.timeout_errors, 1); + assert_eq!(report.fallback_triggers, 1); + } + + #[test] + fn test_shared_arc_metrics() { + let hub = MetricsHub::shared(); + + // Clone the Arc — both references point to the same hub + let hub2 = hub.clone(); + hub.record_llm_call(100, 50, 100, true); + hub2.record_llm_call(200, 100, 200, true); + + let report = hub.generate_report(); + assert_eq!(report.llm.total_calls, 2); + assert_eq!(report.llm.total_input_tokens, 300); + } + + #[test] + fn test_metrics_report_cost() { + let hub = MetricsHub::with_defaults(); + + hub.record_llm_call(1000, 500, 200, true); + + let report = hub.generate_report(); + // Cost should be positive (exact value depends on config pricing) + assert!(report.total_cost_usd() >= 0.0); + } } diff --git a/rust/src/metrics/pilot.rs b/rust/src/metrics/pilot.rs index 8b424935..f8365e45 100644 --- a/rust/src/metrics/pilot.rs +++ b/rust/src/metrics/pilot.rs @@ -22,16 +22,6 @@ pub enum InterventionPoint { Prune, } -/// Helper to store f64 as u64 bits for atomic operations. -fn f64_to_u64_bits(v: f64) -> u64 { - v.to_bits() -} - -/// Helper to convert u64 bits back to f64. -fn u64_bits_to_f64(v: u64) -> f64 { - f64::from_bits(v) -} - /// Pilot metrics tracker. #[derive(Debug, Default)] pub struct PilotMetrics { diff --git a/rust/src/retrieval/complexity/detector.rs b/rust/src/retrieval/complexity/detector.rs index 602da79c..74a14918 100644 --- a/rust/src/retrieval/complexity/detector.rs +++ b/rust/src/retrieval/complexity/detector.rs @@ -9,6 +9,8 @@ use std::collections::HashSet; use super::QueryComplexity; +use crate::llm::memo::{MemoKey, MemoOpType, MemoStore, MemoValue}; +use crate::utils::fingerprint::Fingerprint; /// Query complexity detector. /// @@ -16,33 +18,96 @@ use super::QueryComplexity; pub struct ComplexityDetector { /// Optional LLM client for LLM-based detection. llm_client: Option, + /// Memo store for caching complexity detection results. + memo_store: Option, } impl ComplexityDetector { /// Create a new complexity detector (heuristic only). pub fn new() -> Self { - Self { llm_client: None } + Self { + llm_client: None, + memo_store: None, + } } /// Create with LLM client for accurate detection. pub fn with_llm_client(client: crate::llm::LlmClient) -> Self { Self { llm_client: Some(client), + memo_store: None, } } + /// Add memo store for caching complexity detection results. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Detect the complexity of a query. /// /// Uses LLM when available; falls back to heuristic rules. pub async fn detect(&self, query: &str) -> QueryComplexity { - if let Some(ref client) = self.llm_client { - if let Some(complexity) = crate::retrieval::pilot::detect_with_llm(client, query).await + // Check memo cache + if let Some(ref store) = self.memo_store { + let cache_key = Self::build_cache_key(query); + if let Some(cached) = store.get(&cache_key) { + if let Some(complexity) = Self::deserialize_complexity(&cached) { + return complexity; + } + } + } + + let result = if let Some(ref client) = self.llm_client { + if let Some(complexity) = + crate::retrieval::pilot::detect_with_llm(client, query).await { - return complexity; + complexity + } else { + tracing::warn!("LLM complexity detection failed, falling back to heuristic"); + self.detect_heuristic(query) } - tracing::warn!("LLM complexity detection failed, falling back to heuristic"); + } else { + self.detect_heuristic(query) + }; + + // Cache the result + if let Some(ref store) = self.memo_store { + let cache_key = Self::build_cache_key(query); + store.put_with_tokens( + cache_key, + MemoValue::Text(format!("{:?}", result)), + (query.len() / 4) as u64, + ); + } + + result + } + + /// Build a cache key for complexity detection. + fn build_cache_key(query: &str) -> MemoKey { + let fp = Fingerprint::from_str(query); + MemoKey { + op_type: MemoOpType::ComplexityDetection, + input_fp: fp, + model_id: None, + version: 1, + context_fp: Fingerprint::zero(), + } + } + + /// Deserialize a QueryComplexity from a MemoValue. + fn deserialize_complexity(value: &MemoValue) -> Option { + match value { + MemoValue::Text(s) => match s.as_str() { + "Simple" => Some(QueryComplexity::Simple), + "Medium" => Some(QueryComplexity::Medium), + "Complex" => Some(QueryComplexity::Complex), + _ => None, + }, + _ => None, } - self.detect_heuristic(query) } /// Heuristic-based fallback: keyword matching + word count. diff --git a/rust/src/retrieval/decompose.rs b/rust/src/retrieval/decompose.rs index da928ab1..c596b51e 100644 --- a/rust/src/retrieval/decompose.rs +++ b/rust/src/retrieval/decompose.rs @@ -48,6 +48,8 @@ use serde::{Deserialize, Serialize}; use tracing::{debug, info}; use crate::llm::{LlmClient, LlmExecutor}; +use crate::llm::memo::{MemoKey, MemoOpType, MemoStore, MemoValue}; +use crate::utils::fingerprint::Fingerprint; /// Sub-query resulting from decomposition. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -204,6 +206,8 @@ pub struct QueryDecomposer { llm_client: Option, /// LLM executor for unified execution (optional). llm_executor: Option, + /// Memo store for caching decomposition results. + memo_store: Option, } impl Default for QueryDecomposer { @@ -219,6 +223,7 @@ impl QueryDecomposer { config, llm_client: None, llm_executor: None, + memo_store: None, } } @@ -234,6 +239,12 @@ impl QueryDecomposer { self } + /// Add memo store for caching decomposition results. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Decompose a query into sub-queries. pub async fn decompose(&self, query: &str) -> crate::error::Result { // Check if decomposition is needed @@ -244,23 +255,71 @@ impl QueryDecomposer { )); } + // Check memo cache + if let Some(ref store) = self.memo_store { + let cache_key = Self::build_cache_key(query); + if let Some(cached) = store.get(&cache_key) { + if let Some(result) = Self::deserialize_decomposition(&cached) { + tracing::debug!("Memo cache hit for query decomposition"); + return Ok(result); + } + } + } + info!("Decomposing complex query: '{}'", query); // Try LLM-based decomposition if available - if self.config.use_llm && (self.llm_client.is_some() || self.llm_executor.is_some()) { + let result = if self.config.use_llm && (self.llm_client.is_some() || self.llm_executor.is_some()) { match self.llm_decompose(query).await { - Ok(result) => return Ok(result), + Ok(result) => result, Err(e) => { debug!( "LLM decomposition failed, falling back to rule-based: {}", e ); + self.rule_based_decompose(query)? } } + } else { + self.rule_based_decompose(query)? + }; + + // Cache the result + if let Some(ref store) = self.memo_store { + let cache_key = Self::build_cache_key(query); + if let Ok(json) = serde_json::to_value(&CachedDecomposition::from_result(&result)) { + store.put_with_tokens( + cache_key, + MemoValue::Json(json), + (query.len() / 4) as u64, + ); + } } - // Fall back to rule-based decomposition - self.rule_based_decompose(query) + Ok(result) + } + + /// Build a cache key for query decomposition. + fn build_cache_key(query: &str) -> MemoKey { + let fp = Fingerprint::from_str(query); + MemoKey { + op_type: MemoOpType::QueryDecomposition, + input_fp: fp, + model_id: None, + version: 1, + context_fp: Fingerprint::zero(), + } + } + + /// Deserialize a DecompositionResult from a MemoValue. + fn deserialize_decomposition(value: &MemoValue) -> Option { + match value { + MemoValue::Json(json) => { + let cached: CachedDecomposition = serde_json::from_value(json.clone()).ok()?; + Some(cached.into_result()) + } + _ => None, + } } /// Check if a query should be decomposed. @@ -536,6 +595,72 @@ fn extract_json(text: &str) -> String { text.to_string() } +/// Serializable decomposition result for caching. +/// +/// Only caches the essential fields needed to reconstruct a DecompositionResult. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CachedSubQuery { + text: String, + priority: u8, + query_type: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CachedDecomposition { + original: String, + sub_queries: Vec, + was_decomposed: bool, + reason: String, +} + +impl CachedDecomposition { + fn from_result(result: &DecompositionResult) -> Self { + Self { + original: result.original.clone(), + sub_queries: result + .sub_queries + .iter() + .map(|sq| CachedSubQuery { + text: sq.text.clone(), + priority: sq.priority, + query_type: format!("{:?}", sq.query_type), + }) + .collect(), + was_decomposed: result.was_decomposed, + reason: result.reason.clone(), + } + } + + fn into_result(self) -> DecompositionResult { + let sub_queries: Vec = self + .sub_queries + .into_iter() + .map(|csq| SubQuery { + text: csq.text, + priority: csq.priority, + query_type: match csq.query_type.as_str() { + "Fact" => SubQueryType::Fact, + "Explanation" => SubQueryType::Explanation, + "Comparison" => SubQueryType::Comparison, + "Synthesis" => SubQueryType::Synthesis, + "Navigation" => SubQueryType::Navigation, + _ => SubQueryType::Fact, + }, + complexity: SubQueryComplexity::Simple, + depends_on: vec![], + path_constraint: None, + }) + .collect(); + DecompositionResult { + original: self.original, + sub_queries, + was_decomposed: self.was_decomposed, + reason: self.reason, + total_complexity: 0.5, + } + } +} + /// Result aggregator for multi-turn retrieval. #[derive(Debug, Clone)] pub struct SubQueryResult { diff --git a/rust/src/retrieval/pilot/llm_pilot.rs b/rust/src/retrieval/pilot/llm_pilot.rs index e252da96..f8df6536 100644 --- a/rust/src/retrieval/pilot/llm_pilot.rs +++ b/rust/src/retrieval/pilot/llm_pilot.rs @@ -12,7 +12,7 @@ use tracing::{debug, info, warn}; use crate::document::{DocumentTree, NodeId}; use crate::llm::{LlmClient, LlmExecutor}; -use crate::memo::{MemoKey, MemoStore, MemoValue}; +use crate::llm::memo::{MemoKey, MemoStore, MemoValue}; use crate::utils::fingerprint::Fingerprint; use super::budget::BudgetController; @@ -442,8 +442,8 @@ impl LlmPilot { fn decision_to_cached_value( &self, decision: &PilotDecision, - ) -> crate::memo::PilotDecisionValue { - crate::memo::PilotDecisionValue { + ) -> crate::llm::memo::PilotDecisionValue { + crate::llm::memo::PilotDecisionValue { selected_idx: decision .ranked_candidates .first() @@ -457,7 +457,7 @@ impl LlmPilot { /// Convert a cached value back to a PilotDecision. fn cached_value_to_decision( &self, - value: crate::memo::PilotDecisionValue, + value: crate::llm::memo::PilotDecisionValue, candidates: &[super::parser::CandidateInfo], point: InterventionPoint, ) -> PilotDecision { diff --git a/rust/src/retrieval/pipeline_retriever.rs b/rust/src/retrieval/pipeline_retriever.rs index 2a655182..9471bc3d 100644 --- a/rust/src/retrieval/pipeline_retriever.rs +++ b/rust/src/retrieval/pipeline_retriever.rs @@ -13,11 +13,12 @@ use super::content::ContentAggregatorConfig; use super::pipeline::RetrievalOrchestrator; use super::retriever::{CostEstimate, Retriever, RetrieverError, RetrieverResult}; use super::stages::{AnalyzeStage, EvaluateStage, PlanStage, SearchStage}; +use super::strategy::LlmStrategy; use super::stream::RetrieveEventReceiver; use super::types::{RetrieveOptions, RetrieveResponse}; use crate::document::{DocumentTree, ReasoningIndex}; use crate::llm::LlmClient; -use crate::memo::MemoStore; +use crate::llm::memo::MemoStore; use crate::retrieval::pilot::{LlmPilot, PilotConfig}; /// Pipeline-based retriever using the stage architecture. @@ -112,6 +113,9 @@ impl PipelineRetriever { if let Some(ref client) = self.llm_client { analyze_stage = analyze_stage.with_llm_client(client.clone()); } + if let Some(ref store) = self.memo_store { + analyze_stage = analyze_stage.with_memo_store(store.clone()); + } orchestrator = orchestrator.stage(analyze_stage); // Add plan stage @@ -133,11 +137,21 @@ impl PipelineRetriever { } search_stage = search_stage.with_pilot(Arc::new(pilot)); + + // Create LLM strategy with memo store for node evaluation + let mut llm_strategy = LlmStrategy::new(client.clone()); + if let Some(ref store) = self.memo_store { + llm_strategy = llm_strategy.with_memo_store(store.clone()); + } + search_stage = search_stage.with_llm_strategy(llm_strategy); } orchestrator = orchestrator.stage(search_stage); // Add evaluate stage with optional content aggregator let mut evaluate_stage = EvaluateStage::new(); + if let Some(ref store) = self.memo_store { + evaluate_stage = evaluate_stage.with_memo_store(store.clone()); + } if let Some(ref client) = self.llm_client { evaluate_stage = evaluate_stage.with_llm_judge(client.clone()); } diff --git a/rust/src/retrieval/search/toc_navigator.rs b/rust/src/retrieval/search/toc_navigator.rs index badf6444..77be5243 100644 --- a/rust/src/retrieval/search/toc_navigator.rs +++ b/rust/src/retrieval/search/toc_navigator.rs @@ -15,7 +15,7 @@ use tracing::{debug, info, warn}; use crate::document::DocumentTree; use crate::document::NodeId; use crate::llm::LlmClient; -use crate::memo::MemoStore; +use crate::llm::memo::MemoStore; use crate::retrieval::pilot::NodeScorer; /// A navigation cue produced by the ToCNavigator. diff --git a/rust/src/retrieval/stages/analyze.rs b/rust/src/retrieval/stages/analyze.rs index cc0e6d50..79e2a75e 100644 --- a/rust/src/retrieval/stages/analyze.rs +++ b/rust/src/retrieval/stages/analyze.rs @@ -13,6 +13,7 @@ use async_trait::async_trait; use tracing::info; use crate::document::{DocumentTree, NodeId, TocView}; +use crate::llm::memo::MemoStore; use crate::retrieval::complexity::ComplexityDetector; use crate::retrieval::decompose::{DecompositionConfig, QueryDecomposer}; use crate::retrieval::pipeline::{FailurePolicy, PipelineContext, RetrievalStage, StageOutcome}; @@ -101,6 +102,8 @@ pub struct AnalyzeStage { enable_decomposition: bool, /// Complexity threshold for triggering decomposition. decomposition_threshold: f32, + /// Memo store for caching LLM results. + memo_store: Option, } impl Default for AnalyzeStage { @@ -119,6 +122,7 @@ impl AnalyzeStage { query_decomposer: None, enable_decomposition: false, decomposition_threshold: 0.6, + memo_store: None, } } @@ -144,17 +148,30 @@ impl AnalyzeStage { self } + /// Add memo store for caching complexity detection and decomposition results. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Enable query decomposition and LLM-based complexity detection. pub fn with_llm_client(mut self, client: crate::llm::LlmClient) -> Self { // Use LLM client for complexity detection - self.complexity_detector = ComplexityDetector::with_llm_client(client.clone()); + let mut detector = ComplexityDetector::with_llm_client(client.clone()); + if let Some(ref store) = self.memo_store { + detector = detector.with_memo_store(store.clone()); + } + self.complexity_detector = detector; + // Also enable query decomposition + let mut decomposer = QueryDecomposer::new(DecompositionConfig::default()).with_llm_client(client); + if let Some(ref store) = self.memo_store { + decomposer = decomposer.with_memo_store(store.clone()); + } if self.query_decomposer.is_none() { - self.query_decomposer = - Some(QueryDecomposer::new(DecompositionConfig::default()).with_llm_client(client)); - } else if let Some(ref mut decomposer) = self.query_decomposer { - *decomposer = - QueryDecomposer::new(DecompositionConfig::default()).with_llm_client(client); + self.query_decomposer = Some(decomposer); + } else if let Some(ref mut d) = self.query_decomposer { + *d = decomposer; } self.enable_decomposition = true; self diff --git a/rust/src/retrieval/stages/evaluate.rs b/rust/src/retrieval/stages/evaluate.rs index d3dc2ee4..b008afb1 100644 --- a/rust/src/retrieval/stages/evaluate.rs +++ b/rust/src/retrieval/stages/evaluate.rs @@ -11,6 +11,7 @@ use async_trait::async_trait; use tracing::{info, warn}; use crate::llm::LlmClient; +use crate::llm::memo::MemoStore; use crate::retrieval::content::{ContentAggregator, ContentAggregatorConfig}; use crate::retrieval::pipeline::{FailurePolicy, PipelineContext, RetrievalStage, StageOutcome}; use crate::retrieval::sufficiency::{LlmJudge, SufficiencyChecker, ThresholdChecker}; @@ -46,6 +47,8 @@ pub struct EvaluateStage { use_llm_judge: bool, /// Optional content aggregator for precision-focused aggregation. content_aggregator: Option, + /// Memo store for caching LLM judgments. + memo_store: Option, } impl Default for EvaluateStage { @@ -63,16 +66,27 @@ impl EvaluateStage { max_iterations: 3, use_llm_judge: false, content_aggregator: None, + memo_store: None, } } /// Add LLM judge for more accurate sufficiency checking. pub fn with_llm_judge(mut self, client: LlmClient) -> Self { - self.llm_judge = Some(LlmJudge::new(Box::new(client))); + let mut judge = LlmJudge::new(Box::new(client)); + if let Some(ref store) = self.memo_store { + judge = judge.with_memo_store(store.clone()); + } + self.llm_judge = Some(judge); self.use_llm_judge = true; self } + /// Add memo store for caching LLM judgments. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Set maximum search iterations. pub fn with_max_iterations(mut self, n: usize) -> Self { self.max_iterations = n; diff --git a/rust/src/retrieval/strategy/llm.rs b/rust/src/retrieval/strategy/llm.rs index e22b8b43..10241ed8 100644 --- a/rust/src/retrieval/strategy/llm.rs +++ b/rust/src/retrieval/strategy/llm.rs @@ -15,6 +15,8 @@ use super::super::types::{NavigationDecision, QueryComplexity}; use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities}; use crate::document::{DocumentTree, NodeId, TocView}; use crate::llm::LlmClient; +use crate::llm::memo::{MemoKey, MemoOpType, MemoStore, MemoValue}; +use crate::utils::fingerprint::Fingerprint; /// LLM response for a single node in batch evaluation. #[derive(Debug, Clone, Deserialize)] @@ -85,6 +87,8 @@ pub struct LlmStrategy { toc_view: TocView, /// Whether to include ToC context in prompts. include_toc: bool, + /// Memo store for caching LLM evaluations. + memo_store: Option, } impl LlmStrategy { @@ -96,6 +100,7 @@ impl LlmStrategy { batch_system_prompt: Self::default_batch_system_prompt(), toc_view: TocView::new(), include_toc: true, + memo_store: None, } } @@ -116,6 +121,15 @@ impl LlmStrategy { self } + /// Add memo store for caching LLM evaluations. + /// + /// When enabled, node evaluations are cached based on prompt fingerprints, + /// avoiding redundant LLM calls for the same node+query combinations. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Default system prompt for single-node navigation. fn default_system_prompt() -> String { r#"You are a document navigation assistant. Your task is to help find the most relevant sections in a document tree. @@ -254,6 +268,50 @@ Rules: ) } + /// Build a memo cache key for a single node evaluation. + fn node_eval_cache_key(&self, node_id: NodeId, context: &RetrievalContext) -> MemoKey { + let mut parts = String::new(); + parts.push_str(&context.query); + parts.push_str(":node:"); + // Use the NodeId debug representation as part of the fingerprint + parts.push_str(&format!("{:?}", node_id)); + let fp = Fingerprint::from_str(&parts); + MemoKey { + op_type: MemoOpType::NodeEvaluation, + input_fp: fp, + model_id: None, + version: 1, + context_fp: Fingerprint::zero(), + } + } + + /// Build a memo cache key for a batch evaluation. + fn batch_eval_cache_key(&self, node_ids: &[NodeId], context: &RetrievalContext) -> MemoKey { + let mut parts = String::new(); + parts.push_str(&context.query); + parts.push_str(":batch:"); + for id in node_ids { + parts.push_str(&format!("{:?}", id)); + parts.push(','); + } + let fp = Fingerprint::from_str(&parts); + MemoKey { + op_type: MemoOpType::NodeEvaluation, + input_fp: fp, + model_id: None, + version: 1, + context_fp: Fingerprint::zero(), + } + } + + /// Try to deserialize a cached NodeEvaluation from MemoValue. + fn deserialize_cached_eval(&self, value: &MemoValue) -> Option { + match value { + MemoValue::Json(json) => serde_json::from_value(json.clone()).ok(), + _ => None, + } + } + /// Parse LLM response to evaluation for a single node. fn parse_response( &self, @@ -391,9 +449,20 @@ impl RetrievalStrategy for LlmStrategy { node_id: NodeId, context: &RetrievalContext, ) -> NodeEvaluation { + // Check memo cache + if let Some(ref store) = self.memo_store { + let cache_key = self.node_eval_cache_key(node_id, context); + if let Some(cached) = store.get(&cache_key) { + if let Some(eval) = self.deserialize_cached_eval(&cached) { + tracing::debug!("Memo cache hit for node evaluation (node={:?})", node_id); + return eval; + } + } + } + let prompt = self.build_prompt(tree, node_id, context); - match self.client.complete(&self.system_prompt, &prompt).await { + let result = match self.client.complete(&self.system_prompt, &prompt).await { Ok(response) => self.parse_response(&response, tree, node_id), Err(e) => { tracing::warn!("LLM evaluation failed: {}", e); @@ -407,7 +476,18 @@ impl RetrievalStrategy for LlmStrategy { reasoning: Some(format!("LLM error: {}", e)), } } + }; + + // Cache the result + if let Some(ref store) = self.memo_store { + let cache_key = self.node_eval_cache_key(node_id, context); + if let Ok(json) = serde_json::to_value(&result) { + let tokens = (prompt.len() / 4) as u64; + store.put_with_tokens(cache_key, MemoValue::Json(json), tokens); + } } + + result } async fn evaluate_nodes( @@ -425,10 +505,28 @@ impl RetrievalStrategy for LlmStrategy { return vec![self.evaluate_node(tree, node_ids[0], context).await]; } + // Check memo cache for the entire batch + if let Some(ref store) = self.memo_store { + let cache_key = self.batch_eval_cache_key(node_ids, context); + if let Some(cached) = store.get(&cache_key) { + if let MemoValue::Json(json) = &cached { + if let Ok(evals) = serde_json::from_value::>(json.clone()) { + if evals.len() == node_ids.len() { + tracing::debug!( + "Memo cache hit for batch evaluation ({} nodes)", + node_ids.len() + ); + return evals; + } + } + } + } + } + // Batch: send all nodes in one LLM call let prompt = self.build_batch_prompt(tree, node_ids, context); - match self + let result = match self .client .complete(&self.batch_system_prompt, &prompt) .await @@ -447,7 +545,18 @@ impl RetrievalStrategy for LlmStrategy { } results } + }; + + // Cache the batch result + if let Some(ref store) = self.memo_store { + let cache_key = self.batch_eval_cache_key(node_ids, context); + if let Ok(json) = serde_json::to_value(&result) { + let tokens = (prompt.len() / 4) as u64; + store.put_with_tokens(cache_key, MemoValue::Json(json), tokens); + } } + + result } fn name(&self) -> &'static str { diff --git a/rust/src/retrieval/strategy/trait.rs b/rust/src/retrieval/strategy/trait.rs index 895d60a2..a9b5958a 100644 --- a/rust/src/retrieval/strategy/trait.rs +++ b/rust/src/retrieval/strategy/trait.rs @@ -10,7 +10,7 @@ use super::super::types::{NavigationDecision, QueryComplexity}; use crate::document::{DocumentTree, NodeId}; /// Result of evaluating a single node. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct NodeEvaluation { /// Relevance score (0.0 - 1.0). pub score: f32, diff --git a/rust/src/retrieval/sufficiency/llm_judge.rs b/rust/src/retrieval/sufficiency/llm_judge.rs index df80379b..343153a5 100644 --- a/rust/src/retrieval/sufficiency/llm_judge.rs +++ b/rust/src/retrieval/sufficiency/llm_judge.rs @@ -10,6 +10,8 @@ use serde::{Deserialize, Serialize}; use super::{SufficiencyChecker, SufficiencyLevel}; use crate::config::SufficiencyConfig; +use crate::llm::memo::{MemoKey, MemoOpType, MemoStore, MemoValue}; +use crate::utils::fingerprint::Fingerprint; /// LLM client trait for the judge. #[async_trait] @@ -49,6 +51,8 @@ pub struct LlmJudge { system_prompt: String, /// Minimum confidence to consider sufficient. confidence_threshold: f32, + /// Memo store for caching sufficiency judgments. + memo_store: Option, } impl LlmJudge { @@ -63,9 +67,19 @@ impl LlmJudge { client, system_prompt: Self::default_system_prompt(), confidence_threshold: config.confidence_threshold, + memo_store: None, } } + /// Add memo store for caching sufficiency judgments. + /// + /// When enabled, sufficiency check results are cached based on + /// query+content fingerprints, avoiding redundant LLM calls. + pub fn with_memo_store(mut self, store: MemoStore) -> Self { + self.memo_store = Some(store); + self + } + /// Set confidence threshold. pub fn with_confidence_threshold(mut self, threshold: f32) -> Self { self.confidence_threshold = threshold; @@ -135,11 +149,65 @@ Be conservative - only mark as sufficient if you're confident the content answer content: &str, _token_count: usize, ) -> SufficiencyLevel { + // Check memo cache + if let Some(ref store) = self.memo_store { + let cache_key = self.build_cache_key(query, content); + if let Some(cached) = store.get(&cache_key) { + if let Some(level) = Self::deserialize_sufficiency(&cached) { + tracing::debug!("Memo cache hit for sufficiency check"); + return level; + } + } + } + let prompt = self.build_prompt(query, content); - match self.client.complete(&prompt).await { + let result = match self.client.complete(&prompt).await { Ok(response) => self.parse_response(&response).0, Err(_) => SufficiencyLevel::Insufficient, + }; + + // Cache the result + if let Some(ref store) = self.memo_store { + let cache_key = self.build_cache_key(query, content); + let tokens = (prompt.len() / 4) as u64; + store.put_with_tokens( + cache_key, + MemoValue::Text(format!("{:?}", result)), + tokens, + ); + } + + result + } + + /// Build a cache key for sufficiency check. + fn build_cache_key(&self, query: &str, content: &str) -> MemoKey { + let mut input = String::with_capacity(query.len() + content.len() / 4); + input.push_str(query); + // Use only first 2000 chars of content for fingerprint to avoid + // giant cache keys — content prefix captures topic identity. + input.push_str(&content[..2000.min(content.len())]); + let fp = Fingerprint::from_str(&input); + MemoKey { + op_type: MemoOpType::SufficiencyCheck, + input_fp: fp, + model_id: None, + version: 1, + context_fp: Fingerprint::zero(), + } + } + + /// Deserialize a SufficiencyLevel from a MemoValue. + fn deserialize_sufficiency(value: &MemoValue) -> Option { + match value { + MemoValue::Text(s) => match s.as_str() { + "Sufficient" => Some(SufficiencyLevel::Sufficient), + "PartialSufficient" => Some(SufficiencyLevel::PartialSufficient), + "Insufficient" => Some(SufficiencyLevel::Insufficient), + _ => None, + }, + _ => None, } } } diff --git a/rust/src/storage/mod.rs b/rust/src/storage/mod.rs index b10019ce..ca7c27f3 100644 --- a/rust/src/storage/mod.rs +++ b/rust/src/storage/mod.rs @@ -42,5 +42,5 @@ mod persistence; pub mod workspace; // Re-export main types -pub use persistence::{DocumentMeta, PersistedDocument}; +pub use persistence::{DocumentMeta, PageContent, PersistedDocument}; pub use workspace::Workspace; diff --git a/rust/src/storage/persistence.rs b/rust/src/storage/persistence.rs index b2be0030..b9d28317 100644 --- a/rust/src/storage/persistence.rs +++ b/rust/src/storage/persistence.rs @@ -22,6 +22,13 @@ use crate::error::Result; /// Current format version for persisted documents. const FORMAT_VERSION: u32 = 1; +/// Current schema version for `PersistedDocument`. +/// +/// Increment this when the document structure changes in a +/// backward-incompatible way (e.g. field renames, new required fields). +/// Old documents will be detected and logged as stale on load. +const SCHEMA_VERSION: u32 = 1; + /// Metadata for a persisted document. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DocumentMeta { @@ -203,6 +210,11 @@ impl DocumentMeta { /// A persisted document index containing tree and metadata. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PersistedDocument { + /// Schema version — incremented on backward-incompatible changes. + /// Old documents default to `0` via serde when the field is absent. + #[serde(default)] + pub schema_version: u32, + /// Document metadata. pub meta: DocumentMeta, @@ -222,6 +234,7 @@ impl PersistedDocument { /// Create a new persisted document. pub fn new(meta: DocumentMeta, tree: DocumentTree) -> Self { Self { + schema_version: SCHEMA_VERSION, meta, tree, pages: Vec::new(), @@ -441,6 +454,19 @@ pub fn load_document_with_options( let doc: PersistedDocument = serde_json::from_value(wrapper.payload) .map_err(|e| Error::Parse(format!("Failed to deserialize document: {}", e)))?; + // Check schema version — warn on stale documents, fail on future versions + if doc.schema_version == 0 { + tracing::warn!( + doc_id = %doc.meta.id, + "Document was created before schema versioning — consider re-indexing" + ); + } else if doc.schema_version > SCHEMA_VERSION { + return Err(Error::Parse(format!( + "Document schema version {} is newer than supported {} — please upgrade vectorless", + doc.schema_version, SCHEMA_VERSION + ))); + } + Ok(doc) } @@ -619,6 +645,19 @@ pub fn load_document_from_bytes_with_options( let doc: PersistedDocument = serde_json::from_value(wrapper.payload) .map_err(|e| Error::Parse(format!("Failed to deserialize document: {}", e)))?; + // Check schema version + if doc.schema_version == 0 { + tracing::warn!( + doc_id = %doc.meta.id, + "Document was created before schema versioning — consider re-indexing" + ); + } else if doc.schema_version > SCHEMA_VERSION { + return Err(Error::Parse(format!( + "Document schema version {} is newer than supported {} — please upgrade vectorless", + doc.schema_version, SCHEMA_VERSION + ))); + } + Ok(doc) } diff --git a/rust/src/throttle/config.rs b/rust/src/throttle/config.rs deleted file mode 100644 index 155e2e3a..00000000 --- a/rust/src/throttle/config.rs +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Concurrency control configuration types. - -use serde::{Deserialize, Serialize}; - -/// Concurrency control configuration. -/// -/// This controls how LLM requests are rate-limited and throttled -/// to avoid overwhelming the API. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConcurrencyConfig { - /// Maximum concurrent LLM API calls. - /// - /// This limits how many requests can be in-flight at the same time. - /// Default: 10 - #[serde(default = "default_max_concurrent_requests")] - pub max_concurrent_requests: usize, - - /// Rate limit: requests per minute. - /// - /// This is a soft limit using token bucket algorithm. - /// Default: 500 (OpenAI default tier) - #[serde(default = "default_requests_per_minute")] - pub requests_per_minute: usize, - - /// Alias for `enabled` - whether rate limiting is enabled. - /// - /// When disabled, only semaphore-based concurrency control is used. - /// Default: true - #[serde(default = "default_true")] - pub enabled: bool, - - /// Whether to enable concurrency limiting via semaphore. - /// - /// When disabled, only rate limiting is used. - /// Default: true - #[serde(default = "default_true")] - pub semaphore_enabled: bool, -} - -fn default_max_concurrent_requests() -> usize { - 10 -} -fn default_requests_per_minute() -> usize { - 500 -} -fn default_true() -> bool { - true -} - -impl Default for ConcurrencyConfig { - fn default() -> Self { - Self { - max_concurrent_requests: default_max_concurrent_requests(), - requests_per_minute: default_requests_per_minute(), - enabled: true, - semaphore_enabled: true, - } - } -} - -impl ConcurrencyConfig { - /// Create a new config with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Set the maximum concurrent requests. - pub fn with_max_concurrent_requests(mut self, max: usize) -> Self { - self.max_concurrent_requests = max; - self - } - - /// Set the requests per minute rate limit. - pub fn with_requests_per_minute(mut self, rpm: usize) -> Self { - self.requests_per_minute = rpm; - self - } - - /// Enable or disable rate limiting. - pub fn with_enabled(mut self, enabled: bool) -> Self { - self.enabled = enabled; - self - } - - /// Create a config for high-throughput scenarios. - /// - /// Uses higher limits suitable for paid API tiers. - pub fn high_throughput() -> Self { - Self { - max_concurrent_requests: 50, - requests_per_minute: 3000, - enabled: true, - semaphore_enabled: true, - } - } - - /// Create a config for conservative scenarios. - /// - /// Uses lower limits to avoid rate limit errors. - pub fn conservative() -> Self { - Self { - max_concurrent_requests: 5, - requests_per_minute: 100, - enabled: true, - semaphore_enabled: true, - } - } - - /// Create a config that disables all limits. - /// - /// Useful for testing or when external rate limiting is used. - pub fn unlimited() -> Self { - Self { - max_concurrent_requests: usize::MAX, - requests_per_minute: usize::MAX, - enabled: false, - semaphore_enabled: false, - } - } -} diff --git a/rust/src/throttle/controller.rs b/rust/src/throttle/controller.rs deleted file mode 100644 index 87193fa4..00000000 --- a/rust/src/throttle/controller.rs +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Concurrency controller combining semaphore and rate limiter. - -use std::sync::Arc; -use tokio::sync::{Semaphore, SemaphorePermit}; -use tracing::{debug, trace}; - -use super::config::ConcurrencyConfig; -use super::rate_limiter::RateLimiter; - -/// Concurrency controller for LLM API calls. -/// -/// Combines: -/// - **Rate Limiter** — Token bucket to limit requests per time period -/// - **Semaphore** — Limit concurrent requests -/// -/// # Example -/// -/// ```rust -/// use vectorless::throttle::{ConcurrencyController, ConcurrencyConfig}; -/// -/// # #[tokio::main] -/// # async fn main() { -/// let config = ConcurrencyConfig::default(); -/// let controller = ConcurrencyController::new(config); -/// -/// // Before making an API call -/// let permit = controller.acquire().await; -/// -/// // Make the API call... -/// drop(permit); // Release when done -/// # } -/// ``` -#[derive(Clone)] -pub struct ConcurrencyController { - /// Semaphore for limiting concurrent requests. - semaphore: Arc, - /// Rate limiter for throttling requests. - rate_limiter: Option>, - /// Configuration. - config: ConcurrencyConfig, -} - -impl ConcurrencyController { - /// Create a new concurrency controller with the given configuration. - pub fn new(config: ConcurrencyConfig) -> Self { - let semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests)); - let rate_limiter = if config.enabled { - Some(Arc::new(RateLimiter::new(config.requests_per_minute))) - } else { - None - }; - - Self { - semaphore, - rate_limiter, - config, - } - } - - /// Create a controller with default configuration. - pub fn with_defaults() -> Self { - Self::new(ConcurrencyConfig::default()) - } - - /// Create a controller for high-throughput scenarios. - pub fn high_throughput() -> Self { - Self::new(ConcurrencyConfig::high_throughput()) - } - - /// Create a controller for conservative scenarios. - pub fn conservative() -> Self { - Self::new(ConcurrencyConfig::conservative()) - } - - /// Create a controller with no limits. - pub fn unlimited() -> Self { - Self::new(ConcurrencyConfig::unlimited()) - } - - /// Acquire a permit for making an LLM request. - /// - /// This will: - /// 1. Wait for the rate limiter (if enabled) - /// 2. Acquire a semaphore permit (if enabled) - /// - /// The permit is automatically released when dropped. - pub async fn acquire(&self) -> Option> { - // Step 1: Wait for rate limiter - if let Some(ref limiter) = self.rate_limiter { - trace!("Waiting for rate limiter"); - limiter.acquire().await; - debug!("Rate limiter: token acquired"); - } - - // Step 2: Acquire semaphore permit - if self.config.semaphore_enabled { - trace!("Waiting for semaphore permit"); - let permit = self.semaphore.acquire().await.unwrap(); - debug!( - "Semaphore: permit acquired (available: {})", - self.semaphore.available_permits() - ); - Some(permit) - } else { - None - } - } - - /// Try to acquire a permit without waiting. - /// - /// Returns `None` if the limit is reached. - pub fn try_acquire(&self) -> Option> { - // Check rate limiter - if let Some(ref limiter) = self.rate_limiter { - if !limiter.try_acquire() { - return None; - } - } - - // Try to acquire semaphore - if self.config.semaphore_enabled { - self.semaphore.try_acquire().ok() - } else { - None - } - } - - /// Get the number of available semaphore permits. - pub fn available_permits(&self) -> usize { - self.semaphore.available_permits() - } - - /// Get the configuration. - pub fn config(&self) -> &ConcurrencyConfig { - &self.config - } - - /// Get the rate limiter (if any). - pub fn rate_limiter(&self) -> Option<&RateLimiter> { - self.rate_limiter.as_deref() - } -} - -impl std::fmt::Debug for ConcurrencyController { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ConcurrencyController") - .field( - "max_concurrent_requests", - &self.config.max_concurrent_requests, - ) - .field("requests_per_minute", &self.config.requests_per_minute) - .field("rate_limiting_enabled", &self.config.enabled) - .field("semaphore_enabled", &self.config.semaphore_enabled) - .field("available_permits", &self.semaphore.available_permits()) - .finish() - } -} - -impl Default for ConcurrencyController { - fn default() -> Self { - Self::with_defaults() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_controller_acquire() { - let controller = ConcurrencyController::new(ConcurrencyConfig { - max_concurrent_requests: 2, - requests_per_minute: 100, - enabled: false, // Disable rate limiting for faster test - semaphore_enabled: true, - }); - - let permit1 = controller.acquire().await; - assert!(permit1.is_some()); - assert_eq!(controller.available_permits(), 1); - - let permit2 = controller.acquire().await; - assert!(permit2.is_some()); - assert_eq!(controller.available_permits(), 0); - - drop(permit1); - assert_eq!(controller.available_permits(), 1); - } - - #[test] - fn test_controller_creation() { - let controller = ConcurrencyController::with_defaults(); - assert!(controller.available_permits() > 0); - } -} diff --git a/rust/src/throttle/mod.rs b/rust/src/throttle/mod.rs deleted file mode 100644 index 3bf6467f..00000000 --- a/rust/src/throttle/mod.rs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Concurrency control for LLM API calls. -//! -//! This module provides rate limiting and concurrency control to prevent -//! overwhelming LLM API endpoints: -//! -//! - **Rate Limiter** — Token bucket algorithm to limit requests per time period -//! - **Concurrency Controller** — Combined semaphore + rate limiter -//! -//! -//! # Example -//! -//! ```rust -//! use vectorless::throttle::{ConcurrencyController, ConcurrencyConfig}; -//! -//! # #[tokio::main] -//! # async fn main() { -//! // Create with default configuration -//! let controller = ConcurrencyController::with_defaults(); -//! -//! // Or customize -//! let config = ConcurrencyConfig::new() -//! .with_max_concurrent_requests(20) -//! .with_requests_per_minute(1000); -//! let controller = ConcurrencyController::new(config); -//! -//! // Before making an API call -//! let permit = controller.acquire().await; -//! -//! // Make the API call... -//! // Permit is automatically released when dropped -//! # } -//! ``` - -mod config; -mod controller; -mod rate_limiter; - -pub use config::ConcurrencyConfig; -pub use controller::ConcurrencyController; diff --git a/rust/src/throttle/rate_limiter.rs b/rust/src/throttle/rate_limiter.rs deleted file mode 100644 index 90a865e9..00000000 --- a/rust/src/throttle/rate_limiter.rs +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2026 vectorless developers -// SPDX-License-Identifier: Apache-2.0 - -//! Rate limiter using token bucket algorithm (governor). - -use governor::{ - Quota, RateLimiter as GovernorLimiter, - clock::{Clock, DefaultClock}, - state::{InMemoryState, NotKeyed}, -}; -use std::num::NonZeroU32; -use std::sync::Arc; -use tracing::trace; - -/// Rate limiter for API calls. -/// -/// Uses the governor library's token bucket algorithm to limit -/// the rate of API requests. -#[derive(Clone)] -pub struct RateLimiter { - inner: Arc>, - requests_per_minute: usize, -} - -impl RateLimiter { - /// Create a new rate limiter with the given requests per minute. - /// - /// # Arguments - /// - /// * `requests_per_minute` - Maximum number of requests allowed per minute. - /// - /// # Example - /// - /// ``` - /// use vectorless::throttle::RateLimiter; - /// - /// let limiter = RateLimiter::new(100); // 100 requests per minute - /// ``` - pub fn new(requests_per_minute: usize) -> Self { - let rpm = NonZeroU32::new(requests_per_minute as u32) - .unwrap_or_else(|| NonZeroU32::new(1).unwrap()); - - let quota = Quota::per_minute(rpm); - let inner = Arc::new(GovernorLimiter::direct(quota)); - - Self { - inner, - requests_per_minute, - } - } - - /// Wait until a request can be made. - /// - /// This method will block until the rate limiter has an available token. - pub async fn acquire(&self) { - let clock = DefaultClock::default(); - loop { - match self.inner.check() { - Ok(_) => { - trace!("Rate limiter: token acquired"); - return; - } - Err(negative) => { - let wait_duration = negative.wait_time_from(clock.now()); - trace!( - wait_ms = wait_duration.as_millis() as u64, - "Rate limiter: waiting for token" - ); - tokio::time::sleep(wait_duration).await; - } - } - } - } - - /// Try to acquire a token without waiting. - /// - /// Returns `true` if a token was acquired, `false` if the limit is reached. - pub fn try_acquire(&self) -> bool { - self.inner.check().is_ok() - } - - /// Get the configured requests per minute. - pub fn requests_per_minute(&self) -> usize { - self.requests_per_minute - } -} - -impl std::fmt::Debug for RateLimiter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RateLimiter") - .field("requests_per_minute", &self.requests_per_minute) - .finish() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rate_limiter_creation() { - let limiter = RateLimiter::new(100); - assert_eq!(limiter.requests_per_minute(), 100); - } - - #[test] - fn test_try_acquire() { - let limiter = RateLimiter::new(10); - // Should be able to acquire at least one token - assert!(limiter.try_acquire()); - } -} diff --git a/rust/tests/integration.rs b/rust/tests/integration.rs new file mode 100644 index 00000000..526035d3 --- /dev/null +++ b/rust/tests/integration.rs @@ -0,0 +1,180 @@ +// Copyright (c) 2026 vectorless developers +// SPDX-License-Identifier: Apache-2.0 + +//! Integration tests for the Engine client. +//! +//! These tests exercise the full index → persist → query lifecycle +//! without requiring a real LLM endpoint, using the no-LLM pipeline. + +use std::path::PathBuf; + +use vectorless::__test_support::build_test_engine; +use vectorless::{Engine, IndexContext, IndexMode}; + +async fn setup() -> (Engine, tempfile::TempDir) { + let tmp = tempfile::tempdir().unwrap(); + let engine = build_test_engine(tmp.path()).await; + (engine, tmp) +} + +#[tokio::test] +async fn test_index_and_persist_single_markdown() { + let (engine, tmp) = setup().await; + + // Write a test markdown file + let md_path = tmp.path().join("test.md"); + std::fs::write(&md_path, "# Hello\n\nWorld content here.").unwrap(); + + let ctx = IndexContext::from_path(&md_path).with_mode(IndexMode::Force); + let result = engine.index(ctx).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(!result.has_failures()); + let doc_id = result.doc_id().unwrap(); + assert!(!doc_id.is_empty()); + + // Verify persisted + assert!(engine.exists(doc_id).await.unwrap()); + + // List should contain 1 doc + let docs = engine.list().await.unwrap(); + assert_eq!(docs.len(), 1); + assert_eq!(docs[0].name, "test"); + + // Remove + assert!(engine.remove(doc_id).await.unwrap()); + assert!(!engine.exists(doc_id).await.unwrap()); +} + +#[tokio::test] +async fn test_index_from_content() { + let (engine, _tmp) = setup().await; + + let ctx = IndexContext::from_content( + "# Title\n\nParagraph 1\n\n## Section\n\nParagraph 2", + vectorless::DocumentFormat::Markdown, + ) + .with_name("inline-doc"); + + let result = engine.index(ctx).await.unwrap(); + assert_eq!(result.len(), 1); + let doc_id = result.doc_id().unwrap(); + + // Verify it's persisted and loadable + assert!(engine.exists(doc_id).await.unwrap()); + + // Clean up + engine.remove(doc_id).await.unwrap(); +} + +#[tokio::test] +async fn test_index_multiple_sources_parallel() { + let (engine, tmp) = setup().await; + + // Create 3 markdown files + let paths: Vec = (0..3) + .map(|i| { + let p = tmp.path().join(format!("doc{i}.md")); + std::fs::write(&p, format!("# Doc {i}\n\nContent {i}")).unwrap(); + p + }) + .collect(); + + let ctx = IndexContext::from_paths(paths).with_mode(IndexMode::Force); + let result = engine.index(ctx).await.unwrap(); + + assert_eq!(result.len(), 3); + assert!(!result.has_failures()); + + let docs = engine.list().await.unwrap(); + assert_eq!(docs.len(), 3); + + // Clear all + let count = engine.clear().await.unwrap(); + assert_eq!(count, 3); +} + +#[tokio::test] +async fn test_index_default_mode_skips_existing() { + let (engine, tmp) = setup().await; + + let md_path = tmp.path().join("existing.md"); + std::fs::write(&md_path, "# Original\n\nOriginal content.").unwrap(); + + // First index + let ctx = IndexContext::from_path(&md_path); + let result1 = engine.index(ctx).await.unwrap(); + assert_eq!(result1.len(), 1); + let id1 = result1.doc_id().unwrap().to_string(); + + // Second index with Default mode — should skip + let ctx = IndexContext::from_path(&md_path); + let result2 = engine.index(ctx).await.unwrap(); + assert_eq!(result2.len(), 1); + assert!(!result2.has_failures()); + // Same doc ID — not re-indexed + assert_eq!(result2.doc_id().unwrap(), id1); +} + +#[tokio::test] +async fn test_force_mode_reindexes() { + let (engine, tmp) = setup().await; + + let md_path = tmp.path().join("force.md"); + std::fs::write(&md_path, "# Version 1").unwrap(); + + // First index + let ctx = IndexContext::from_path(&md_path); + let result1 = engine.index(ctx).await.unwrap(); + let id1 = result1.doc_id().unwrap().to_string(); + + // Force re-index — should get a new doc ID + let ctx = IndexContext::from_path(&md_path).with_mode(IndexMode::Force); + let result2 = engine.index(ctx).await.unwrap(); + assert_eq!(result2.len(), 1); + // Different doc ID — re-indexed + assert_ne!(result2.doc_id().unwrap(), id1); +} + +#[tokio::test] +async fn test_cancel_blocks_new_operations() { + let (engine, _tmp) = setup().await; + + engine.cancel(); + assert!(engine.is_cancelled()); + + let ctx = IndexContext::from_content("# test", vectorless::DocumentFormat::Markdown); + let err = engine.index(ctx).await.unwrap_err(); + assert!(err.to_string().contains("cancelled")); + + engine.reset_cancel(); + assert!(!engine.is_cancelled()); +} + +#[tokio::test] +async fn test_clear_empty_workspace() { + let (engine, _tmp) = setup().await; + + let count = engine.clear().await.unwrap(); + assert_eq!(count, 0); +} + +#[tokio::test] +async fn test_remove_nonexistent() { + let (engine, _tmp) = setup().await; + + let removed = engine.remove("nonexistent-id").await.unwrap(); + assert!(!removed); +} + +#[tokio::test] +async fn test_index_from_bytes() { + let (engine, _tmp) = setup().await; + + let ctx = IndexContext::from_bytes(vec![1, 2, 3, 4], vectorless::DocumentFormat::Pdf) + .with_name("test-bytes"); + + // This will fail at parse (not a real PDF), but should error gracefully + let result = engine.index(ctx).await; + assert!(result.is_err()); +}