diff --git a/Cargo.lock b/Cargo.lock index d892261..11747cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3707,7 +3707,7 @@ dependencies = [ [[package]] name = "sqlx-sqlite-conn-mgr" -version = "0.8.6" +version = "0.8.7" dependencies = [ "serde", "sqlx", @@ -3719,7 +3719,7 @@ dependencies = [ [[package]] name = "sqlx-sqlite-observer" -version = "0.8.6" +version = "0.8.7" dependencies = [ "futures", "libsqlite3-sys", @@ -3737,7 +3737,7 @@ dependencies = [ [[package]] name = "sqlx-sqlite-toolkit" -version = "0.8.6" +version = "0.8.7" dependencies = [ "base64 0.22.1", "indexmap 2.13.0", diff --git a/README.md b/README.md index 7f81a4c..f2bc71f 100644 --- a/README.md +++ b/README.md @@ -948,6 +948,41 @@ Working Tauri demo apps are in the [`examples/`](examples) directory: See the [toolkit crate README](crates/sqlx-sqlite-toolkit/README.md#examples) for setup instructions. +## Security Considerations + +### Cross-Window Shared State + +Database instances are shared across all webviews/windows within the same Tauri +application. A database loaded in one window is accessible from any other window +without calling `load()` again. Writes from one window are immediately visible +to reads in another, and closing a database affects all windows. + +### Resource Limits + +The plugin enforces several resource limits to prevent denial-of-service from +untrusted or buggy frontend code: + + * **Database count**: Maximum 50 concurrently loaded databases (configurable + via `Builder::max_databases()`) + * **Interruptible transaction timeout**: Transactions that exceed the + default (5 minutes) are automatically rolled back on the next access + attempt (configurable via `Builder::transaction_timeout()`) + * **Observer channel capacity**: Capped at 10,000 (default 256) + * **Observed tables**: Maximum 100 tables per `observe()` call + * **Subscriptions**: Maximum 100 active subscriptions per database + +### Unbounded Result Sets + +`fetchAll()` returns the entire result set in a single response with no built-in +size limit. For large or unbounded queries, prefer `fetchPage()` with keyset +pagination to keep memory usage bounded on both the Rust and TypeScript sides. + +### Path Validation + +Database paths are validated to prevent directory traversal. Absolute paths, +`..` segments, and null bytes are rejected. All paths are resolved relative to +the app config directory. + ## Development This project follows diff --git a/crates/sqlx-sqlite-conn-mgr/Cargo.toml b/crates/sqlx-sqlite-conn-mgr/Cargo.toml index efd981b..604c2f5 100644 --- a/crates/sqlx-sqlite-conn-mgr/Cargo.toml +++ b/crates/sqlx-sqlite-conn-mgr/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlx-sqlite-conn-mgr" # Sync major.minor with major.minor of SQLx crate -version = "0.8.6" +version = "0.8.7" description = "Wraps SQLx for SQLite, enforcing pragmatic connection policies for mobile and desktop applications" authors = ["Jeremy Thomerson"] license = "MIT" diff --git a/crates/sqlx-sqlite-conn-mgr/src/attached.rs b/crates/sqlx-sqlite-conn-mgr/src/attached.rs index f68985b..92d84e9 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/attached.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/attached.rs @@ -70,7 +70,7 @@ impl AttachedReadConnection { /// attached databases may persist when the connection is returned to the pool. pub async fn detach_all(mut self) -> Result<()> { for schema_name in &self.schema_names { - let detach_sql = format!("DETACH DATABASE {}", schema_name); + let detach_sql = format!("DETACH DATABASE \"{}\"", schema_name); sqlx::query(&detach_sql).execute(&mut *self.conn).await?; } Ok(()) @@ -140,7 +140,7 @@ impl AttachedWriteGuard { /// attached databases may persist when the connection is returned to the pool. pub async fn detach_all(mut self) -> Result<()> { for schema_name in &self.schema_names { - let detach_sql = format!("DETACH DATABASE {}", schema_name); + let detach_sql = format!("DETACH DATABASE \"{}\"", schema_name); sqlx::query(&detach_sql).execute(&mut *self.writer).await?; } Ok(()) @@ -252,7 +252,10 @@ pub async fn acquire_reader_with_attached( // Schema name is validated above to contain only safe identifier characters let path = spec.database.path_str(); let escaped_path = path.replace("'", "''"); - let attach_sql = format!("ATTACH DATABASE '{}' AS {}", escaped_path, spec.schema_name); + let attach_sql = format!( + "ATTACH DATABASE '{}' AS \"{}\"", + escaped_path, spec.schema_name + ); sqlx::query(&attach_sql).execute(&mut *conn).await?; schema_names.push(spec.schema_name); @@ -349,7 +352,10 @@ pub async fn acquire_writer_with_attached( for spec in specs { let path = spec.database.path_str(); let escaped_path = path.replace("'", "''"); - let attach_sql = format!("ATTACH DATABASE '{}' AS {}", escaped_path, spec.schema_name); + let attach_sql = format!( + "ATTACH DATABASE '{}' AS \"{}\"", + escaped_path, spec.schema_name + ); sqlx::query(&attach_sql).execute(&mut *writer).await?; schema_names.push(spec.schema_name); diff --git a/crates/sqlx-sqlite-observer/Cargo.toml b/crates/sqlx-sqlite-observer/Cargo.toml index 36742b8..6cfc264 100644 --- a/crates/sqlx-sqlite-observer/Cargo.toml +++ b/crates/sqlx-sqlite-observer/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlx-sqlite-observer" # Sync major.minor with major.minor of SQLx crate -version = "0.8.6" +version = "0.8.7" license = "MIT" edition = "2024" rust-version = "1.89" @@ -29,7 +29,7 @@ regex = "1.12.3" sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio"], default-features = false } # Required for preupdate_hook - SQLite must be compiled with SQLITE_ENABLE_PREUPDATE_HOOK libsqlite3-sys = { version = "0.30.1", features = ["preupdate_hook"] } -sqlx-sqlite-conn-mgr = { path = "../sqlx-sqlite-conn-mgr", version = "0.8.6", optional = true } +sqlx-sqlite-conn-mgr = { path = "../sqlx-sqlite-conn-mgr", version = "0.8.7", optional = true } [dev-dependencies] tokio = { version = "1.49.0", features = ["full", "macros"] } diff --git a/crates/sqlx-sqlite-observer/src/broker.rs b/crates/sqlx-sqlite-observer/src/broker.rs index df3ef08..401214c 100644 --- a/crates/sqlx-sqlite-observer/src/broker.rs +++ b/crates/sqlx-sqlite-observer/src/broker.rs @@ -59,7 +59,16 @@ pub struct ObservationBroker { impl ObservationBroker { /// Creates a new broker with the specified broadcast channel capacity. + /// + /// # Panics + /// + /// Panics if `channel_capacity` is 0. pub fn new(channel_capacity: usize, capture_values: bool) -> Arc { + // broadcast::channel panics on zero capacity. Assert here to surface a clear + // message rather than an internal tokio panic. Changing the return type to + // Result would ripple through every call site for a case that the plugin layer + // already validates before reaching this point. + assert!(channel_capacity > 0, "channel_capacity must be at least 1"); let (change_tx, _) = broadcast::channel(channel_capacity); Arc::new(Self { buffer: Mutex::new(Vec::new()), diff --git a/crates/sqlx-sqlite-observer/src/config.rs b/crates/sqlx-sqlite-observer/src/config.rs index 3d03c58..d6a94b6 100644 --- a/crates/sqlx-sqlite-observer/src/config.rs +++ b/crates/sqlx-sqlite-observer/src/config.rs @@ -90,6 +90,9 @@ impl ObserverConfig { /// Sets the broadcast channel capacity for change notifications. /// + /// Capacity must be at least 1. A capacity of 0 will cause a panic when the + /// observer is initialized. + /// /// See [`channel_capacity`](Self::channel_capacity) for details on sizing. pub fn with_channel_capacity(mut self, capacity: usize) -> Self { self.channel_capacity = capacity; diff --git a/crates/sqlx-sqlite-observer/src/schema.rs b/crates/sqlx-sqlite-observer/src/schema.rs index 8615514..1ec59e2 100644 --- a/crates/sqlx-sqlite-observer/src/schema.rs +++ b/crates/sqlx-sqlite-observer/src/schema.rs @@ -20,11 +20,11 @@ pub async fn query_table_info( // Check if table exists and get WITHOUT ROWID status let without_rowid = is_without_rowid(conn, table_name).await?; - // Get primary key columns using PRAGMA table_info + // Get primary key columns using pragma_table_info() let pk_columns = query_pk_columns(conn, table_name).await?; // Determine if table exists: - // - If pk_columns is None, PRAGMA table_info returned no rows (table doesn't exist) + // - If pk_columns is None, pragma_table_info returned no rows (table doesn't exist) // - If without_rowid is true, the table must exist (we found it in sqlite_master) // - A table with no explicit PK returns Some([]), not None if pk_columns.is_none() && !without_rowid { @@ -78,15 +78,20 @@ fn has_without_rowid_clause(create_sql: &str) -> bool { /// Returns column indices in the order they appear in the PRIMARY KEY definition. /// For composite primary keys, the `pk` column in PRAGMA table_info indicates /// the position (1-indexed) within the PK. +/// +/// Uses the `pragma_table_info()` table-valued function (available since SQLite +/// 3.16.0) so the table name can be bound as a parameter instead of interpolated +/// into the SQL string. async fn query_pk_columns( conn: &mut SqliteConnection, table_name: &str, ) -> crate::Result>> { - // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk + // pragma_table_info returns: cid, name, type, notnull, dflt_value, pk // pk is 0 for non-PK columns, or 1-indexed position for PK columns - let pragma = format!("PRAGMA table_info({})", quote_identifier(table_name)); + let sql = "SELECT cid, name, type, \"notnull\", dflt_value, pk FROM pragma_table_info(?1)"; - let rows = sqlx::query(&pragma) + let rows = sqlx::query(sql) + .bind(table_name) .fetch_all(&mut *conn) .await .map_err(crate::Error::Sqlx)?; @@ -116,23 +121,10 @@ async fn query_pk_columns( Ok(Some(pk_columns.into_iter().map(|(cid, _)| cid).collect())) } -/// Quotes a SQLite identifier to prevent SQL injection. -fn quote_identifier(name: &str) -> String { - // Double any existing double quotes and wrap in double quotes - format!("\"{}\"", name.replace('"', "\"\"")) -} - #[cfg(test)] mod tests { use super::*; - #[test] - fn test_quote_identifier() { - assert_eq!(quote_identifier("users"), "\"users\""); - assert_eq!(quote_identifier("my table"), "\"my table\""); - assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\""); - } - #[test] fn test_has_without_rowid_clause() { // Positive cases diff --git a/crates/sqlx-sqlite-toolkit/Cargo.toml b/crates/sqlx-sqlite-toolkit/Cargo.toml index ac74d4d..41be89b 100644 --- a/crates/sqlx-sqlite-toolkit/Cargo.toml +++ b/crates/sqlx-sqlite-toolkit/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlx-sqlite-toolkit" # Sync major.minor with major.minor of SQLx crate -version = "0.8.6" +version = "0.8.7" license = "MIT" edition = "2024" rust-version = "1.89" diff --git a/crates/sqlx-sqlite-toolkit/src/error.rs b/crates/sqlx-sqlite-toolkit/src/error.rs index 37ffe46..f030bc6 100644 --- a/crates/sqlx-sqlite-toolkit/src/error.rs +++ b/crates/sqlx-sqlite-toolkit/src/error.rs @@ -45,6 +45,10 @@ pub enum Error { #[error("invalid transaction token")] InvalidTransactionToken, + /// Transaction timed out (exceeded the configured timeout). + #[error("transaction timed out for database: {0}")] + TransactionTimedOut(String), + /// Error from the observer (change notifications). #[cfg(feature = "observer")] #[error(transparent)] @@ -115,6 +119,7 @@ impl Error { Error::TransactionAlreadyActive(_) => "TRANSACTION_ALREADY_ACTIVE".to_string(), Error::NoActiveTransaction(_) => "NO_ACTIVE_TRANSACTION".to_string(), Error::InvalidTransactionToken => "INVALID_TRANSACTION_TOKEN".to_string(), + Error::TransactionTimedOut(_) => "TRANSACTION_TIMED_OUT".to_string(), #[cfg(feature = "observer")] Error::Observer(_) => "OBSERVER_ERROR".to_string(), Error::Io(_) => "IO_ERROR".to_string(), @@ -194,6 +199,13 @@ mod tests { assert_eq!(err.error_code(), "IO_ERROR"); } + #[test] + fn test_error_code_transaction_timed_out() { + let err = Error::TransactionTimedOut("test.db".into()); + assert_eq!(err.error_code(), "TRANSACTION_TIMED_OUT"); + assert!(err.to_string().contains("test.db")); + } + #[test] fn test_error_code_other() { let err = Error::Other("something went wrong".into()); diff --git a/crates/sqlx-sqlite-toolkit/src/transactions.rs b/crates/sqlx-sqlite-toolkit/src/transactions.rs index 438eb92..01cb709 100644 --- a/crates/sqlx-sqlite-toolkit/src/transactions.rs +++ b/crates/sqlx-sqlite-toolkit/src/transactions.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; use indexmap::IndexMap; use serde::Deserialize; @@ -10,7 +11,7 @@ use sqlx::{Column, Row}; use sqlx_sqlite_conn_mgr::{AttachedWriteGuard, WriteGuard}; use tokio::sync::{Mutex, RwLock}; use tokio::task::AbortHandle; -use tracing::debug; +use tracing::{debug, warn}; #[cfg(feature = "observer")] use sqlx_sqlite_observer::ObservableWriteGuard; @@ -97,6 +98,7 @@ pub struct ActiveInterruptibleTransaction { db_path: String, transaction_id: String, writer: Option, + created_at: Instant, } impl ActiveInterruptibleTransaction { @@ -105,6 +107,7 @@ impl ActiveInterruptibleTransaction { db_path, transaction_id, writer: Some(writer), + created_at: Instant::now(), } } @@ -241,34 +244,71 @@ impl Drop for ActiveInterruptibleTransaction { } } +/// Default transaction timeout (5 minutes). +const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(300); + /// Global state tracking all active interruptible transactions. /// -/// Enforces one interruptible transaction per database path. +/// Enforces one interruptible transaction per database path and applies a configurable +/// timeout. Expired transactions are cleaned up lazily on the next `insert()` or +/// `remove()` call — no background task is needed. +/// /// Uses `Mutex` rather than `RwLock` because all operations require write access, /// and `Mutex` only requires `T: Send` (not `T: Sync`) — avoiding an /// `unsafe impl Sync` that would otherwise be needed due to non-`Sync` inner /// types (`PoolConnection`, raw pointers in observer guards). -#[derive(Clone, Default)] -pub struct ActiveInterruptibleTransactions( - Arc>>, -); +#[derive(Clone)] +pub struct ActiveInterruptibleTransactions { + inner: Arc>>, + timeout: Duration, +} + +impl Default for ActiveInterruptibleTransactions { + fn default() -> Self { + Self::new(DEFAULT_TRANSACTION_TIMEOUT) + } +} impl ActiveInterruptibleTransactions { + /// Create a new instance with the given transaction timeout. + pub fn new(timeout: Duration) -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + timeout, + } + } + pub async fn insert(&self, db_path: String, tx: ActiveInterruptibleTransaction) -> Result<()> { use std::collections::hash_map::Entry; - let mut txs = self.0.lock().await; + let mut txs = self.inner.lock().await; match txs.entry(db_path.clone()) { Entry::Vacant(e) => { e.insert(tx); Ok(()) } - Entry::Occupied(_) => Err(Error::TransactionAlreadyActive(db_path)), + Entry::Occupied(mut e) => { + // If the existing transaction has expired, drop it (auto-rollback) and + // replace with the new one. + if e.get().created_at.elapsed() >= self.timeout { + warn!( + "Evicting expired transaction for db: {} (age: {:?}, timeout: {:?})", + db_path, + e.get().created_at.elapsed(), + self.timeout, + ); + // Drop the expired transaction (auto-rollback) before inserting the new one + let _expired = e.insert(tx); + Ok(()) + } else { + Err(Error::TransactionAlreadyActive(db_path)) + } + } } } pub async fn abort_all(&self) { - let mut txs = self.0.lock().await; + let mut txs = self.inner.lock().await; debug!("Aborting {} active interruptible transaction(s)", txs.len()); for db_path in txs.keys() { @@ -283,13 +323,17 @@ impl ActiveInterruptibleTransactions { txs.clear(); } - /// Remove and return transaction for commit/rollback + /// Remove and return transaction for commit/rollback. + /// + /// Returns `Err(Error::TransactionTimedOut)` if the transaction has exceeded the + /// configured timeout. The expired transaction is dropped (auto-rolled-back) in + /// that case. pub async fn remove( &self, db_path: &str, token_id: &str, ) -> Result { - let mut txs = self.0.lock().await; + let mut txs = self.inner.lock().await; // Validate token before removal let tx = txs @@ -300,6 +344,19 @@ impl ActiveInterruptibleTransactions { return Err(Error::InvalidTransactionToken); } + // Check if the transaction has expired + if tx.created_at.elapsed() >= self.timeout { + warn!( + "Transaction timed out for db: {} (age: {:?}, timeout: {:?})", + db_path, + tx.created_at.elapsed(), + self.timeout, + ); + // Drop the expired transaction (auto-rollback via Drop) + txs.remove(db_path); + return Err(Error::TransactionTimedOut(db_path.to_string())); + } + // Safe unwrap: we just confirmed the key exists above Ok(txs.remove(db_path).unwrap()) } diff --git a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs index ba5ba00..7ac675d 100644 --- a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs +++ b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs @@ -202,6 +202,80 @@ async fn test_insert_after_abort_all_succeeds() { state.insert("reuse-key".into(), tx2).await.unwrap(); } +// ============================================================================ +// ActiveInterruptibleTransactions timeout tests +// ============================================================================ + +#[tokio::test] +async fn test_expired_transaction_evicted_on_insert() { + let (db1, _temp1) = create_test_db("expire1.db").await; + let (db2, _temp2) = create_test_db("expire2.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + // Use a 1ms timeout so the first transaction expires immediately + let state = ActiveInterruptibleTransactions::new(std::time::Duration::from_millis(1)); + + let tx1 = begin_transaction(&db1, "shared-key").await; + state.insert("shared-key".into(), tx1).await.unwrap(); + + // Sleep to ensure the transaction expires + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Second insert should succeed because the expired transaction is evicted + let tx2 = begin_transaction(&db2, "shared-key").await; + state.insert("shared-key".into(), tx2).await.unwrap(); +} + +#[tokio::test] +async fn test_remove_expired_transaction_returns_timed_out() { + let (db, _temp) = create_test_db("timeout.db").await; + + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + + let state = ActiveInterruptibleTransactions::new(std::time::Duration::from_millis(1)); + + let tx = begin_transaction(&db, "timeout.db").await; + let tx_id = tx.transaction_id().to_string(); + + state.insert("timeout.db".into(), tx).await.unwrap(); + + // Sleep to ensure the transaction expires + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let err = expect_err(state.remove("timeout.db", &tx_id).await); + assert_eq!(err.error_code(), "TRANSACTION_TIMED_OUT"); +} + +#[tokio::test] +async fn test_non_expired_transaction_not_evicted() { + let (db1, _temp1) = create_test_db("live1.db").await; + let (db2, _temp2) = create_test_db("live2.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + // Use a long timeout so the first transaction does NOT expire + let state = ActiveInterruptibleTransactions::new(std::time::Duration::from_secs(300)); + + let tx1 = begin_transaction(&db1, "shared-key").await; + state.insert("shared-key".into(), tx1).await.unwrap(); + + // Second insert should still fail because the first transaction is alive + let tx2 = begin_transaction(&db2, "shared-key").await; + let err = state.insert("shared-key".into(), tx2).await.unwrap_err(); + assert_eq!(err.error_code(), "TRANSACTION_ALREADY_ACTIVE"); +} + // ============================================================================ // ActiveRegularTransactions tests // ============================================================================ diff --git a/guest-js/index.ts b/guest-js/index.ts index d064953..3537fae 100644 --- a/guest-js/index.ts +++ b/guest-js/index.ts @@ -757,6 +757,12 @@ class TransactionBuilder implements PromiseLike { * * The `Database` class serves as the primary interface for * communicating with SQLite databases through the plugin. + * + * @remarks + * Database instances are shared across all webviews/windows within the same Tauri + * application. A database loaded in one window is accessible from any other window + * without calling `load()` again. This means writes from one window are immediately + * visible to reads in another, and closing a database affects all windows. */ export default class Database { public path: string; @@ -930,6 +936,11 @@ export default class Database { * * SQLite uses `$1`, `$2`, etc. for parameter binding. * + * @remarks + * This method returns the entire result set in a single response. For large or + * unbounded queries, prefer {@link fetchPage} with keyset pagination to keep memory + * usage bounded on both the Rust and TypeScript sides. + * * @param query - SQL SELECT query * @param bindValues - Optional parameter values * diff --git a/src/commands.rs b/src/commands.rs index 68adc67..2222bda 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -112,7 +112,7 @@ pub async fn load( // Wait for migrations to complete if registered for this database await_migrations(&migration_states, &db).await?; - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; // Return cached if db was already loaded if instances.contains_key(&db) { @@ -121,7 +121,14 @@ pub async fn load( drop(instances); // Release read lock before acquiring write lock - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; + + // Check database count limit before creating a new connection. + // This check is before entry() to avoid borrow conflicts, and the write lock + // prevents races between the len() check and the insert. + if !instances.contains_key(&db) && instances.len() >= db_instances.max { + return Err(Error::TooManyDatabases(db_instances.max)); + } // Use entry API to atomically check and insert, avoiding race conditions // where two callers could both create wrappers @@ -187,7 +194,7 @@ pub async fn execute( values: Vec, attached: Option>, ) -> Result<(u64, i64)> { - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -214,7 +221,7 @@ pub async fn execute_transaction( statements: Vec, attached: Option>, ) -> Result> { - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -283,7 +290,10 @@ pub async fn execute_transaction( } } -/// Execute a SELECT query returning all matching rows +/// Execute a SELECT query returning all matching rows. +/// +/// Returns the entire result set in a single response. For large or unbounded queries, +/// prefer `fetch_page` with keyset pagination to keep memory usage bounded. #[tauri::command] pub async fn fetch_all( db_instances: State<'_, DbInstances>, @@ -292,7 +302,7 @@ pub async fn fetch_all( values: Vec, attached: Option>, ) -> Result>> { - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -319,7 +329,7 @@ pub async fn fetch_one( values: Vec, attached: Option>, ) -> Result>> { - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -357,7 +367,7 @@ pub async fn fetch_page( )); } - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -394,7 +404,7 @@ pub async fn close( ) -> Result { active_subs.remove_for_db(&db).await; - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; if let Some(wrapper) = instances.remove(&db) { wrapper.close().await?; @@ -415,7 +425,7 @@ pub async fn close_all( ) -> Result<()> { active_subs.abort_all().await; - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; // Collect all wrappers to close let wrappers: Vec = instances.drain().map(|(_, v)| v).collect(); @@ -447,7 +457,7 @@ pub async fn remove( ) -> Result { active_subs.remove_for_db(&db).await; - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; if let Some(wrapper) = instances.remove(&db) { wrapper.remove().await?; @@ -489,7 +499,7 @@ pub async fn begin_interruptible_transaction( initial_statements: Vec, attached: Option>, ) -> Result { - let instances = db_instances.0.read().await; + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -637,11 +647,21 @@ pub async fn observe( tables: Vec, config: Option, ) -> Result<()> { + const MAX_OBSERVED_TABLES: usize = 100; + const MAX_CHANNEL_CAPACITY: usize = 10_000; + + if tables.is_empty() || tables.len() > MAX_OBSERVED_TABLES { + return Err(Error::InvalidConfig(format!( + "tables count must be between 1 and {MAX_OBSERVED_TABLES}, got {}", + tables.len() + ))); + } + // Abort plugin-level subscription tasks before the crate-level // enable_observation() drops the old broker active_subs.remove_for_db(&db).await; - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; let wrapper = instances .get_mut(&db) @@ -651,6 +671,11 @@ pub async fn observe( if let Some(params) = config { if let Some(capacity) = params.channel_capacity { + if capacity == 0 || capacity > MAX_CHANNEL_CAPACITY { + return Err(Error::InvalidConfig(format!( + "channel_capacity must be between 1 and {MAX_CHANNEL_CAPACITY}, got {capacity}" + ))); + } observer_config = observer_config.with_channel_capacity(capacity); } if let Some(capture) = params.capture_values { @@ -676,7 +701,14 @@ pub async fn subscribe( tables: Vec, on_event: Channel, ) -> Result { - let instances = db_instances.0.read().await; + const MAX_SUBSCRIPTIONS_PER_DATABASE: usize = 100; + + let sub_count = active_subs.count_for_db(&db).await; + if sub_count >= MAX_SUBSCRIPTIONS_PER_DATABASE { + return Err(Error::TooManySubscriptions(MAX_SUBSCRIPTIONS_PER_DATABASE)); + } + + let instances = db_instances.inner.read().await; let wrapper = instances .get(&db) @@ -740,7 +772,7 @@ pub async fn unobserve( // Abort all subscriptions for this database first active_subs.remove_for_db(&db).await; - let mut instances = db_instances.0.write().await; + let mut instances = db_instances.inner.write().await; let wrapper = instances .get_mut(&db) diff --git a/src/error.rs b/src/error.rs index 0d02874..4fc5e1e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,10 @@ pub enum Error { #[error("invalid database path: {0}")] InvalidPath(String), + /// Path traversal attempt detected. + #[error("path traversal not allowed: {0}")] + PathTraversal(String), + /// Attempted to access a database that hasn't been loaded. #[error("database {0} not loaded")] DatabaseNotLoaded(String), @@ -35,6 +39,18 @@ pub enum Error { #[error("observation not enabled for database: {0}")] ObservationNotEnabled(String), + /// Too many databases loaded simultaneously. + #[error("cannot load more than {0} databases")] + TooManyDatabases(usize), + + /// Too many subscriptions for a single database. + #[error("cannot create more than {0} subscriptions per database")] + TooManySubscriptions(usize), + + /// Invalid configuration parameter. + #[error("invalid configuration: {0}")] + InvalidConfig(String), + /// Generic error for operations that don't fit other categories. #[error("{0}")] Other(String), @@ -67,8 +83,12 @@ impl Error { Error::Toolkit(e) => e.error_code(), Error::Migration(_) => "MIGRATION_ERROR".to_string(), Error::InvalidPath(_) => "INVALID_PATH".to_string(), + Error::PathTraversal(_) => "PATH_TRAVERSAL".to_string(), Error::DatabaseNotLoaded(_) => "DATABASE_NOT_LOADED".to_string(), Error::ObservationNotEnabled(_) => "OBSERVATION_NOT_ENABLED".to_string(), + Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(), + Error::TooManySubscriptions(_) => "TOO_MANY_SUBSCRIPTIONS".to_string(), + Error::InvalidConfig(_) => "INVALID_CONFIG".to_string(), Error::Other(_) => "ERROR".to_string(), } } diff --git a/src/lib.rs b/src/lib.rs index e40cee2..9717310 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,12 +22,38 @@ pub use sqlx_sqlite_toolkit::{ TransactionExecutionBuilder, WriteQueryResult, }; +/// Default maximum number of concurrently loaded databases. +const DEFAULT_MAX_DATABASES: usize = 50; + /// Database instances managed by the plugin. /// /// This struct maintains a thread-safe map of database paths to their corresponding -/// connection wrappers. -#[derive(Clone, Default)] -pub struct DbInstances(pub Arc>>); +/// connection wrappers, with a configurable upper limit on how many databases can be +/// loaded simultaneously. +#[derive(Clone)] +pub struct DbInstances { + pub(crate) inner: Arc>>, + pub(crate) max: usize, +} + +impl Default for DbInstances { + fn default() -> Self { + Self { + inner: Arc::new(RwLock::new(HashMap::new())), + max: DEFAULT_MAX_DATABASES, + } + } +} + +impl DbInstances { + /// Create a new instance with the given maximum database count. + pub fn new(max: usize) -> Self { + Self { + inner: Arc::new(RwLock::new(HashMap::new())), + max, + } + } +} /// Migration status for a database. #[derive(Debug, Clone)] @@ -130,10 +156,14 @@ pub struct MigrationEvent { /// .expect("error while running tauri application"); /// # } /// ``` -#[derive(Default)] +#[derive(Debug, Default)] pub struct Builder { /// Migrations registered per database path migrations: HashMap>, + /// Timeout for interruptible transactions. Defaults to 5 minutes. + transaction_timeout: Option, + /// Maximum number of concurrently loaded databases. Defaults to 50. + max_databases: Option, } impl Builder { @@ -141,6 +171,8 @@ impl Builder { pub fn new() -> Self { Self { migrations: HashMap::new(), + transaction_timeout: None, + max_databases: None, } } @@ -170,9 +202,43 @@ impl Builder { self } + /// Set the timeout for interruptible transactions. + /// + /// If an interruptible transaction exceeds this duration, it will be automatically + /// rolled back on the next access attempt. Defaults to 5 minutes. + /// + /// Returns `Err(Error::InvalidConfig)` if `timeout` is zero. + pub fn transaction_timeout(mut self, timeout: std::time::Duration) -> Result { + if timeout.is_zero() { + return Err(Error::InvalidConfig( + "transaction_timeout must be greater than zero".to_string(), + )); + } + self.transaction_timeout = Some(timeout); + Ok(self) + } + + /// Set the maximum number of databases that can be loaded simultaneously. + /// + /// Prevents unbounded memory growth from connection pool proliferation. + /// Defaults to 50. + /// + /// Returns `Err(Error::InvalidConfig)` if `max` is zero. + pub fn max_databases(mut self, max: usize) -> Result { + if max == 0 { + return Err(Error::InvalidConfig( + "max_databases must be greater than zero".to_string(), + )); + } + self.max_databases = Some(max); + Ok(self) + } + /// Build the plugin with command registration and state management. pub fn build(self) -> tauri::plugin::TauriPlugin { let migrations = Arc::new(self.migrations); + let transaction_timeout = self.transaction_timeout; + let max_databases = self.max_databases; PluginBuilder::::new("sqlite") .invoke_handler(tauri::generate_handler![ @@ -195,9 +261,15 @@ impl Builder { commands::unobserve, ]) .setup(move |app, _api| { - app.manage(DbInstances::default()); + app.manage(match max_databases { + Some(max) => DbInstances::new(max), + None => DbInstances::default(), + }); app.manage(MigrationStates::default()); - app.manage(ActiveInterruptibleTransactions::default()); + app.manage(match transaction_timeout { + Some(timeout) => ActiveInterruptibleTransactions::new(timeout), + None => ActiveInterruptibleTransactions::default(), + }); app.manage(ActiveRegularTransactions::default()); app.manage(subscriptions::ActiveSubscriptions::default()); @@ -270,7 +342,7 @@ impl Builder { // Close databases (each wrapper's close() disables its own // observer at the crate level, unregistering SQLite hooks) - let mut guard = instances_clone.0.write().await; + let mut guard = instances_clone.inner.write().await; let wrappers: Vec = guard.drain().map(|(_, v)| v).collect(); @@ -313,7 +385,7 @@ impl Builder { // ExitRequested should have already closed all databases // This is just a safety check let instances = app.state::(); - match instances.0.try_read() { + match instances.inner.try_read() { Ok(guard) => { if !guard.is_empty() { warn!( @@ -473,17 +545,49 @@ fn emit_migration_event( } } -/// Resolve database path for migrations (similar to wrapper but accessible at init). +/// Resolve database path for migrations. +/// +/// Delegates to `resolve::resolve_database_path` to ensure consistent path validation +/// across all entry points. fn resolve_migration_path( path: &str, app: &tauri::AppHandle, ) -> Result { - let app_path = app - .path() - .app_config_dir() - .map_err(|_| Error::InvalidPath("No app config path found".to_string()))?; + crate::resolve::resolve_database_path(path, app) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_databases_rejects_zero() { + let err = Builder::new().max_databases(0).unwrap_err(); + assert!(matches!(err, Error::InvalidConfig(_))); + } - std::fs::create_dir_all(&app_path)?; + #[test] + fn test_max_databases_accepts_positive() { + let builder = Builder::new().max_databases(1).unwrap(); + assert_eq!(builder.max_databases, Some(1)); + } + + #[test] + fn test_transaction_timeout_rejects_zero() { + let err = Builder::new() + .transaction_timeout(std::time::Duration::ZERO) + .unwrap_err(); + assert!(matches!(err, Error::InvalidConfig(_))); + } - Ok(app_path.join(path)) + #[test] + fn test_transaction_timeout_accepts_positive() { + let builder = Builder::new() + .transaction_timeout(std::time::Duration::from_secs(1)) + .unwrap(); + assert_eq!( + builder.transaction_timeout, + Some(std::time::Duration::from_secs(1)) + ); + } } diff --git a/src/resolve.rs b/src/resolve.rs index bc67a12..08d3b25 100644 --- a/src/resolve.rs +++ b/src/resolve.rs @@ -1,5 +1,5 @@ use std::fs::create_dir_all; -use std::path::PathBuf; +use std::path::{Component, Path, PathBuf}; use sqlx_sqlite_conn_mgr::SqliteDatabaseConfig; use sqlx_sqlite_toolkit::DatabaseWrapper; @@ -23,8 +23,11 @@ pub async fn connect( /// Resolve database file path relative to app config directory. /// -/// Paths are joined to `app_config_dir()` (e.g., `Library/Application Support/${bundleIdentifier}` on iOS). -/// Special paths like `:memory:` are passed through unchanged. +/// Paths are joined to `app_config_dir()` (e.g., `Library/Application Support/${bundleIdentifier}` +/// on iOS). Special paths like `:memory:` are passed through unchanged. +/// +/// Returns `Err(Error::PathTraversal)` if the path attempts to escape the app config directory +/// via absolute paths, `..` segments, or null bytes. pub fn resolve_database_path(path: &str, app: &AppHandle) -> Result { let app_path = app .path() @@ -33,6 +36,186 @@ pub fn resolve_database_path(path: &str, app: &AppHandle) -> Resu create_dir_all(&app_path)?; - // Join the relative path to the app config directory - Ok(app_path.join(path)) + validate_and_resolve(path, &app_path) +} + +/// Validate a user-supplied path and resolve it against a base directory. +/// +/// In-memory database paths are passed through unchanged. All other paths are validated +/// to ensure they cannot escape the base directory. +fn validate_and_resolve(path: &str, base: &Path) -> Result { + // Pass through in-memory database paths unchanged — they don't touch the filesystem. + // Matches the same patterns as `is_memory_database` in sqlx-sqlite-conn-mgr. + if is_memory_path(path) { + return Ok(PathBuf::from(path)); + } + + // Reject null bytes — these can truncate paths in C-level filesystem calls + if path.contains('\0') { + return Err(Error::PathTraversal("path contains null byte".to_string())); + } + + let rel = Path::new(path); + + // Reject absolute paths — PathBuf::join replaces the base when given an absolute path + if rel.is_absolute() { + return Err(Error::PathTraversal( + "absolute paths are not allowed".to_string(), + )); + } + + // Reject parent directory components — prevents escaping the base via `../` + for component in rel.components() { + if matches!(component, Component::ParentDir) { + return Err(Error::PathTraversal( + "parent directory references are not allowed".to_string(), + )); + } + } + + // Join and canonicalize to verify containment. The parent directory is canonicalized + // because the file may not exist yet. + let joined = base.join(rel); + let canonical_base = base + .canonicalize() + .map_err(|e| Error::InvalidPath(format!("cannot canonicalize base path: {e}")))?; + + let canonical_resolved = if joined.exists() { + joined.canonicalize() + } else { + // Ensure intermediate directories exist so that canonicalize can resolve the + // parent. This matches the caller's expectation that nested relative paths like + // "subdir/mydb.db" work without pre-creating "subdir/". + let parent = joined + .parent() + .ok_or_else(|| Error::InvalidPath("path has no parent".to_string()))?; + + create_dir_all(parent)?; + + parent + .canonicalize() + .map(|p| p.join(joined.file_name().unwrap_or_default())) + } + .map_err(|e| Error::InvalidPath(format!("cannot canonicalize path: {e}")))?; + + if !canonical_resolved.starts_with(&canonical_base) { + return Err(Error::PathTraversal( + "resolved path escapes the base directory".to_string(), + )); + } + + // Return the original (non-canonicalized) joined path for consistency with how the + // rest of the codebase references database paths. + Ok(joined) +} + +/// Check if a path string represents an in-memory SQLite database. +/// +/// Matches the same patterns as `is_memory_database` in `sqlx-sqlite-conn-mgr`: +/// `:memory:`, `file::memory:*` URIs, and `mode=memory` query parameters. +fn is_memory_path(path: &str) -> bool { + path == ":memory:" + || path.starts_with("file::memory:") + || (path.starts_with("file:") && path.contains("mode=memory")) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + /// Helper that creates a temporary base directory for testing. + fn make_temp_base() -> PathBuf { + let dir = std::env::temp_dir().join(format!("tauri_sqlite_test_{}", std::process::id())); + fs::create_dir_all(&dir).unwrap(); + dir + } + + #[test] + fn test_simple_filename() { + let base = make_temp_base(); + let result = validate_and_resolve("mydb.db", &base).unwrap(); + assert_eq!(result, base.join("mydb.db")); + } + + #[test] + fn test_subdirectory_path() { + let base = make_temp_base(); + // Intermediate directories are auto-created — no manual setup needed + let result = validate_and_resolve("subdir/mydb.db", &base).unwrap(); + assert_eq!(result, base.join("subdir/mydb.db")); + assert!(base.join("subdir").is_dir()); + } + + #[test] + fn test_nested_subdirectory_path() { + let base = make_temp_base(); + let result = validate_and_resolve("a/b/c/mydb.db", &base).unwrap(); + assert_eq!(result, base.join("a/b/c/mydb.db")); + assert!(base.join("a/b/c").is_dir()); + } + + #[test] + fn test_memory_passthrough() { + let base = make_temp_base(); + assert_eq!( + validate_and_resolve(":memory:", &base).unwrap(), + PathBuf::from(":memory:"), + ); + } + + #[test] + fn test_file_memory_uri_passthrough() { + let base = make_temp_base(); + assert_eq!( + validate_and_resolve("file::memory:?cache=shared", &base).unwrap(), + PathBuf::from("file::memory:?cache=shared"), + ); + } + + #[test] + fn test_mode_memory_passthrough() { + let base = make_temp_base(); + assert_eq!( + validate_and_resolve("file:test?mode=memory", &base).unwrap(), + PathBuf::from("file:test?mode=memory"), + ); + } + + #[test] + fn test_rejects_parent_traversal() { + let base = make_temp_base(); + let err = validate_and_resolve("../../../etc/passwd", &base).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_absolute_path() { + let base = make_temp_base(); + let err = validate_and_resolve("/etc/passwd", &base).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_embedded_traversal() { + let base = make_temp_base(); + let err = validate_and_resolve("foo/../../bar", &base).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_null_byte() { + let base = make_temp_base(); + let err = validate_and_resolve("path\0evil", &base).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_non_uri_mode_memory() { + let base = make_temp_base(); + // A bare filename containing "mode=memory" is not a valid SQLite URI — + // it should go through normal path validation, not be passed through. + let result = validate_and_resolve("evil.db?mode=memory", &base).unwrap(); + assert_eq!(result, base.join("evil.db?mode=memory")); + } } diff --git a/src/subscriptions.rs b/src/subscriptions.rs index bc5176a..2805d6c 100644 --- a/src/subscriptions.rs +++ b/src/subscriptions.rs @@ -161,6 +161,12 @@ impl ActiveSubscriptions { } } + /// Count active subscriptions for a specific database. + pub async fn count_for_db(&self, db_path: &str) -> usize { + let subs = self.0.read().await; + subs.values().filter(|sub| sub.db_path == db_path).count() + } + /// Abort all subscriptions (for cleanup on app exit). pub async fn abort_all(&self) { let mut subs = self.0.write().await;