diff --git a/Cargo.toml b/Cargo.toml index d0e2244..187f052 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,9 +56,6 @@ lmdb = "0.8" # SQLite rusqlite = { version = "0.31", features = ["bundled"] } -# File I/O -memmap2 = "0.9" - # Serialization bincode = "1.3" serde_json = "1.0" diff --git a/crates/azoth-file-log/Cargo.toml b/crates/azoth-file-log/Cargo.toml index 7abfcb7..88f7041 100644 --- a/crates/azoth-file-log/Cargo.toml +++ b/crates/azoth-file-log/Cargo.toml @@ -17,7 +17,7 @@ serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } -memmap2 = { workspace = true } +parking_lot = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/azoth-file-log/src/lib.rs b/crates/azoth-file-log/src/lib.rs index 2e84a85..8fab54d 100644 --- a/crates/azoth-file-log/src/lib.rs +++ b/crates/azoth-file-log/src/lib.rs @@ -6,7 +6,7 @@ //! //! Features: //! - Fast sequential writes (no ACID overhead) -//! - Memory-mapped reads for iteration +//! - Buffered sequential reads for iteration //! - Automatic log rotation based on size //! - EventId allocation via atomic counter + file sync //! - Multiple concurrent readers, single writer diff --git a/crates/azoth-file-log/src/store.rs b/crates/azoth-file-log/src/store.rs index 2d481d6..ba899fc 100644 --- a/crates/azoth-file-log/src/store.rs +++ b/crates/azoth-file-log/src/store.rs @@ -3,13 +3,14 @@ use azoth_core::{ event_log::{EventLog, EventLogIterator, EventLogStats}, types::EventId, }; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::fs::{File, OpenOptions}; use std::io::{BufWriter, Read, Write}; use std::path::{Path, PathBuf}; use std::sync::{ atomic::{AtomicU64, Ordering}, - Arc, Mutex, + Arc, }; /// Configuration for file-based event log @@ -128,7 +129,7 @@ impl FileEventLog { /// Save metadata to disk fn save_meta(&self) -> Result<()> { - let meta = self.meta.lock().unwrap(); + let meta = self.meta.lock(); let meta_path = self.config.base_dir.join("meta.json"); // Use compact serialization (not pretty) for better performance let data = serde_json::to_string(&*meta) @@ -156,7 +157,7 @@ impl FileEventLog { ))); } - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); // Write event_id (8 bytes, big-endian) writer.write_all(&event_id.to_be_bytes())?; @@ -191,7 +192,7 @@ impl FileEventLog { fn rotate_internal(&self) -> Result> { // Flush current writer { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.flush()?; } @@ -204,7 +205,7 @@ impl FileEventLog { // Update metadata { - let mut meta = self.meta.lock().unwrap(); + let mut meta = self.meta.lock(); meta.current_file_num = new_file_num; } self.save_meta()?; @@ -218,7 +219,7 @@ impl FileEventLog { // Replace writer { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); *writer = BufWriter::with_capacity(self.config.write_buffer_size, file); } @@ -239,13 +240,13 @@ impl EventLog for FileEventLog { // Conditionally flush based on config if self.config.flush_on_append { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.flush()?; } // Update metadata { - let mut meta = self.meta.lock().unwrap(); + let mut meta = self.meta.lock(); meta.next_event_id = event_id + 1; meta.total_events += 1; } @@ -321,20 +322,20 @@ impl EventLog for FileEventLog { } // Single lock acquisition + single write syscall - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.write_all(&buffer)?; } // Conditionally flush based on config if self.config.flush_on_append { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.flush()?; } // Update metadata let last_id = first_event_id + events.len() as u64 - 1; { - let mut meta = self.meta.lock().unwrap(); + let mut meta = self.meta.lock(); meta.next_event_id = last_id + 1; meta.total_events += events.len() as u64; } @@ -359,11 +360,11 @@ impl EventLog for FileEventLog { ) -> Result> { // Flush writer to ensure all data is on disk { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.flush()?; } - let meta = self.meta.lock().unwrap(); + let meta = self.meta.lock(); let end_id = end.unwrap_or(meta.next_event_id); Ok(Box::new(FileEventLogIter::new( @@ -405,7 +406,7 @@ impl EventLog for FileEventLog { } fn oldest_event_id(&self) -> Result { - let meta = self.meta.lock().unwrap(); + let meta = self.meta.lock(); Ok(meta.oldest_event_id) } @@ -419,7 +420,7 @@ impl EventLog for FileEventLog { } fn sync(&self) -> Result<()> { - let mut writer = self.writer.lock().unwrap(); + let mut writer = self.writer.lock(); writer.flush()?; writer.get_ref().sync_all()?; self.save_meta()?; @@ -427,7 +428,7 @@ impl EventLog for FileEventLog { } fn stats(&self) -> Result { - let meta = self.meta.lock().unwrap(); + let meta = self.meta.lock(); // Calculate total bytes across all log files let mut total_bytes = 0u64; diff --git a/crates/azoth-lmdb/src/store.rs b/crates/azoth-lmdb/src/store.rs index d0b8352..f4c4707 100644 --- a/crates/azoth-lmdb/src/store.rs +++ b/crates/azoth-lmdb/src/store.rs @@ -8,10 +8,11 @@ use azoth_core::{ }; use azoth_file_log::{FileEventLog, FileEventLogConfig}; use lmdb::{Database, DatabaseFlags, Environment, EnvironmentFlags, Transaction, WriteFlags}; +use parking_lot::Mutex; use std::path::Path; use std::sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }; use std::time::Duration; @@ -316,7 +317,7 @@ impl CanonicalStore for LmdbCanonicalStore { } fn seal(&self) -> Result { - let _guard = self.write_lock.lock().unwrap(); + let _guard = self.write_lock.lock(); let mut txn = self .env @@ -426,7 +427,7 @@ impl LmdbCanonicalStore { /// Sealing is used as a temporary barrier to create deterministic snapshots. Backups should /// clear the seal before resuming ingestion; otherwise the DB becomes permanently read-only. pub fn clear_seal(&self) -> Result<()> { - let _guard = self.write_lock.lock().unwrap(); + let _guard = self.write_lock.lock(); let mut txn = self .env diff --git a/crates/azoth-sqlite/Cargo.toml b/crates/azoth-sqlite/Cargo.toml index 8628839..8fa78ab 100644 --- a/crates/azoth-sqlite/Cargo.toml +++ b/crates/azoth-sqlite/Cargo.toml @@ -20,6 +20,7 @@ tracing = { workspace = true } anyhow = { workspace = true } chrono = { workspace = true } tokio = { workspace = true, features = ["time", "sync"] } +parking_lot = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/azoth-sqlite/src/read_pool.rs b/crates/azoth-sqlite/src/read_pool.rs index b60634d..5ec2f43 100644 --- a/crates/azoth-sqlite/src/read_pool.rs +++ b/crates/azoth-sqlite/src/read_pool.rs @@ -7,10 +7,10 @@ use azoth_core::{ error::{AzothError, Result}, ReadPoolConfig, }; +use parking_lot::Mutex; use rusqlite::{Connection, OpenFlags}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Mutex; use std::time::{Duration, Instant}; use tokio::sync::{Semaphore, SemaphorePermit}; @@ -19,7 +19,7 @@ use tokio::sync::{Semaphore, SemaphorePermit}; /// This wraps a SQLite read-only connection with automatic permit release /// when the connection is returned to the pool. pub struct PooledSqliteConnection<'a> { - conn: std::sync::MutexGuard<'a, Connection>, + conn: parking_lot::MutexGuard<'a, Connection>, _permit: SemaphorePermit<'a>, } @@ -128,7 +128,7 @@ impl SqliteReadPool { let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % self.connections.len(); for i in 0..self.connections.len() { let idx = (start + i) % self.connections.len(); - if let Ok(guard) = self.connections[idx].try_lock() { + if let Some(guard) = self.connections[idx].try_lock() { return Ok(PooledSqliteConnection { conn: guard, _permit: permit, @@ -152,7 +152,7 @@ impl SqliteReadPool { let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % self.connections.len(); for i in 0..self.connections.len() { let idx = (start + i) % self.connections.len(); - if let Ok(guard) = self.connections[idx].try_lock() { + if let Some(guard) = self.connections[idx].try_lock() { return Ok(Some(PooledSqliteConnection { conn: guard, _permit: permit, diff --git a/crates/azoth-sqlite/src/store.rs b/crates/azoth-sqlite/src/store.rs index f8bc913..9b06d49 100644 --- a/crates/azoth-sqlite/src/store.rs +++ b/crates/azoth-sqlite/src/store.rs @@ -4,9 +4,10 @@ use azoth_core::{ types::EventId, ProjectionConfig, }; +use parking_lot::Mutex; use rusqlite::{Connection, OpenFlags}; use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use crate::read_pool::SqliteReadPool; use crate::schema; @@ -138,7 +139,7 @@ impl SqliteProjectionStore { let conn = self.read_conn.clone(); tokio::task::spawn_blocking(move || { - let conn_guard = conn.lock().unwrap(); + let conn_guard = conn.lock(); f(&conn_guard) }) .await @@ -166,7 +167,7 @@ impl SqliteProjectionStore { return f(conn.connection()); } - let conn_guard = self.read_conn.lock().unwrap(); + let conn_guard = self.read_conn.lock(); f(&conn_guard) } @@ -189,7 +190,7 @@ impl SqliteProjectionStore { { let conn = self.write_conn.clone(); tokio::task::spawn_blocking(move || { - let conn_guard = conn.lock().unwrap(); + let conn_guard = conn.lock(); f(&conn_guard) }) .await @@ -211,7 +212,7 @@ impl SqliteProjectionStore { where F: FnOnce(&Connection) -> Result<()>, { - let conn_guard = self.write_conn.lock().unwrap(); + let conn_guard = self.write_conn.lock(); f(&conn_guard) } @@ -233,7 +234,7 @@ impl SqliteProjectionStore { where F: FnOnce(&rusqlite::Transaction) -> Result<()>, { - let mut conn_guard = self.write_conn.lock().unwrap(); + let mut conn_guard = self.write_conn.lock(); let tx = conn_guard .transaction() .map_err(|e| AzothError::Projection(e.to_string()))?; @@ -254,7 +255,7 @@ impl SqliteProjectionStore { { let conn = self.write_conn.clone(); tokio::task::spawn_blocking(move || { - let mut conn_guard = conn.lock().unwrap(); + let mut conn_guard = conn.lock(); let tx = conn_guard .transaction() .map_err(|e| AzothError::Projection(e.to_string()))?; @@ -337,13 +338,13 @@ impl ProjectionStore for SqliteProjectionStore { fn begin_txn(&self) -> Result> { // Begin exclusive transaction using SimpleProjectionTxn (uses write connection) - let guard = self.write_conn.lock().unwrap(); + let guard = self.write_conn.lock(); SimpleProjectionTxn::new(guard) } fn get_cursor(&self) -> Result { // Use read connection for this read-only operation - let conn = self.read_conn.lock().unwrap(); + let conn = self.read_conn.lock(); let cursor: i64 = conn .query_row( "SELECT last_applied_event_id FROM projection_meta WHERE id = 0", @@ -356,14 +357,14 @@ impl ProjectionStore for SqliteProjectionStore { } fn migrate(&self, target_version: u32) -> Result<()> { - let conn = self.write_conn.lock().unwrap(); + let conn = self.write_conn.lock(); schema::migrate(&conn, target_version) } fn backup_to(&self, path: &Path) -> Result<()> { // Checkpoint WAL to flush all changes to the main database file { - let conn = self.write_conn.lock().unwrap(); + let conn = self.write_conn.lock(); // Execute checkpoint with full iteration of results let mut stmt = conn .prepare("PRAGMA wal_checkpoint(RESTART)") @@ -394,7 +395,7 @@ impl ProjectionStore for SqliteProjectionStore { fn schema_version(&self) -> Result { // Use read connection for this read-only operation - let conn = self.read_conn.lock().unwrap(); + let conn = self.read_conn.lock(); let version: i64 = conn .query_row( "SELECT schema_version FROM projection_meta WHERE id = 0", diff --git a/crates/azoth-sqlite/src/txn.rs b/crates/azoth-sqlite/src/txn.rs index eaa0f4c..d2df49e 100644 --- a/crates/azoth-sqlite/src/txn.rs +++ b/crates/azoth-sqlite/src/txn.rs @@ -3,8 +3,8 @@ use azoth_core::{ traits::ProjectionTxn, types::EventId, }; +use parking_lot::MutexGuard; use rusqlite::Connection; -use std::sync::MutexGuard; // Projection transaction that works with Connection directly pub struct SimpleProjectionTxn<'a> { diff --git a/crates/azoth-vector/README.md b/crates/azoth-vector/README.md index cf83aad..9427da1 100644 --- a/crates/azoth-vector/README.md +++ b/crates/azoth-vector/README.md @@ -155,14 +155,13 @@ let results = search.knn(&query, 20).await?; // Search with distance threshold let similar_only = search.threshold(&query, 0.5, 100).await?; -// Search with SQL filter +// Search with structured filter +let filter = VectorFilter::new() + .like("metadata", "%important%") + .gt("created_at", "2024-01-01"); + let filtered = search - .knn_filtered( - &query, - 10, - "metadata LIKE ? AND created_at > ?", - vec!["%important%".to_string(), "2024-01-01".to_string()], - ) + .knn_filtered(&query, 10, &filter) .await?; ``` @@ -290,13 +289,12 @@ let results = search.knn(&Vector::new(query_embedding), 5).await?; let user_preferences = Vector::new(user_embedding); let search = VectorSearch::new(db.projection(), "items", "embedding"); +let filter = VectorFilter::new() + .eq("category", &user_category) + .eq_i64("in_stock", 1); + let recommendations = search - .knn_filtered( - &user_preferences, - 20, - "category = ? AND in_stock = 1", - vec![user_category.to_string()], - ) + .knn_filtered(&user_preferences, 20, &filter) .await?; ``` diff --git a/crates/azoth-vector/src/extension.rs b/crates/azoth-vector/src/extension.rs index 6ab8bbc..4f6fca1 100644 --- a/crates/azoth-vector/src/extension.rs +++ b/crates/azoth-vector/src/extension.rs @@ -1,5 +1,6 @@ //! Extension loading and initialization +use crate::search::validate_sql_identifier; use crate::types::VectorConfig; use azoth_core::{error::AzothError, Result}; use rusqlite::Connection; @@ -48,6 +49,10 @@ pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<( } }); + // Reject paths that contain path-traversal components to prevent loading + // arbitrary libraries from unexpected locations. + validate_extension_path(ext_path)?; + unsafe { let _guard = rusqlite::LoadExtensionGuard::new(conn) .map_err(|e| AzothError::Projection(format!("Failed to enable extensions: {}", e)))?; @@ -65,6 +70,40 @@ pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<( Ok(()) } +/// Validate that an extension path does not contain path-traversal components +/// or point through symlinks. +/// +/// Rejects paths containing `..` components and paths that are symlinks, +/// as these could be used to load arbitrary shared libraries. +fn validate_extension_path(path: &Path) -> Result<()> { + // Reject ".." components anywhere in the path + for component in path.components() { + if let std::path::Component::ParentDir = component { + return Err(AzothError::Config(format!( + "Extension path '{}' contains '..' component. \ + Path traversal is not allowed for extension loading.", + path.display() + ))); + } + } + + // If the file exists, reject symlinks + if path.exists() + && path + .symlink_metadata() + .map(|m| m.file_type().is_symlink()) + .unwrap_or(false) + { + return Err(AzothError::Config(format!( + "Extension path '{}' is a symbolic link. \ + Symlinks are not allowed for extension loading.", + path.display() + ))); + } + + Ok(()) +} + /// Extend SqliteProjectionStore with vector support pub trait VectorExtension { /// Load the sqlite-vector extension @@ -135,20 +174,20 @@ pub trait VectorExtension { impl VectorExtension for azoth_sqlite::SqliteProjectionStore { fn load_vector_extension(&self, path: Option<&Path>) -> Result<()> { - let conn = self.conn().lock().unwrap(); + let conn = self.conn().lock(); load_vector_extension(&conn, path) } fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()> { - let conn = self.conn().lock().unwrap(); + // Validate identifiers to prevent SQL injection (same check used in VectorSearch::new) + validate_sql_identifier(table, "Table")?; + validate_sql_identifier(column, "Column")?; + + let conn = self.conn().lock(); let config_str = config.to_config_string(); conn.query_row( - &format!( - "SELECT vector_init('{}', '{}', ?)", - table.replace('\'', "''"), - column.replace('\'', "''") - ), + &format!("SELECT vector_init('{table}', '{column}', ?)"), [&config_str], |_row| Ok(()), ) @@ -169,13 +208,13 @@ impl VectorExtension for azoth_sqlite::SqliteProjectionStore { } fn has_vector_support(&self) -> bool { - let conn = self.conn().lock().unwrap(); + let conn = self.conn().lock(); let result = conn.prepare("SELECT vector_version()"); result.is_ok() } fn vector_version(&self) -> Result { - let conn = self.conn().lock().unwrap(); + let conn = self.conn().lock(); let version: String = conn .query_row("SELECT vector_version()", [], |row| row.get(0)) .map_err(|e| AzothError::Projection(format!("Failed to get vector version: {}", e)))?; diff --git a/crates/azoth-vector/src/lib.rs b/crates/azoth-vector/src/lib.rs index c8d93b3..9de6dd3 100644 --- a/crates/azoth-vector/src/lib.rs +++ b/crates/azoth-vector/src/lib.rs @@ -55,5 +55,5 @@ pub mod types; pub use extension::VectorExtension; pub use migration::{add_vector_column, create_vector_table}; -pub use search::VectorSearch; +pub use search::{VectorFilter, VectorSearch}; pub use types::{DistanceMetric, SearchResult, Vector, VectorConfig, VectorType}; diff --git a/crates/azoth-vector/src/search.rs b/crates/azoth-vector/src/search.rs index e224588..778d561 100644 --- a/crates/azoth-vector/src/search.rs +++ b/crates/azoth-vector/src/search.rs @@ -10,7 +10,7 @@ use std::sync::Arc; /// /// Only allows `[a-zA-Z_][a-zA-Z0-9_]*` to prevent SQL injection via /// identifier manipulation. Returns an error if the identifier is invalid. -fn validate_sql_identifier(name: &str, kind: &str) -> Result<()> { +pub(crate) fn validate_sql_identifier(name: &str, kind: &str) -> Result<()> { if name.is_empty() { return Err(azoth_core::error::AzothError::Config(format!( "{} name must not be empty", @@ -44,6 +44,191 @@ fn validate_sql_identifier(name: &str, kind: &str) -> Result<()> { Ok(()) } +/// A bound parameter value for SQL queries. +/// +/// Used internally by [`VectorFilter`] to hold typed values that are bound +/// via parameterized queries, preventing SQL injection. +#[derive(Clone, Debug)] +pub enum FilterValue { + /// String parameter + String(String), + /// 64-bit integer parameter + I64(i64), + /// 64-bit float parameter + F64(f64), +} + +impl FilterValue { + /// Convert into a boxed `ToSql` trait object for rusqlite parameter binding. + pub fn to_boxed_sql(self) -> Box { + match self { + Self::String(s) => Box::new(s), + Self::I64(i) => Box::new(i), + Self::F64(f) => Box::new(f), + } + } +} + +/// A single condition in a vector search filter. +#[derive(Clone, Debug)] +struct FilterCondition { + /// SQL fragment, e.g. `t.category = ?` or `t.in_stock = ?` + sql: String, + /// Bound parameter values (one per `?` placeholder in `sql`) + params: Vec, +} + +/// Type-safe filter builder for vector search queries. +/// +/// All column names are validated as safe SQL identifiers and all values are +/// bound via parameterized queries, eliminating SQL injection by construction. +/// +/// # Example +/// +/// ``` +/// use azoth_vector::VectorFilter; +/// +/// let filter = VectorFilter::new() +/// .eq("category", "electronics") +/// .eq_i64("in_stock", 1) +/// .gt("price", "9.99"); +/// +/// let (sql, params) = filter.to_sql().unwrap(); +/// assert_eq!(sql, "t.category = ? AND t.in_stock = ? AND t.price > ?"); +/// assert_eq!(params.len(), 3); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct VectorFilter { + conditions: Vec, +} + +impl VectorFilter { + /// Create an empty filter (matches all rows). + pub fn new() -> Self { + Self::default() + } + + /// Add a string equality condition: `t. = ?` + pub fn eq(self, column: &str, value: impl Into) -> Self { + self.add_op(column, "=", FilterValue::String(value.into())) + } + + /// Add a string inequality condition: `t. != ?` + pub fn neq(self, column: &str, value: impl Into) -> Self { + self.add_op(column, "!=", FilterValue::String(value.into())) + } + + /// Add a string greater-than condition: `t. > ?` + pub fn gt(self, column: &str, value: impl Into) -> Self { + self.add_op(column, ">", FilterValue::String(value.into())) + } + + /// Add a string greater-or-equal condition: `t. >= ?` + pub fn gte(self, column: &str, value: impl Into) -> Self { + self.add_op(column, ">=", FilterValue::String(value.into())) + } + + /// Add a string less-than condition: `t. < ?` + pub fn lt(self, column: &str, value: impl Into) -> Self { + self.add_op(column, "<", FilterValue::String(value.into())) + } + + /// Add a string less-or-equal condition: `t. <= ?` + pub fn lte(self, column: &str, value: impl Into) -> Self { + self.add_op(column, "<=", FilterValue::String(value.into())) + } + + /// Add a LIKE condition: `t. LIKE ?` + pub fn like(self, column: &str, pattern: impl Into) -> Self { + self.add_op(column, "LIKE", FilterValue::String(pattern.into())) + } + + /// Add an integer equality condition: `t. = ?` + pub fn eq_i64(self, column: &str, value: i64) -> Self { + self.add_op(column, "=", FilterValue::I64(value)) + } + + /// Add an integer greater-than condition: `t. > ?` + pub fn gt_i64(self, column: &str, value: i64) -> Self { + self.add_op(column, ">", FilterValue::I64(value)) + } + + /// Add an integer greater-or-equal condition: `t. >= ?` + pub fn gte_i64(self, column: &str, value: i64) -> Self { + self.add_op(column, ">=", FilterValue::I64(value)) + } + + /// Add an integer less-than condition: `t. < ?` + pub fn lt_i64(self, column: &str, value: i64) -> Self { + self.add_op(column, "<", FilterValue::I64(value)) + } + + /// Add an integer less-or-equal condition: `t. <= ?` + pub fn lte_i64(self, column: &str, value: i64) -> Self { + self.add_op(column, "<=", FilterValue::I64(value)) + } + + /// Add a float equality condition: `t. = ?` + pub fn eq_f64(self, column: &str, value: f64) -> Self { + self.add_op(column, "=", FilterValue::F64(value)) + } + + /// Add a float greater-than condition: `t. > ?` + pub fn gt_f64(self, column: &str, value: f64) -> Self { + self.add_op(column, ">", FilterValue::F64(value)) + } + + /// Add a float less-than condition: `t. < ?` + pub fn lt_f64(self, column: &str, value: f64) -> Self { + self.add_op(column, "<", FilterValue::F64(value)) + } + + /// Internal helper: validate column and push a condition. + fn add_op(mut self, column: &str, op: &str, value: FilterValue) -> Self { + self.conditions.push(FilterCondition { + // We store validated column + op; validation happens in to_sql() + sql: format!("t.{column} {op} ?"), + params: vec![value], + }); + self + } + + /// Emit the WHERE clause and its bound parameters. + /// + /// Returns `("1 = 1", [])` for an empty filter (matches all rows). + /// + /// # Errors + /// + /// Returns `AzothError::Config` if any column name fails identifier validation. + pub fn to_sql(&self) -> Result<(String, Vec)> { + if self.conditions.is_empty() { + return Ok(("1 = 1".to_string(), Vec::new())); + } + + // Validate all column names before emitting SQL + for cond in &self.conditions { + // Extract column name from `t. ?` + let col_name = cond + .sql + .strip_prefix("t.") + .and_then(|rest| rest.split_whitespace().next()) + .unwrap_or(""); + validate_sql_identifier(col_name, "Filter column")?; + } + + let sql_parts: Vec<&str> = self.conditions.iter().map(|c| c.sql.as_str()).collect(); + let sql = sql_parts.join(" AND "); + + let params: Vec = self + .conditions + .iter() + .flat_map(|c| c.params.clone()) + .collect(); + + Ok((sql, params)) + } +} + /// Vector search builder /// /// Provides k-NN search with optional filtering and custom distance metrics. @@ -191,35 +376,25 @@ impl VectorSearch { .collect()) } - /// Search with custom SQL filter - /// - /// Allows filtering results by additional columns in the table. - /// - /// # Safety (SQL Injection) + /// Search with structured filter conditions /// - /// The `filter` string is interpolated into the WHERE clause of the query. - /// **Always use `?` placeholders** for values and pass them via `filter_params`. - /// Never interpolate user input directly into the filter string. - /// - /// Table and column identifiers are validated at `VectorSearch::new()` time - /// to prevent identifier injection. + /// Allows filtering results by additional columns in the table using a + /// type-safe [`VectorFilter`] builder. All column names are validated as + /// safe SQL identifiers, and all values are bound via parameterized queries, + /// preventing SQL injection by construction. /// /// # Example /// /// ```no_run - /// # use azoth_vector::{VectorSearch, Vector}; + /// # use azoth_vector::{VectorSearch, Vector, VectorFilter}; /// # async fn example(search: VectorSearch) -> Result<(), Box> { /// let query = Vector::new(vec![0.1, 0.2, 0.3]); /// - /// // GOOD: parameterized filter - /// let results = search - /// .knn_filtered(&query, 10, "t.category = ?", vec!["tech".to_string()]) - /// .await?; + /// let filter = VectorFilter::new() + /// .eq("category", "tech") + /// .eq_i64("in_stock", 1); /// - /// // BAD (DO NOT DO THIS): interpolating user input - /// // let results = search - /// // .knn_filtered(&query, 10, &format!("t.category = '{}'", user_input), vec![]) - /// // .await?; + /// let results = search.knn_filtered(&query, 10, &filter).await?; /// # Ok(()) /// # } /// ``` @@ -227,15 +402,14 @@ impl VectorSearch { &self, query: &Vector, k: usize, - filter: &str, - filter_params: Vec, + filter: &VectorFilter, ) -> Result> { - // Table and column are validated at construction time via validate_sql_identifier + let (where_clause, filter_params) = filter.to_sql()?; + let table = self.table.clone(); let column = self.column.clone(); let query_json = query.to_json(); let k_i64 = k as i64; - let filter = filter.to_string(); self.projection .query_async(move |conn| { @@ -243,14 +417,14 @@ impl VectorSearch { "SELECT v.rowid, v.distance FROM vector_quantize_scan('{table}', '{column}', ?, ?) AS v JOIN {table} AS t ON v.rowid = t.rowid - WHERE {filter} + WHERE {where_clause} ORDER BY v.distance ASC", ); let mut params_vec: Vec> = vec![Box::new(query_json), Box::new(k_i64)]; for p in filter_params { - params_vec.push(Box::new(p)); + params_vec.push(p.to_boxed_sql()); } let mut stmt = conn diff --git a/crates/azoth-vector/tests/integration_test.rs b/crates/azoth-vector/tests/integration_test.rs index 1e426f2..a09d386 100644 --- a/crates/azoth-vector/tests/integration_test.rs +++ b/crates/azoth-vector/tests/integration_test.rs @@ -136,7 +136,9 @@ mod with_extension { use azoth_core::error::AzothError; use azoth_core::{ProjectionConfig, ProjectionStore}; use azoth_sqlite::SqliteProjectionStore; - use azoth_vector::{DistanceMetric, Vector, VectorConfig, VectorExtension, VectorSearch}; + use azoth_vector::{ + DistanceMetric, Vector, VectorConfig, VectorExtension, VectorFilter, VectorSearch, + }; use std::sync::Arc; use tempfile::tempdir; @@ -343,15 +345,11 @@ mod with_extension { let query = Vector::new(vec![0.95, 0.05, 0.0]); let search = VectorSearch::new(store.clone(), "items", "vector").unwrap(); - let results = search - .knn_filtered( - &query, - 10, - "category = ? AND in_stock = 1", - vec!["electronics".to_string()], - ) - .await - .unwrap(); + let filter = VectorFilter::new() + .eq("category", "electronics") + .eq_i64("in_stock", 1); + + let results = search.knn_filtered(&query, 10, &filter).await.unwrap(); assert_eq!(results.len(), 2); let rowids: std::collections::HashSet<_> = results.iter().map(|r| r.rowid).collect(); diff --git a/crates/azoth/Cargo.toml b/crates/azoth/Cargo.toml index ebd5f61..af5e40b 100644 --- a/crates/azoth/Cargo.toml +++ b/crates/azoth/Cargo.toml @@ -31,6 +31,8 @@ chrono = { workspace = true } bincode = { workspace = true } rmp-serde = "1.3" +parking_lot = { workspace = true } + # Encryption and compression age = "0.10" zstd = "0.13" diff --git a/crates/azoth/examples/batch_processing.rs b/crates/azoth/examples/batch_processing.rs index 8d5404b..99146bd 100644 --- a/crates/azoth/examples/batch_processing.rs +++ b/crates/azoth/examples/batch_processing.rs @@ -62,7 +62,7 @@ fn main() -> Result<()> { // Setup projection schema println!("1. Setting up projection schema..."); let conn = db.projection().conn(); - let locked_conn = conn.lock().unwrap(); + let locked_conn = conn.lock(); locked_conn .execute( "CREATE TABLE IF NOT EXISTS accounts ( @@ -110,7 +110,7 @@ fn main() -> Result<()> { // Process events in batches println!("\n4. Processing events in batches..."); let start = Instant::now(); - let locked_conn = db.projection().conn().lock().unwrap(); + let locked_conn = db.projection().conn().lock(); let processed = registry.process_batched( &locked_conn, events.iter().map(|(id, data)| (*id, data.as_slice())), diff --git a/crates/azoth/examples/migration_example.rs b/crates/azoth/examples/migration_example.rs index a6688c2..2c30c7a 100644 --- a/crates/azoth/examples/migration_example.rs +++ b/crates/azoth/examples/migration_example.rs @@ -143,7 +143,7 @@ fn main() -> Result<()> { // Check that tables were created { - let conn = db.projection().conn().lock().unwrap(); + let conn = db.projection().conn().lock(); // Query for users table let tables: Vec = conn diff --git a/crates/azoth/src/checkpoint.rs b/crates/azoth/src/checkpoint.rs index 3088e91..8b516a6 100644 --- a/crates/azoth/src/checkpoint.rs +++ b/crates/azoth/src/checkpoint.rs @@ -92,6 +92,46 @@ impl LocalStorage { pub fn new(base_path: PathBuf) -> Self { Self { base_path } } + + /// Resolve `id` to a path under `base_path`, rejecting path-traversal attempts. + /// + /// Only allows filenames that consist of ASCII alphanumeric characters, + /// hyphens, underscores, and periods. Rejects any id containing path + /// separators (`/`, `\`, `..`) or other special characters. + fn safe_path(&self, id: &str) -> Result { + // Reject empty ids + if id.is_empty() { + return Err(AzothError::Config( + "Checkpoint id must not be empty".to_string(), + )); + } + + // Whitelist: only safe filename characters + let is_safe = id + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.'); + + if !is_safe || id.contains("..") { + return Err(AzothError::Config(format!( + "Checkpoint id '{}' contains unsafe characters. \ + Only alphanumeric, hyphen, underscore, and period are allowed.", + id + ))); + } + + let resolved = self.base_path.join(id); + + // Defence-in-depth: verify the resolved path is still under base_path + // even after the whitelist check above. + if resolved.parent() != Some(&self.base_path) { + return Err(AzothError::Config(format!( + "Checkpoint id '{}' resolves outside the storage directory", + id + ))); + } + + Ok(resolved) + } } #[async_trait] @@ -104,12 +144,15 @@ impl CheckpointStorage for LocalStorage { metadata.timestamp.format("%Y%m%d-%H%M%S"), &metadata.id ); - let dest_path = self.base_path.join(&filename); + + // Validate that the generated filename is safe before writing + let dest_path = self.safe_path(&filename)?; std::fs::copy(path, &dest_path)?; // Also save metadata - let metadata_path = self.base_path.join(format!("{}.json", &filename)); + let metadata_filename = format!("{}.json", &filename); + let metadata_path = self.safe_path(&metadata_filename)?; let metadata_json = serde_json::to_string_pretty(metadata) .map_err(|e| AzothError::Serialization(e.to_string()))?; std::fs::write(metadata_path, metadata_json)?; @@ -118,7 +161,7 @@ impl CheckpointStorage for LocalStorage { } async fn download(&self, id: &str, path: &Path) -> Result<()> { - let src_path = self.base_path.join(id); + let src_path = self.safe_path(id)?; if !src_path.exists() { return Err(AzothError::NotFound(format!( "Checkpoint not found: {}", @@ -131,8 +174,9 @@ impl CheckpointStorage for LocalStorage { } async fn delete(&self, id: &str) -> Result<()> { - let checkpoint_path = self.base_path.join(id); - let metadata_path = self.base_path.join(format!("{}.json", id)); + let checkpoint_path = self.safe_path(id)?; + let metadata_filename = format!("{}.json", id); + let metadata_path = self.safe_path(&metadata_filename)?; if checkpoint_path.exists() { std::fs::remove_file(&checkpoint_path)?; diff --git a/crates/azoth/src/db.rs b/crates/azoth/src/db.rs index a9970d9..d3c37fd 100644 --- a/crates/azoth/src/db.rs +++ b/crates/azoth/src/db.rs @@ -104,10 +104,10 @@ impl AzothDb { /// # Example /// ```ignore /// let conn = db.projection_connection(); - /// let guard = conn.lock().unwrap(); + /// let guard = conn.lock(); /// guard.execute("INSERT INTO ...", params![])?; /// ``` - pub fn projection_connection(&self) -> &Arc> { + pub fn projection_connection(&self) -> &Arc> { self.projection.conn() } diff --git a/crates/azoth/src/lib.rs b/crates/azoth/src/lib.rs index 4868202..1cf2afc 100644 --- a/crates/azoth/src/lib.rs +++ b/crates/azoth/src/lib.rs @@ -97,5 +97,6 @@ pub use migration::{ }; pub use transaction::{ execute_transaction_async, AsyncTransaction, PreflightContext, Transaction, TransactionContext, + MAX_DECLARED_KEYS, }; pub use typed_values::{Array, Set, TypedValue, I256, U256}; diff --git a/crates/azoth/src/migration.rs b/crates/azoth/src/migration.rs index d56f933..b7a2f66 100644 --- a/crates/azoth/src/migration.rs +++ b/crates/azoth/src/migration.rs @@ -204,7 +204,7 @@ impl MigrationManager { /// Initialize the migration history table fn init_migration_history(&self, projection: &Arc) -> Result<()> { - let conn = projection.conn().lock().unwrap(); + let conn = projection.conn().lock(); conn.execute( "CREATE TABLE IF NOT EXISTS migration_history ( version INTEGER PRIMARY KEY, @@ -230,7 +230,7 @@ impl MigrationManager { ); // Run migration, history write, and schema-version bump atomically. - let conn = projection.conn().lock().unwrap(); + let conn = projection.conn().lock(); conn.execute_batch("BEGIN IMMEDIATE TRANSACTION") .map_err(|e| AzothError::Projection(e.to_string()))?; @@ -299,7 +299,7 @@ impl MigrationManager { ); // Execute rollback and metadata updates atomically. - let conn = projection.conn().lock().unwrap(); + let conn = projection.conn().lock(); conn.execute_batch("BEGIN IMMEDIATE TRANSACTION") .map_err(|e| AzothError::Projection(e.to_string()))?; @@ -375,7 +375,7 @@ impl MigrationManager { ) -> Result> { self.init_migration_history(projection)?; - let conn = projection.conn().lock().unwrap(); + let conn = projection.conn().lock(); let mut stmt = conn .prepare("SELECT version, name, applied_at FROM migration_history ORDER BY version") .map_err(|e| AzothError::Projection(e.to_string()))?; diff --git a/crates/azoth/src/transaction.rs b/crates/azoth/src/transaction.rs index 46549b4..f885e65 100644 --- a/crates/azoth/src/transaction.rs +++ b/crates/azoth/src/transaction.rs @@ -111,6 +111,25 @@ use azoth_lmdb::preflight_cache::{CachedValue, PreflightCache}; use std::collections::HashSet; use std::sync::Arc; +/// Maximum number of keys that can be declared per transaction. +/// +/// Declaring too many keys causes excessive stripe lock acquisition, high memory +/// usage for the value cache, and potential cache pollution. This limit acts as +/// a safety guard against accidental misuse or denial-of-service scenarios. +pub const MAX_DECLARED_KEYS: usize = 10_000; + +/// Check that the number of declared keys does not exceed [`MAX_DECLARED_KEYS`]. +fn check_key_limit(count: usize) -> Result<()> { + if count > MAX_DECLARED_KEYS { + return Err(AzothError::Config(format!( + "Transaction declares {} keys, which exceeds the maximum of {}. \ + Consider batching operations or increasing MAX_DECLARED_KEYS.", + count, MAX_DECLARED_KEYS + ))); + } + Ok(()) +} + // ============================================================================ // Async Transaction API // ============================================================================ @@ -206,6 +225,9 @@ impl AsyncTransaction { let declared_keys = self.declared_keys; let validators = self.validators; + // Check key count limit before spawning the blocking task + check_key_limit(declared_keys.len())?; + tokio::task::spawn_blocking(move || { // Phase 1: Acquire locks on declared keys (sorted, deadlock-free) let lock_manager = db.canonical().lock_manager(); @@ -776,6 +798,9 @@ impl<'a> Transaction<'a> { } } + // Guard against excessive key declarations + check_key_limit(self.declared_keys.len())?; + // Phase 1: Acquire locks on declared keys (sorted, deadlock-free) let lock_manager = self.db.canonical().lock_manager(); let keys_vec: Vec<&[u8]> = self.declared_keys.iter().map(|k| k.as_slice()).collect(); @@ -850,6 +875,9 @@ impl<'a> Transaction<'a> { { // No async safety check - caller takes responsibility + // Guard against excessive key declarations + check_key_limit(self.declared_keys.len())?; + // Phase 1: Acquire locks on declared keys (sorted, deadlock-free) let lock_manager = self.db.canonical().lock_manager(); let keys_vec: Vec<&[u8]> = self.declared_keys.iter().map(|k| k.as_slice()).collect(); diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 0a24e1b..791bf7b 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -283,7 +283,7 @@ println!("Restored to EventId: {}", cursor); ### Direct SQL Access ```rust -let conn = db.projection().conn().lock().unwrap(); +let conn = db.projection().conn().lock(); let balance: i64 = conn.query_row( "SELECT balance FROM accounts WHERE id = ?1", @@ -297,7 +297,7 @@ println!("Account balance: {}", balance); ### Custom Queries ```rust -let conn = db.projection().conn().lock().unwrap(); +let conn = db.projection().conn().lock(); let mut stmt = conn.prepare( "SELECT id, balance FROM accounts WHERE balance > ?1" @@ -494,7 +494,7 @@ fn test_end_to_end() { db.projector().run_once().unwrap(); // Verify SQL - let conn = db.projection().conn().lock().unwrap(); + let conn = db.projection().conn().lock(); let balance: i64 = conn.query_row( "SELECT balance FROM accounts WHERE id = 1", [],