From bb3089e795b50be822757dc4b95ac70ba58330c1 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Fri, 6 Feb 2026 10:28:30 +0000 Subject: [PATCH 1/6] fix: replace string-based SQL injection in inject_block_filter with AST manipulation The previous implementation used string position matching (finding WHERE, ORDER BY, LIMIT keywords) and format! interpolation to splice block_num filters into user-provided SQL. This was vulnerable to structural SQL injection where crafted queries could exploit the naive keyword matching (e.g. WHERE inside string literals, UNION bypasses). Replace with sqlparser AST parsing and manipulation: - Parse user SQL into AST, requiring a single simple SELECT statement - Determine filter column from the FROM table (num for blocks, block_num for others) - Safely AND the block filter into the existing WHERE clause (or add one) - Serialize modified AST back to SQL Also: - Reject UNION/INTERSECT/set operations in live mode (ambiguous filtering) - Return Result instead of String for proper error handling - Add Display impl for ApiError - Add tests for UNION rejection, non-SELECT rejection, and WHERE keyword in string literals Amp-Thread-ID: https://ampcode.com/threads/T-019c3272-f632-763c-8078-504a90852a67 Co-authored-by: Amp --- src/api/mod.rs | 141 +++++++++++++++++++++++++++++------------ tests/api_live_test.rs | 40 +++++++++--- 2 files changed, 129 insertions(+), 52 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 07a05b69..25300174 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -477,7 +477,16 @@ async fn handle_query_live( } else { let catch_up_start = last_block_num + 1; for block_num in catch_up_start..=end { - let filtered_sql = inject_block_filter(&sql, block_num); + let filtered_sql = match inject_block_filter(&sql, block_num) { + Ok(s) => s, + Err(e) => { + yield Ok(SseEvent::default() + .event("error") + .json_data(serde_json::json!({ "ok": false, "error": e.to_string() })) + .unwrap()); + return; + } + }; match crate::service::execute_query_postgres(&pool, &filtered_sql, signature.as_deref(), &options).await { Ok(result) => { yield Ok(SseEvent::default() @@ -516,50 +525,85 @@ async fn handle_query_live( /// Inject a block number filter into SQL query for live streaming. /// Transforms queries to only return data for the specific block. /// Uses 'num' for blocks table, 'block_num' for txs/logs tables. +/// +/// Uses sqlparser AST manipulation to safely add the filter condition, +/// avoiding SQL injection risks from string-based splicing. #[doc(hidden)] -pub fn inject_block_filter(sql: &str, block_num: u64) -> String { - let sql_upper = sql.to_uppercase(); - - // Determine column name based on table being queried - let col = if sql_upper.contains("FROM BLOCKS") || sql_upper.contains("FROM \"BLOCKS\"") { - "num" - } else { - "block_num" +pub fn inject_block_filter(sql: &str, block_num: u64) -> Result { + use sqlparser::ast::{ + BinaryOperator, Expr, Ident, SetExpr, Statement, Value, }; - - // Find WHERE clause position - if let Some(where_pos) = sql_upper.find("WHERE") { - // Insert after WHERE - let insert_pos = where_pos + 5; - format!( - "{} {} = {} AND {}", - &sql[..insert_pos], - col, - block_num, - &sql[insert_pos..] - ) - } else if let Some(order_pos) = sql_upper.find("ORDER BY") { - // Insert WHERE before ORDER BY - format!( - "{} WHERE {} = {} {}", - &sql[..order_pos], - col, - block_num, - &sql[order_pos..] - ) - } else if let Some(limit_pos) = sql_upper.find("LIMIT") { - // Insert WHERE before LIMIT - format!( - "{} WHERE {} = {} {}", - &sql[..limit_pos], - col, - block_num, - &sql[limit_pos..] - ) - } else { - // Append WHERE at end - format!("{sql} WHERE {col} = {block_num}") + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| ApiError::BadRequest(format!("SQL parse error: {e}")))?; + + if statements.len() != 1 { + return Err(ApiError::BadRequest( + "Live mode requires exactly one SQL statement".to_string(), + )); } + + let stmt = &mut statements[0]; + let query = match stmt { + Statement::Query(q) => q, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a SELECT query".to_string(), + )) + } + }; + + let select = match query.body.as_mut() { + SetExpr::Select(s) => s, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a simple SELECT query (UNION/INTERSECT not supported)" + .to_string(), + )) + } + }; + + let table_name: String = select + .from + .first() + .and_then(|twj| match &twj.relation { + sqlparser::ast::TableFactor::Table { name, .. } => { + name.0.last().and_then(|part| part.as_ident()).map(|ident| ident.value.to_lowercase()) + } + _ => None, + }) + .ok_or_else(|| { + ApiError::BadRequest( + "Live mode requires a query with a FROM table clause".to_string(), + ) + })?; + + let col_name = if table_name == "blocks" { "num" } else { "block_num" }; + + let col_expr = Expr::CompoundIdentifier(vec![ + Ident::new(&table_name), + Ident::new(col_name), + ]); + + let block_filter = Expr::BinaryOp { + left: Box::new(col_expr), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number(block_num.to_string(), false).into())), + }; + + select.selection = Some(match select.selection.take() { + Some(existing) => Expr::BinaryOp { + left: Box::new(Expr::Nested(Box::new(existing))), + op: BinaryOperator::And, + right: Box::new(block_filter), + }, + None => block_filter, + }); + + Ok(stmt.to_string()) } /// Rewrite analytics table references to include chain-specific database prefix. @@ -599,6 +643,19 @@ pub enum ApiError { NotFound(String), } +impl std::fmt::Display for ApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiError::BadRequest(msg) => write!(f, "{msg}"), + ApiError::Timeout => write!(f, "Query timeout"), + ApiError::QueryError(msg) => write!(f, "{msg}"), + ApiError::Internal(msg) => write!(f, "{msg}"), + ApiError::Forbidden(msg) => write!(f, "{msg}"), + ApiError::NotFound(msg) => write!(f, "{msg}"), + } + } +} + impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let (status, message) = match self { diff --git a/tests/api_live_test.rs b/tests/api_live_test.rs index 69e4de0b..4608cd44 100644 --- a/tests/api_live_test.rs +++ b/tests/api_live_test.rs @@ -361,37 +361,57 @@ async fn test_query_live_returns_sse() { #[test] fn test_inject_block_filter_blocks_table() { let sql = "SELECT num, hash FROM blocks ORDER BY num DESC LIMIT 1"; - let filtered = inject_block_filter(sql, 100); - assert!(filtered.contains("num = 100"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("blocks.num = 100"), "got: {filtered}"); assert!(filtered.contains("ORDER BY"), "should preserve ORDER BY"); } #[test] fn test_inject_block_filter_txs_table() { let sql = "SELECT * FROM txs ORDER BY block_num DESC LIMIT 10"; - let filtered = inject_block_filter(sql, 200); - assert!(filtered.contains("block_num = 200"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 200).unwrap(); + assert!(filtered.contains("txs.block_num = 200"), "got: {filtered}"); } #[test] fn test_inject_block_filter_logs_table() { let sql = "SELECT * FROM logs WHERE address = '0x123' ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 300); - assert!(filtered.contains("block_num = 300"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 300).unwrap(); + assert!(filtered.contains("logs.block_num = 300"), "got: {filtered}"); assert!(filtered.contains("address = '0x123'"), "should preserve existing WHERE"); } #[test] fn test_inject_block_filter_with_existing_where() { let sql = "SELECT * FROM txs WHERE gas_used > 21000 ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 400); - assert!(filtered.contains("block_num = 400"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 400).unwrap(); + assert!(filtered.contains("txs.block_num = 400"), "got: {filtered}"); assert!(filtered.contains("gas_used > 21000"), "should preserve existing condition"); } #[test] fn test_inject_block_filter_no_order_by() { let sql = "SELECT COUNT(*) FROM blocks LIMIT 1"; - let filtered = inject_block_filter(sql, 500); - assert!(filtered.contains("num = 500"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 500).unwrap(); + assert!(filtered.contains("blocks.num = 500"), "got: {filtered}"); +} + +#[test] +fn test_inject_block_filter_rejects_union() { + let sql = "SELECT * FROM txs UNION SELECT * FROM logs"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_rejects_non_select() { + let sql = "INSERT INTO txs VALUES (1)"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_where_keyword_in_string_literal() { + let sql = "SELECT * FROM txs WHERE input = 'WHERE clause test'"; + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("txs.block_num = 100"), "got: {filtered}"); + assert!(filtered.contains("'WHERE clause test'"), "should preserve string literal"); } From 9ee9e8c90af5b88567840de2d1bc16c779f1a179 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Fri, 6 Feb 2026 10:44:44 +0000 Subject: [PATCH 2/6] fix: add table allowlist to query validator and read-only API role Replace the blocklist-only approach with a table allowlist so API users can only query: blocks, txs, logs, receipts, token_holders, token_balances, and CTE-defined tables. Previously, users could query sync_state (internal), pg_tables (schema enumeration), or any other table accessible to the tidx DB user. Changes: - Allowlist in validator: only permitted tables + CTE-defined names pass - Block dblink function family (cross-database access) - Add db/api_role.sql migration creating a tidx_api read-only role with SELECT-only grants on indexed tables (defense-in-depth) - Thread CTE names through all validate_* functions - 6 new tests: sync_state rejected, pg_tables rejected, unknown table rejected, CTE tables allowed, dblink blocked, analytics tables allowed Amp-Thread-ID: https://ampcode.com/threads/T-019c3272-f632-763c-8078-504a90852a67 Co-authored-by: Amp --- db/api_role.sql | 26 ++++++ src/db/schema.rs | 3 + src/query/validator.rs | 178 +++++++++++++++++++++++++++++------------ 3 files changed, 157 insertions(+), 50 deletions(-) create mode 100644 db/api_role.sql diff --git a/db/api_role.sql b/db/api_role.sql new file mode 100644 index 00000000..e471e19b --- /dev/null +++ b/db/api_role.sql @@ -0,0 +1,26 @@ +-- Create a read-only role for API query connections. +-- The API should connect as this role to provide defense-in-depth +-- against SQL injection, even if the query validator is bypassed. +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'tidx_api') THEN + CREATE ROLE tidx_api WITH LOGIN PASSWORD 'tidx_api' NOSUPERUSER NOCREATEDB NOCREATEROLE; + END IF; +END $$; + +-- Grant read-only access to indexed tables only +GRANT USAGE ON SCHEMA public TO tidx_api; +GRANT SELECT ON blocks, txs, logs, receipts TO tidx_api; + +-- Allow calling ABI decode helper functions +GRANT EXECUTE ON FUNCTION abi_uint(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_int(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bool(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bytes(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_string(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_uint(bytea) TO tidx_api; + +-- Revoke everything else (defense-in-depth) +REVOKE ALL ON sync_state FROM tidx_api; diff --git a/src/db/schema.rs b/src/db/schema.rs index 1c51511b..9cc11af7 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -39,6 +39,9 @@ pub async fn run_migrations(pool: &Pool) -> Result<()> { // Load any optional extensions conn.batch_execute(include_str!("../../db/extensions.sql")).await?; + // Create read-only API role with SELECT-only access to indexed tables + conn.batch_execute(include_str!("../../db/api_role.sql")).await?; + Ok(()) } diff --git a/src/query/validator.rs b/src/query/validator.rs index 87a2a881..1034cf05 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use anyhow::{anyhow, Result}; use sqlparser::ast::{ Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, ObjectName, Query, SetExpr, @@ -6,6 +8,15 @@ use sqlparser::ast::{ use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; +const ALLOWED_TABLES: &[&str] = &[ + "blocks", + "txs", + "logs", + "receipts", + "token_holders", + "token_balances", +]; + /// Validates that a SQL query is safe to execute. /// /// Rejects: @@ -30,48 +41,64 @@ pub fn validate_query(sql: &str) -> Result<()> { let stmt = &statements[0]; match stmt { - Statement::Query(query) => validate_query_ast(query), + Statement::Query(query) => { + let cte_names = extract_cte_names(query); + validate_query_ast(query, &cte_names) + } _ => Err(anyhow!("Only SELECT queries are allowed")), } } -fn validate_query_ast(query: &Query) -> Result<()> { - // Check CTEs for data-modifying statements +fn extract_cte_names(query: &Query) -> HashSet { + let mut names = HashSet::new(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + names.insert(cte.alias.name.value.to_lowercase()); + } + } + names +} + +fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> { + let mut all_cte_names = cte_names.clone(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + all_cte_names.insert(cte.alias.name.value.to_lowercase()); + } + } + for cte in &query.with.as_ref().map_or(vec![], |w| w.cte_tables.clone()) { - validate_query_ast(&cte.query)?; + validate_query_ast(&cte.query, &all_cte_names)?; } - validate_set_expr(&query.body) + validate_set_expr(&query.body, &all_cte_names) } -fn validate_set_expr(set_expr: &SetExpr) -> Result<()> { +fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result<()> { match set_expr { SetExpr::Select(select) => { - // Validate FROM clause for table in &select.from { - validate_table_with_joins(table)?; + validate_table_with_joins(table, cte_names)?; } - // Validate SELECT expressions for item in &select.projection { if let sqlparser::ast::SelectItem::UnnamedExpr(expr) | sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } = item { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; } } - // Validate WHERE clause if let Some(selection) = &select.selection { - validate_expr(selection)?; + validate_expr(selection, cte_names)?; } Ok(()) } - SetExpr::Query(q) => validate_query_ast(q), + SetExpr::Query(q) => validate_query_ast(q, cte_names), SetExpr::SetOperation { left, right, .. } => { - validate_set_expr(left)?; - validate_set_expr(right) + validate_set_expr(left, cte_names)?; + validate_set_expr(right, cte_names) } SetExpr::Values(_) => Ok(()), SetExpr::Insert(_) => Err(anyhow!("INSERT not allowed")), @@ -82,29 +109,27 @@ fn validate_set_expr(set_expr: &SetExpr) -> Result<()> { } } -fn validate_table_with_joins(table: &TableWithJoins) -> Result<()> { - validate_table_factor(&table.relation)?; +fn validate_table_with_joins(table: &TableWithJoins, cte_names: &HashSet) -> Result<()> { + validate_table_factor(&table.relation, cte_names)?; for join in &table.joins { - validate_table_factor(&join.relation)?; + validate_table_factor(&join.relation, cte_names)?; } Ok(()) } -fn validate_table_factor(factor: &TableFactor) -> Result<()> { +fn validate_table_factor(factor: &TableFactor, cte_names: &HashSet) -> Result<()> { match factor { TableFactor::Table { name, args, .. } => { - // Check if this is a table-valued function like read_csv(...) if args.is_some() { let func_name = name.to_string().to_lowercase(); if is_dangerous_table_function(&func_name) { return Err(anyhow!("Table function '{func_name}' is not allowed")); } } - validate_table_name(name) + validate_table_name(name, cte_names) } - TableFactor::Derived { subquery, .. } => validate_query_ast(subquery), + TableFactor::Derived { subquery, .. } => validate_query_ast(subquery, cte_names), TableFactor::TableFunction { expr, .. } => { - // Block table functions that can read filesystem if let Expr::Function(func) = expr { let func_name = func.name.to_string().to_lowercase(); if is_dangerous_table_function(&func_name) { @@ -121,16 +146,15 @@ fn validate_table_factor(factor: &TableFactor) -> Result<()> { Ok(()) } TableFactor::NestedJoin { table_with_joins, .. } => { - validate_table_with_joins(table_with_joins) + validate_table_with_joins(table_with_joins, cte_names) } _ => Ok(()), } } -fn validate_table_name(name: &ObjectName) -> Result<()> { +fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result<()> { let full_name = name.to_string().to_lowercase(); - // Block system catalogs const BLOCKED_SCHEMAS: &[&str] = &[ "pg_catalog", "information_schema", @@ -144,7 +168,6 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { } } - // Block specific dangerous tables const BLOCKED_TABLES: &[&str] = &[ "pg_stat_activity", "pg_settings", @@ -160,44 +183,57 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { } } - Ok(()) + let bare_name = name.0.last() + .and_then(|part| part.as_ident()) + .map(|ident| ident.value.to_lowercase()) + .unwrap_or_default(); + + if ALLOWED_TABLES.contains(&bare_name.as_str()) { + return Ok(()); + } + + if cte_names.contains(&bare_name) { + return Ok(()); + } + + Err(anyhow!("Access to table '{bare_name}' is not allowed")) } -fn validate_expr(expr: &Expr) -> Result<()> { +fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { match expr { - Expr::Function(func) => validate_function(func), - Expr::Subquery(q) => validate_query_ast(q), - Expr::InSubquery { subquery, .. } => validate_query_ast(subquery), - Expr::Exists { subquery, .. } => validate_query_ast(subquery), + Expr::Function(func) => validate_function(func, cte_names), + Expr::Subquery(q) => validate_query_ast(q, cte_names), + Expr::InSubquery { subquery, .. } => validate_query_ast(subquery, cte_names), + Expr::Exists { subquery, .. } => validate_query_ast(subquery, cte_names), Expr::BinaryOp { left, right, .. } => { - validate_expr(left)?; - validate_expr(right) + validate_expr(left, cte_names)?; + validate_expr(right, cte_names) } - Expr::UnaryOp { expr, .. } => validate_expr(expr), + Expr::UnaryOp { expr, .. } => validate_expr(expr, cte_names), Expr::Between { expr, low, high, .. } => { - validate_expr(expr)?; - validate_expr(low)?; - validate_expr(high) + validate_expr(expr, cte_names)?; + validate_expr(low, cte_names)?; + validate_expr(high, cte_names) } Expr::Case { operand, conditions, else_result, .. } => { if let Some(op) = operand { - validate_expr(op)?; + validate_expr(op, cte_names)?; } for case_when in conditions { - validate_expr(&case_when.condition)?; - validate_expr(&case_when.result)?; + validate_expr(&case_when.condition, cte_names)?; + validate_expr(&case_when.result, cte_names)?; } if let Some(else_r) = else_result { - validate_expr(else_r)?; + validate_expr(else_r, cte_names)?; } Ok(()) } - Expr::Cast { expr, .. } => validate_expr(expr), - Expr::Nested(e) => validate_expr(e), + Expr::Cast { expr, .. } => validate_expr(expr, cte_names), + Expr::Nested(e) => validate_expr(e, cte_names), Expr::InList { expr, list, .. } => { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; for item in list { - validate_expr(item)?; + validate_expr(item, cte_names)?; } Ok(()) } @@ -205,20 +241,19 @@ fn validate_expr(expr: &Expr) -> Result<()> { } } -fn validate_function(func: &Function) -> Result<()> { +fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); if is_dangerous_function(&func_name) { return Err(anyhow!("Function '{func_name}' is not allowed")); } - // Recursively validate function arguments if let FunctionArguments::List(arg_list) = &func.args { for arg in &arg_list.args { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) | FunctionArg::Named { arg: FunctionArgExpr::Expr(expr), .. } = arg { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; } } } @@ -250,6 +285,12 @@ fn is_dangerous_function(name: &str) -> bool { "lo_export", // PostgreSQL command execution "pg_execute_server_program", + // PostgreSQL dblink (remote connections) + "dblink", + "dblink_exec", + "dblink_connect", + "dblink_send_query", + "dblink_get_result", // ClickHouse system functions "system.flush_logs", "system.reload_config", @@ -403,4 +444,41 @@ mod tests { fn test_rejects_nested_dangerous_function() { assert!(validate_query("SELECT COALESCE(pg_sleep(1), 0)").is_err()); } + + #[test] + fn test_rejects_sync_state() { + assert!(validate_query("SELECT * FROM sync_state").is_err()); + } + + #[test] + fn test_rejects_pg_tables() { + assert!(validate_query("SELECT * FROM pg_tables").is_err()); + } + + #[test] + fn test_rejects_unknown_table() { + assert!(validate_query("SELECT * FROM some_random_table").is_err()); + } + + #[test] + fn test_allows_cte_defined_table() { + assert!(validate_query( + "WITH my_cte AS (SELECT * FROM blocks) SELECT * FROM my_cte" + ) + .is_ok()); + } + + #[test] + fn test_rejects_dblink() { + assert!(validate_query("SELECT * FROM dblink('host=evil dbname=secrets', 'SELECT * FROM passwords')").is_err()); + assert!(validate_query("SELECT dblink_connect('myconn', 'host=evil')").is_err()); + assert!(validate_query("SELECT dblink_exec('myconn', 'DROP TABLE blocks')").is_err()); + } + + #[test] + fn test_allows_analytics_tables() { + assert!(validate_query("SELECT * FROM token_holders").is_ok()); + assert!(validate_query("SELECT * FROM token_balances").is_ok()); + assert!(validate_query("SELECT * FROM public.blocks").is_ok()); + } } From 4038458e2ba8640653909f2314495a8e22ddcf5c Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Fri, 6 Feb 2026 10:51:56 +0000 Subject: [PATCH 3/6] fix: close DoS, privilege escalation, and file read bypass vectors Block three categories of attacks: 1. DoS via resource exhaustion: - Reject WITH RECURSIVE (endless loop CTEs) - Block generate_series() (billion-row generation) - Block SELECT INTO (object creation) 2. Privilege escalation / validator bypass: - Validate expressions inside VALUES rows (previously VALUES(pg_sleep(10)) bypassed the entire function blocklist) - Reject TABLE statement (TABLE pg_shadow bypassed table allowlist) - Validate GROUP BY, HAVING, JOIN ON expressions (could hide function calls) - Walk IsNull/IsNotNull/IsTrue/IsFalse/Like expressions recursively 3. File read hardening: - Block lo_get/lo_open/lo_close/loread/lo_creat/lo_create/lo_unlink/lo_put - Block pg_file_read/pg_file_write/pg_file_rename/pg_file_unlink/pg_logdir_ls - VALUES bypass closure prevents pg_read_file via VALUES(...) 11 new tests covering all vectors. Amp-Thread-ID: https://ampcode.com/threads/T-019c3272-f632-763c-8078-504a90852a67 Co-authored-by: Amp --- src/query/validator.rs | 147 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 2 deletions(-) diff --git a/src/query/validator.rs b/src/query/validator.rs index 1034cf05..b0847412 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -60,6 +60,13 @@ fn extract_cte_names(query: &Query) -> HashSet { } fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> { + // Block recursive CTEs (can cause endless loops / resource exhaustion) + if let Some(with) = &query.with { + if with.recursive { + return Err(anyhow!("Recursive CTEs are not allowed")); + } + } + let mut all_cte_names = cte_names.clone(); if let Some(with) = &query.with { for cte in &with.cte_tables { @@ -77,6 +84,11 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result<()> { match set_expr { SetExpr::Select(select) => { + // Reject SELECT INTO (creates objects) + if select.into.is_some() { + return Err(anyhow!("SELECT INTO is not allowed")); + } + for table in &select.from { validate_table_with_joins(table, cte_names)?; } @@ -93,6 +105,18 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result< validate_expr(selection, cte_names)?; } + // Validate GROUP BY expressions + if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by { + for expr in exprs { + validate_expr(expr, cte_names)?; + } + } + + // Validate HAVING + if let Some(having) = &select.having { + validate_expr(having, cte_names)?; + } + Ok(()) } SetExpr::Query(q) => validate_query_ast(q, cte_names), @@ -100,12 +124,20 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result< validate_set_expr(left, cte_names)?; validate_set_expr(right, cte_names) } - SetExpr::Values(_) => Ok(()), + SetExpr::Values(values) => { + // Validate all expressions in VALUES rows to prevent function call bypass + for row in &values.rows { + for expr in row { + validate_expr(expr, cte_names)?; + } + } + Ok(()) + } SetExpr::Insert(_) => Err(anyhow!("INSERT not allowed")), SetExpr::Update(_) => Err(anyhow!("UPDATE not allowed")), SetExpr::Delete(_) => Err(anyhow!("DELETE not allowed")), SetExpr::Merge(_) => Err(anyhow!("MERGE not allowed")), - SetExpr::Table(_) => Ok(()), + SetExpr::Table(_) => Err(anyhow!("TABLE statement is not allowed")), } } @@ -113,6 +145,27 @@ fn validate_table_with_joins(table: &TableWithJoins, cte_names: &HashSet validate_table_factor(&table.relation, cte_names)?; for join in &table.joins { validate_table_factor(&join.relation, cte_names)?; + // Validate JOIN ON expressions + let constraint = match &join.join_operator { + sqlparser::ast::JoinOperator::Join(c) + | sqlparser::ast::JoinOperator::Inner(c) + | sqlparser::ast::JoinOperator::Left(c) + | sqlparser::ast::JoinOperator::LeftOuter(c) + | sqlparser::ast::JoinOperator::Right(c) + | sqlparser::ast::JoinOperator::RightOuter(c) + | sqlparser::ast::JoinOperator::FullOuter(c) + | sqlparser::ast::JoinOperator::CrossJoin(c) + | sqlparser::ast::JoinOperator::Semi(c) + | sqlparser::ast::JoinOperator::LeftSemi(c) + | sqlparser::ast::JoinOperator::RightSemi(c) + | sqlparser::ast::JoinOperator::Anti(c) + | sqlparser::ast::JoinOperator::LeftAnti(c) + | sqlparser::ast::JoinOperator::RightAnti(c) => Some(c), + _ => None, + }; + if let Some(sqlparser::ast::JoinConstraint::On(expr)) = constraint { + validate_expr(expr, cte_names)?; + } } Ok(()) } @@ -237,6 +290,19 @@ fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { } Ok(()) } + Expr::IsNull(e) + | Expr::IsNotNull(e) + | Expr::IsTrue(e) + | Expr::IsFalse(e) + | Expr::IsNotTrue(e) + | Expr::IsNotFalse(e) + | Expr::IsUnknown(e) + | Expr::IsNotUnknown(e) => validate_expr(e, cte_names), + Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => { + validate_expr(expr, cte_names)?; + validate_expr(pattern, cte_names) + } + Expr::AnyOp { right, .. } | Expr::AllOp { right, .. } => validate_expr(right, cte_names), _ => Ok(()), } } @@ -291,6 +357,24 @@ fn is_dangerous_function(name: &str) -> bool { "dblink_connect", "dblink_send_query", "dblink_get_result", + // PostgreSQL large object access + "lo_get", + "lo_open", + "lo_close", + "loread", + "lowrite", + "lo_creat", + "lo_create", + "lo_unlink", + "lo_put", + // PostgreSQL set-returning functions (DoS via row generation) + "generate_series", + // PostgreSQL admin extension functions (file access) + "pg_file_read", + "pg_file_write", + "pg_file_rename", + "pg_file_unlink", + "pg_logdir_ls", // ClickHouse system functions "system.flush_logs", "system.reload_config", @@ -481,4 +565,63 @@ mod tests { assert!(validate_query("SELECT * FROM token_balances").is_ok()); assert!(validate_query("SELECT * FROM public.blocks").is_ok()); } + + #[test] + fn test_rejects_recursive_cte() { + assert!(validate_query( + "WITH RECURSIVE r AS (SELECT 1 AS n UNION ALL SELECT n+1 FROM r) SELECT * FROM r" + ).is_err()); + } + + #[test] + fn test_rejects_generate_series() { + assert!(validate_query("SELECT generate_series(1, 1000000000)").is_err()); + assert!(validate_query("SELECT * FROM blocks WHERE num IN (SELECT generate_series(1, 1000000))").is_err()); + } + + #[test] + fn test_rejects_values_function_bypass() { + assert!(validate_query("VALUES (pg_sleep(10))").is_err()); + assert!(validate_query("VALUES (pg_read_file('/etc/passwd'))").is_err()); + } + + #[test] + fn test_rejects_table_statement() { + assert!(validate_query("TABLE blocks").is_err()); + assert!(validate_query("TABLE pg_shadow").is_err()); + } + + #[test] + fn test_rejects_select_into() { + assert!(validate_query("SELECT * INTO newtable FROM blocks").is_err()); + } + + #[test] + fn test_rejects_lo_functions() { + assert!(validate_query("SELECT lo_get(12345)").is_err()); + assert!(validate_query("SELECT lo_open(12345, 262144)").is_err()); + } + + #[test] + fn test_rejects_admin_file_functions() { + assert!(validate_query("SELECT pg_file_read('/etc/passwd', 0, 1000)").is_err()); + assert!(validate_query("SELECT pg_file_write('/tmp/evil', 'data', false)").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_having() { + assert!(validate_query("SELECT COUNT(*) FROM blocks GROUP BY num HAVING pg_sleep(1) IS NOT NULL").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_join_on() { + assert!(validate_query( + "SELECT * FROM blocks JOIN txs ON pg_sleep(1) IS NOT NULL" + ).is_err()); + } + + #[test] + fn test_allows_simple_values() { + assert!(validate_query("VALUES (1, 'hello'), (2, 'world')").is_ok()); + } } From 589222daa89a1b8c6800da836c6debf2a44012f3 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Tue, 10 Feb 2026 22:27:15 +0000 Subject: [PATCH 4/6] fix: replace function blocklist with allowlist and harden API role Switch from a function blocklist (reject known-bad) to an allowlist (permit known-good only). This eliminates the risk of missing dangerous functions as PostgreSQL adds new ones. Validator changes: - ALLOWED_FUNCTIONS allowlist: ABI helpers, aggregates, scalars, string, numeric, time, window functions, and type casting - Reject ALL table functions (FROM func(...)) unconditionally - Reject unsupported TableFactor variants (catch-all _ => Err) - Remove is_dangerous_function() and is_dangerous_table_function() API role hardening (db/api_role.sql): - Deny-by-default: REVOKE ALL on tables, sequences, and functions before granting specific access - Add token_holders and token_balances to SELECT grants - CONNECTION LIMIT 64 - statement_timeout = 30s, work_mem = 64MB, temp_file_limit = 256MB 5 new tests, all 147 lib tests passing. Amp-Thread-ID: https://ampcode.com/threads/T-019c499a-f07e-73a9-9526-6c18fd511372 Co-authored-by: Amp --- db/api_role.sql | 23 ++++-- src/query/validator.rs | 181 +++++++++++++++-------------------------- 2 files changed, 83 insertions(+), 121 deletions(-) diff --git a/db/api_role.sql b/db/api_role.sql index e471e19b..c709bcb1 100644 --- a/db/api_role.sql +++ b/db/api_role.sql @@ -1,6 +1,7 @@ --- Create a read-only role for API query connections. --- The API should connect as this role to provide defense-in-depth --- against SQL injection, even if the query validator is bypassed. +-- Read-only API role for query connections. +-- Defense-in-depth: even if the query validator is bypassed, +-- this role cannot modify data, execute arbitrary functions, +-- or exhaust server resources. DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'tidx_api') THEN @@ -8,11 +9,16 @@ BEGIN END IF; END $$; +-- Revoke all privileges first (deny-by-default) +REVOKE ALL ON ALL TABLES IN SCHEMA public FROM tidx_api; +REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM tidx_api; +REVOKE EXECUTE ON ALL FUNCTIONS IN SCHEMA public FROM tidx_api; + -- Grant read-only access to indexed tables only GRANT USAGE ON SCHEMA public TO tidx_api; -GRANT SELECT ON blocks, txs, logs, receipts TO tidx_api; +GRANT SELECT ON blocks, txs, logs, receipts, token_holders, token_balances TO tidx_api; --- Allow calling ABI decode helper functions +-- Grant execute only on ABI decode helper functions GRANT EXECUTE ON FUNCTION abi_uint(bytea) TO tidx_api; GRANT EXECUTE ON FUNCTION abi_int(bytea) TO tidx_api; GRANT EXECUTE ON FUNCTION abi_address(bytea) TO tidx_api; @@ -22,5 +28,8 @@ GRANT EXECUTE ON FUNCTION abi_string(bytea, int) TO tidx_api; GRANT EXECUTE ON FUNCTION format_address(bytea) TO tidx_api; GRANT EXECUTE ON FUNCTION format_uint(bytea) TO tidx_api; --- Revoke everything else (defense-in-depth) -REVOKE ALL ON sync_state FROM tidx_api; +-- Resource limits (prevent DoS) +ALTER ROLE tidx_api CONNECTION LIMIT 64; +ALTER ROLE tidx_api SET statement_timeout = '30s'; +ALTER ROLE tidx_api SET work_mem = '64MB'; +ALTER ROLE tidx_api SET temp_file_limit = '256MB'; diff --git a/src/query/validator.rs b/src/query/validator.rs index b0847412..c75e488a 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -174,34 +174,17 @@ fn validate_table_factor(factor: &TableFactor, cte_names: &HashSet) -> R match factor { TableFactor::Table { name, args, .. } => { if args.is_some() { - let func_name = name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } + return Err(anyhow!("Table functions are not allowed")); } validate_table_name(name, cte_names) } TableFactor::Derived { subquery, .. } => validate_query_ast(subquery, cte_names), - TableFactor::TableFunction { expr, .. } => { - if let Expr::Function(func) = expr { - let func_name = func.name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - } - Ok(()) - } - TableFactor::Function { name, .. } => { - let func_name = name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - Ok(()) - } + TableFactor::TableFunction { .. } => Err(anyhow!("Table functions are not allowed")), + TableFactor::Function { .. } => Err(anyhow!("Table functions are not allowed")), TableFactor::NestedJoin { table_with_joins, .. } => { validate_table_with_joins(table_with_joins, cte_names) } - _ => Ok(()), + _ => Err(anyhow!("Unsupported FROM clause type")), } } @@ -307,11 +290,40 @@ fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { } } +const ALLOWED_FUNCTIONS: &[&str] = &[ + // ABI decode helpers (custom PostgreSQL functions) + "abi_uint", "abi_int", "abi_address", "abi_bool", "abi_bytes", "abi_string", + "format_address", "format_uint", + // Aggregates + "count", "sum", "avg", "min", "max", + // Scalar / null handling + "coalesce", "nullif", "greatest", "least", + // Numeric + "abs", "round", "floor", "ceil", "ceiling", "trunc", "pow", "power", + // String + "lower", "upper", "length", "substring", "substr", "trim", "ltrim", "rtrim", + "replace", "concat", "left", "right", "lpad", "rpad", + // Bytea / hex + "encode", "decode", "octet_length", + // Time + "date_trunc", "extract", "to_timestamp", "now", + // Window functions + "row_number", "rank", "dense_rank", "lag", "lead", "first_value", "last_value", + "ntile", "percent_rank", "cume_dist", + // Type casting helpers + "cast", +]; + +fn is_allowed_function(name: &str) -> bool { + let bare_name = name.rsplit('.').next().unwrap_or(name); + ALLOWED_FUNCTIONS.contains(&bare_name) +} + fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); - if is_dangerous_function(&func_name) { - return Err(anyhow!("Function '{func_name}' is not allowed")); + if !is_allowed_function(&func_name) { + return Err(anyhow!("Function '{}' is not allowed", func_name)); } if let FunctionArguments::List(arg_list) = &func.args { @@ -327,97 +339,6 @@ fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> Ok(()) } -/// Check if a function is dangerous (DoS, file access, side effects). -fn is_dangerous_function(name: &str) -> bool { - const DANGEROUS: &[&str] = &[ - // PostgreSQL DoS/side-effect functions - "pg_sleep", - "pg_terminate_backend", - "pg_cancel_backend", - "pg_reload_conf", - "pg_rotate_logfile", - "pg_switch_wal", - "pg_create_restore_point", - "pg_start_backup", - "pg_stop_backup", - "set_config", - "current_setting", - // PostgreSQL file access - "pg_read_file", - "pg_read_binary_file", - "pg_ls_dir", - "pg_stat_file", - "lo_import", - "lo_export", - // PostgreSQL command execution - "pg_execute_server_program", - // PostgreSQL dblink (remote connections) - "dblink", - "dblink_exec", - "dblink_connect", - "dblink_send_query", - "dblink_get_result", - // PostgreSQL large object access - "lo_get", - "lo_open", - "lo_close", - "loread", - "lowrite", - "lo_creat", - "lo_create", - "lo_unlink", - "lo_put", - // PostgreSQL set-returning functions (DoS via row generation) - "generate_series", - // PostgreSQL admin extension functions (file access) - "pg_file_read", - "pg_file_write", - "pg_file_rename", - "pg_file_unlink", - "pg_logdir_ls", - // ClickHouse system functions - "system.flush_logs", - "system.reload_config", - "system.shutdown", - "system.kill_query", - "system.drop_dns_cache", - "system.drop_mark_cache", - "system.drop_uncompressed_cache", - ]; - - DANGEROUS.iter().any(|&d| name == d || name.ends_with(&format!(".{d}"))) -} - -/// Check if a table function is dangerous (filesystem access). -fn is_dangerous_table_function(name: &str) -> bool { - const DANGEROUS: &[&str] = &[ - // ClickHouse file/URL table functions - "file", - "url", - "s3", - "gcs", - "hdfs", - "remote", - "remoteSecure", - "cluster", - "clusterAllReplicas", - // ClickHouse input formats - "input", - "format", - // ClickHouse system access - "system", - "numbers", - "zeros", - "generateRandom", - // ClickHouse dictionary access (could leak data) - "dictGet", - "dictGetOrDefault", - "dictHas", - ]; - - DANGEROUS.iter().any(|&d| name == d || name.contains(&format!("{d}("))) -} - #[cfg(test)] mod tests { use super::*; @@ -624,4 +545,36 @@ mod tests { fn test_allows_simple_values() { assert!(validate_query("VALUES (1, 'hello'), (2, 'world')").is_ok()); } + + #[test] + fn test_rejects_unknown_function() { + assert!(validate_query("SELECT md5('test') FROM blocks").is_err()); + assert!(validate_query("SELECT regexp_replace(hash, 'a', 'b') FROM blocks").is_err()); + } + + #[test] + fn test_allows_abi_helpers() { + assert!(validate_query("SELECT abi_uint(input) FROM txs").is_ok()); + assert!(validate_query("SELECT abi_address(input) FROM txs").is_ok()); + assert!(validate_query("SELECT format_address(miner) FROM blocks").is_ok()); + } + + #[test] + fn test_allows_common_functions() { + assert!(validate_query("SELECT COALESCE(gas_used, 0) FROM blocks").is_ok()); + assert!(validate_query("SELECT ABS(gas_used) FROM blocks").is_ok()); + assert!(validate_query("SELECT LOWER('test') FROM blocks").is_ok()); + assert!(validate_query("SELECT date_trunc('hour', to_timestamp(ts)) FROM blocks").is_ok()); + } + + #[test] + fn test_rejects_all_table_functions() { + assert!(validate_query("SELECT * FROM generate_series(1, 100)").is_err()); + assert!(validate_query("SELECT * FROM unnest(ARRAY[1,2,3])").is_err()); + } + + #[test] + fn test_rejects_unsupported_table_factor() { + assert!(validate_query("SELECT * FROM UNNEST(ARRAY[1,2,3])").is_err()); + } } From ce6f506e8536115f0020704364b356e417867b5d Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Tue, 10 Feb 2026 22:31:20 +0000 Subject: [PATCH 5/6] fix: reject-by-default expression validation, LIMIT/depth/size caps Switch validate_expr to reject-by-default: only explicitly allowed expression types are permitted (identifiers, literals, binary/unary ops, CASE, CAST, BETWEEN, IN, LIKE, subqueries, SQL builtins like EXTRACT, SUBSTRING, TRIM, etc). Unknown expression variants are rejected. Query structure hardening: - Reject FOR UPDATE/SHARE locking clauses - Validate ORDER BY expressions through validate_expr - Validate LIMIT/OFFSET: must be numeric literals, capped at 10,000 - Reject LIMIT BY (ClickHouse-specific) - Subquery depth limit: max 4 levels of nesting - Query size limit: max 64KB - Validate window function OVER clause (PARTITION BY, ORDER BY) Service layer: - Replace string-based LIMIT detection (contains("LIMIT")) with AST-based detection via append_limit_if_missing() - Use HARD_LIMIT_MAX constant (10,000) across validator, service, and API - API param clamping uses HARD_LIMIT_MAX instead of hardcoded 100,000 15 new tests, all 162 lib tests passing. Amp-Thread-ID: https://ampcode.com/threads/T-019c499a-f07e-73a9-9526-6c18fd511372 Co-authored-by: Amp --- src/api/mod.rs | 6 +- src/query/mod.rs | 2 +- src/query/validator.rs | 577 ++++++++++++++++++++++++++++++++++------- src/service/mod.rs | 28 +- 4 files changed, 501 insertions(+), 112 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 25300174..096049ed 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -257,7 +257,7 @@ fn default_timeout() -> u64 { 5000 } fn default_limit() -> i64 { - 10000 + crate::query::HARD_LIMIT_MAX } #[derive(Serialize)] @@ -299,7 +299,7 @@ async fn handle_query_once( let options = QueryOptions { timeout_ms: params.timeout_ms.clamp(100, 30000), - limit: params.limit.clamp(1, 100000), + limit: params.limit.clamp(1, crate::query::HARD_LIMIT_MAX), }; // Route to appropriate engine @@ -397,7 +397,7 @@ async fn handle_query_live( let signature = params.signature; let options = QueryOptions { timeout_ms: params.timeout_ms.clamp(100, 30000), - limit: params.limit.clamp(1, 100000), + limit: params.limit.clamp(1, crate::query::HARD_LIMIT_MAX), }; // Detect if this is an OLAP query (aggregations, etc.) diff --git a/src/query/mod.rs b/src/query/mod.rs index e0d16a2e..dee2f249 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -7,7 +7,7 @@ pub use parser::{ extract_order_by_columns, AbiParam, AbiType, EventSignature, }; pub use router::{route_query, QueryEngine}; -pub use validator::validate_query; +pub use validator::{validate_query, HARD_LIMIT_MAX}; use regex_lite::Regex; use std::sync::LazyLock; diff --git a/src/query/validator.rs b/src/query/validator.rs index c75e488a..6544860a 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -17,18 +17,26 @@ const ALLOWED_TABLES: &[&str] = &[ "token_balances", ]; +const MAX_QUERY_LENGTH: usize = 65_536; +const MAX_SUBQUERY_DEPTH: usize = 4; +pub const HARD_LIMIT_MAX: i64 = 10_000; + /// Validates that a SQL query is safe to execute. /// -/// Rejects: -/// - Multiple statements -/// - Non-SELECT statements (INSERT, UPDATE, DELETE, etc.) -/// - Data-modifying CTEs -/// - Dangerous functions (pg_sleep, read_csv, pg_read_file, etc.) -/// - System catalog access +/// Uses a reject-by-default approach: only explicitly allowed tables, +/// functions, and expression types are permitted. Everything else is rejected. pub fn validate_query(sql: &str) -> Result<()> { + if sql.len() > MAX_QUERY_LENGTH { + return Err(anyhow!( + "Query too large ({} bytes, max {})", + sql.len(), + MAX_QUERY_LENGTH + )); + } + let dialect = GenericDialect {}; - let statements = Parser::parse_sql(&dialect, sql) - .map_err(|e| anyhow!("SQL parse error: {e}"))?; + let statements = + Parser::parse_sql(&dialect, sql).map_err(|e| anyhow!("SQL parse error: {e}"))?; if statements.is_empty() { return Err(anyhow!("Empty query")); @@ -43,7 +51,7 @@ pub fn validate_query(sql: &str) -> Result<()> { match stmt { Statement::Query(query) => { let cte_names = extract_cte_names(query); - validate_query_ast(query, &cte_names) + validate_query_ast(query, &cte_names, 0) } _ => Err(anyhow!("Only SELECT queries are allowed")), } @@ -59,7 +67,14 @@ fn extract_cte_names(query: &Query) -> HashSet { names } -fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> { +fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) -> Result<()> { + if depth > MAX_SUBQUERY_DEPTH { + return Err(anyhow!( + "Subquery nesting too deep (max {} levels)", + MAX_SUBQUERY_DEPTH + )); + } + // Block recursive CTEs (can cause endless loops / resource exhaustion) if let Some(with) = &query.with { if with.recursive { @@ -67,6 +82,13 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> } } + // Block FOR UPDATE / FOR SHARE locking clauses + if !query.locks.is_empty() { + return Err(anyhow!( + "Locking clauses (FOR UPDATE/SHARE) are not allowed" + )); + } + let mut all_cte_names = cte_names.clone(); if let Some(with) = &query.with { for cte in &with.cte_tables { @@ -74,14 +96,86 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> } } - for cte in &query.with.as_ref().map_or(vec![], |w| w.cte_tables.clone()) { - validate_query_ast(&cte.query, &all_cte_names)?; + for cte in &query + .with + .as_ref() + .map_or(vec![], |w| w.cte_tables.clone()) + { + validate_query_ast(&cte.query, &all_cte_names, depth + 1)?; } - validate_set_expr(&query.body, &all_cte_names) + validate_set_expr(&query.body, &all_cte_names, depth)?; + + // Validate ORDER BY expressions + if let Some(order_by) = &query.order_by { + match &order_by.kind { + sqlparser::ast::OrderByKind::Expressions(exprs) => { + for order_expr in exprs { + validate_expr(&order_expr.expr, &all_cte_names, depth)?; + } + } + sqlparser::ast::OrderByKind::All(_) => {} + } + } + + // Validate LIMIT / OFFSET: only allow numeric literals + if let Some(limit_clause) = &query.limit_clause { + match limit_clause { + sqlparser::ast::LimitClause::LimitOffset { + limit, + offset, + limit_by, + } => { + if let Some(limit_expr) = limit { + validate_limit_expr(limit_expr, "LIMIT")?; + } + if let Some(offset) = offset { + validate_limit_expr(&offset.value, "OFFSET")?; + } + if !limit_by.is_empty() { + return Err(anyhow!("LIMIT BY is not allowed")); + } + } + sqlparser::ast::LimitClause::OffsetCommaLimit { offset, limit } => { + validate_limit_expr(offset, "OFFSET")?; + validate_limit_expr(limit, "LIMIT")?; + } + } + } + + Ok(()) +} + +fn validate_limit_expr(expr: &Expr, context: &str) -> Result<()> { + match expr { + Expr::Value(v) => { + let val = &v.value; + match val { + sqlparser::ast::Value::Number(n, _) => { + if let Ok(num) = n.parse::() { + if num > HARD_LIMIT_MAX { + return Err(anyhow!( + "{context} value {num} exceeds maximum ({HARD_LIMIT_MAX})" + )); + } + Ok(()) + } else { + Err(anyhow!("{context} must be a valid integer")) + } + } + sqlparser::ast::Value::Null => Ok(()), + _ => Err(anyhow!("{context} must be a numeric literal")), + } + } + _ => Err(anyhow!("{context} must be a numeric literal")), + } } -fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result<()> { +fn validate_set_expr( + set_expr: &SetExpr, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { match set_expr { SetExpr::Select(select) => { // Reject SELECT INTO (creates objects) @@ -90,45 +184,44 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result< } for table in &select.from { - validate_table_with_joins(table, cte_names)?; + validate_table_with_joins(table, cte_names, depth)?; } for item in &select.projection { if let sqlparser::ast::SelectItem::UnnamedExpr(expr) | sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } = item { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; } } if let Some(selection) = &select.selection { - validate_expr(selection, cte_names)?; + validate_expr(selection, cte_names, depth)?; } // Validate GROUP BY expressions if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by { for expr in exprs { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; } } // Validate HAVING if let Some(having) = &select.having { - validate_expr(having, cte_names)?; + validate_expr(having, cte_names, depth)?; } Ok(()) } - SetExpr::Query(q) => validate_query_ast(q, cte_names), + SetExpr::Query(q) => validate_query_ast(q, cte_names, depth), SetExpr::SetOperation { left, right, .. } => { - validate_set_expr(left, cte_names)?; - validate_set_expr(right, cte_names) + validate_set_expr(left, cte_names, depth)?; + validate_set_expr(right, cte_names, depth) } SetExpr::Values(values) => { - // Validate all expressions in VALUES rows to prevent function call bypass for row in &values.rows { for expr in row { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; } } Ok(()) @@ -141,11 +234,14 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result< } } -fn validate_table_with_joins(table: &TableWithJoins, cte_names: &HashSet) -> Result<()> { - validate_table_factor(&table.relation, cte_names)?; +fn validate_table_with_joins( + table: &TableWithJoins, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + validate_table_factor(&table.relation, cte_names, depth)?; for join in &table.joins { - validate_table_factor(&join.relation, cte_names)?; - // Validate JOIN ON expressions + validate_table_factor(&join.relation, cte_names, depth)?; let constraint = match &join.join_operator { sqlparser::ast::JoinOperator::Join(c) | sqlparser::ast::JoinOperator::Inner(c) @@ -164,13 +260,17 @@ fn validate_table_with_joins(table: &TableWithJoins, cte_names: &HashSet _ => None, }; if let Some(sqlparser::ast::JoinConstraint::On(expr)) = constraint { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; } } Ok(()) } -fn validate_table_factor(factor: &TableFactor, cte_names: &HashSet) -> Result<()> { +fn validate_table_factor( + factor: &TableFactor, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { match factor { TableFactor::Table { name, args, .. } => { if args.is_some() { @@ -178,12 +278,14 @@ fn validate_table_factor(factor: &TableFactor, cte_names: &HashSet) -> R } validate_table_name(name, cte_names) } - TableFactor::Derived { subquery, .. } => validate_query_ast(subquery, cte_names), + TableFactor::Derived { subquery, .. } => { + validate_query_ast(subquery, cte_names, depth + 1) + } TableFactor::TableFunction { .. } => Err(anyhow!("Table functions are not allowed")), TableFactor::Function { .. } => Err(anyhow!("Table functions are not allowed")), - TableFactor::NestedJoin { table_with_joins, .. } => { - validate_table_with_joins(table_with_joins, cte_names) - } + TableFactor::NestedJoin { + table_with_joins, .. + } => validate_table_with_joins(table_with_joins, cte_names, depth), _ => Err(anyhow!("Unsupported FROM clause type")), } } @@ -200,7 +302,9 @@ fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result for schema in BLOCKED_SCHEMAS { if full_name.starts_with(schema) { - return Err(anyhow!("Access to system catalog '{schema}' is not allowed")); + return Err(anyhow!( + "Access to system catalog '{schema}' is not allowed" + )); } } @@ -219,7 +323,9 @@ fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result } } - let bare_name = name.0.last() + let bare_name = name + .0 + .last() .and_then(|part| part.as_ident()) .map(|ident| ident.value.to_lowercase()) .unwrap_or_default(); @@ -235,44 +341,79 @@ fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result Err(anyhow!("Access to table '{bare_name}' is not allowed")) } -fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { +/// Reject-by-default expression validation. +/// Only explicitly allowed expression types are permitted. +fn validate_expr(expr: &Expr, cte_names: &HashSet, depth: usize) -> Result<()> { match expr { - Expr::Function(func) => validate_function(func, cte_names), - Expr::Subquery(q) => validate_query_ast(q, cte_names), - Expr::InSubquery { subquery, .. } => validate_query_ast(subquery, cte_names), - Expr::Exists { subquery, .. } => validate_query_ast(subquery, cte_names), + // Safe leaf nodes + Expr::Identifier(_) | Expr::CompoundIdentifier(_) => Ok(()), + Expr::Value(_) => Ok(()), + Expr::TypedString(_) => Ok(()), + Expr::Wildcard(_) | Expr::QualifiedWildcard(_, _) => Ok(()), + + // Function calls (validated against allowlist) + Expr::Function(func) => validate_function(func, cte_names, depth), + + // Subqueries (increment depth) + Expr::Subquery(q) => validate_query_ast(q, cte_names, depth + 1), + Expr::InSubquery { + expr, subquery, .. + } => { + validate_expr(expr, cte_names, depth)?; + validate_query_ast(subquery, cte_names, depth + 1) + } + Expr::Exists { subquery, .. } => validate_query_ast(subquery, cte_names, depth + 1), + + // Binary / unary operations Expr::BinaryOp { left, right, .. } => { - validate_expr(left, cte_names)?; - validate_expr(right, cte_names) + validate_expr(left, cte_names, depth)?; + validate_expr(right, cte_names, depth) } - Expr::UnaryOp { expr, .. } => validate_expr(expr, cte_names), - Expr::Between { expr, low, high, .. } => { - validate_expr(expr, cte_names)?; - validate_expr(low, cte_names)?; - validate_expr(high, cte_names) + Expr::UnaryOp { expr, .. } => validate_expr(expr, cte_names, depth), + + // Range expressions + Expr::Between { + expr, low, high, .. + } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(low, cte_names, depth)?; + validate_expr(high, cte_names, depth) } - Expr::Case { operand, conditions, else_result, .. } => { + + // CASE WHEN + Expr::Case { + operand, + conditions, + else_result, + .. + } => { if let Some(op) = operand { - validate_expr(op, cte_names)?; + validate_expr(op, cte_names, depth)?; } for case_when in conditions { - validate_expr(&case_when.condition, cte_names)?; - validate_expr(&case_when.result, cte_names)?; + validate_expr(&case_when.condition, cte_names, depth)?; + validate_expr(&case_when.result, cte_names, depth)?; } if let Some(else_r) = else_result { - validate_expr(else_r, cte_names)?; + validate_expr(else_r, cte_names, depth)?; } Ok(()) } - Expr::Cast { expr, .. } => validate_expr(expr, cte_names), - Expr::Nested(e) => validate_expr(e, cte_names), + + // Type casting + Expr::Cast { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::Nested(e) => validate_expr(e, cte_names, depth), + + // IN list Expr::InList { expr, list, .. } => { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; for item in list { - validate_expr(item, cte_names)?; + validate_expr(item, cte_names, depth)?; } Ok(()) } + + // Boolean tests Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsTrue(e) @@ -280,36 +421,159 @@ fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { | Expr::IsNotTrue(e) | Expr::IsNotFalse(e) | Expr::IsUnknown(e) - | Expr::IsNotUnknown(e) => validate_expr(e, cte_names), + | Expr::IsNotUnknown(e) => validate_expr(e, cte_names, depth), + + // Pattern matching Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => { - validate_expr(expr, cte_names)?; - validate_expr(pattern, cte_names) + validate_expr(expr, cte_names, depth)?; + validate_expr(pattern, cte_names, depth) + } + Expr::SimilarTo { expr, pattern, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(pattern, cte_names, depth) + } + + // ANY/ALL operators + Expr::AnyOp { right, .. } | Expr::AllOp { right, .. } => { + validate_expr(right, cte_names, depth) + } + + // IS DISTINCT FROM + Expr::IsDistinctFrom(a, b) | Expr::IsNotDistinctFrom(a, b) => { + validate_expr(a, cte_names, depth)?; + validate_expr(b, cte_names, depth) + } + + // SQL builtins parsed as dedicated Expr variants (not Function) + Expr::Extract { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::Substring { expr, substring_from, substring_for, .. } => { + validate_expr(expr, cte_names, depth)?; + if let Some(from) = substring_from { + validate_expr(from, cte_names, depth)?; + } + if let Some(for_expr) = substring_for { + validate_expr(for_expr, cte_names, depth)?; + } + Ok(()) + } + Expr::Trim { expr, trim_what, .. } => { + validate_expr(expr, cte_names, depth)?; + if let Some(what) = trim_what { + validate_expr(what, cte_names, depth)?; + } + Ok(()) + } + Expr::Ceil { expr, .. } | Expr::Floor { expr, .. } => { + validate_expr(expr, cte_names, depth) + } + Expr::Position { expr, r#in, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(r#in, cte_names, depth) + } + Expr::Overlay { expr, overlay_what, overlay_from, overlay_for, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(overlay_what, cte_names, depth)?; + validate_expr(overlay_from, cte_names, depth)?; + if let Some(for_expr) = overlay_for { + validate_expr(for_expr, cte_names, depth)?; + } + Ok(()) + } + Expr::Collate { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::AtTimeZone { timestamp, time_zone, .. } => { + validate_expr(timestamp, cte_names, depth)?; + validate_expr(time_zone, cte_names, depth) + } + + // Tuple / row constructors + Expr::Tuple(exprs) => { + for e in exprs { + validate_expr(e, cte_names, depth)?; + } + Ok(()) + } + + // Array literal + Expr::Array(arr) => { + for e in &arr.elem { + validate_expr(e, cte_names, depth)?; + } + Ok(()) } - Expr::AnyOp { right, .. } | Expr::AllOp { right, .. } => validate_expr(right, cte_names), - _ => Ok(()), + + // Interval literal + Expr::Interval(_) => Ok(()), + + // Reject everything else (reject-by-default) + _ => Err(anyhow!("Unsupported expression type")), } } const ALLOWED_FUNCTIONS: &[&str] = &[ // ABI decode helpers (custom PostgreSQL functions) - "abi_uint", "abi_int", "abi_address", "abi_bool", "abi_bytes", "abi_string", - "format_address", "format_uint", + "abi_uint", + "abi_int", + "abi_address", + "abi_bool", + "abi_bytes", + "abi_string", + "format_address", + "format_uint", // Aggregates - "count", "sum", "avg", "min", "max", + "count", + "sum", + "avg", + "min", + "max", // Scalar / null handling - "coalesce", "nullif", "greatest", "least", + "coalesce", + "nullif", + "greatest", + "least", // Numeric - "abs", "round", "floor", "ceil", "ceiling", "trunc", "pow", "power", + "abs", + "round", + "floor", + "ceil", + "ceiling", + "trunc", + "pow", + "power", // String - "lower", "upper", "length", "substring", "substr", "trim", "ltrim", "rtrim", - "replace", "concat", "left", "right", "lpad", "rpad", + "lower", + "upper", + "length", + "substring", + "substr", + "trim", + "ltrim", + "rtrim", + "replace", + "concat", + "left", + "right", + "lpad", + "rpad", // Bytea / hex - "encode", "decode", "octet_length", + "encode", + "decode", + "octet_length", // Time - "date_trunc", "extract", "to_timestamp", "now", + "date_trunc", + "extract", + "to_timestamp", + "now", // Window functions - "row_number", "rank", "dense_rank", "lag", "lead", "first_value", "last_value", - "ntile", "percent_rank", "cume_dist", + "row_number", + "rank", + "dense_rank", + "lag", + "lead", + "first_value", + "last_value", + "ntile", + "percent_rank", + "cume_dist", // Type casting helpers "cast", ]; @@ -319,7 +583,7 @@ fn is_allowed_function(name: &str) -> bool { ALLOWED_FUNCTIONS.contains(&bare_name) } -fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> { +fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); if !is_allowed_function(&func_name) { @@ -329,9 +593,24 @@ fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> if let FunctionArguments::List(arg_list) = &func.args { for arg in &arg_list.args { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) - | FunctionArg::Named { arg: FunctionArgExpr::Expr(expr), .. } = arg + | FunctionArg::Named { + arg: FunctionArgExpr::Expr(expr), + .. + } = arg { - validate_expr(expr, cte_names)?; + validate_expr(expr, cte_names, depth)?; + } + } + } + + // Validate window function OVER clause + if let Some(window_type) = &func.over { + if let sqlparser::ast::WindowType::WindowSpec(spec) = window_type { + for expr in &spec.partition_by { + validate_expr(expr, cte_names, depth)?; + } + for order_expr in &spec.order_by { + validate_expr(&order_expr.expr, cte_names, depth)?; } } } @@ -376,7 +655,6 @@ mod tests { #[test] fn test_rejects_data_modifying_cte() { - // This is the V1 bypass attempt let result = validate_query( "WITH x AS (UPDATE blocks SET miner = 'pwn' RETURNING 1) SELECT * FROM x", ); @@ -385,11 +663,9 @@ mod tests { #[test] fn test_rejects_comment_bypass() { - // Comments are stripped by parser, so this becomes a valid UPDATE let result = validate_query( "WITH x AS (UPDA/**/TE blocks SET miner = 'pwn' RETURNING 1) SELECT * FROM x", ); - // Parser will either fail to parse or recognize it as UPDATE assert!(result.is_err()); } @@ -431,18 +707,16 @@ mod tests { #[test] fn test_allows_window_functions() { - assert!(validate_query( - "SELECT num, ROW_NUMBER() OVER (ORDER BY num) FROM blocks" - ) - .is_ok()); + assert!( + validate_query("SELECT num, ROW_NUMBER() OVER (ORDER BY num) FROM blocks").is_ok() + ); } #[test] fn test_allows_subquery() { - assert!(validate_query( - "SELECT * FROM blocks WHERE num IN (SELECT block_num FROM txs)" - ) - .is_ok()); + assert!( + validate_query("SELECT * FROM blocks WHERE num IN (SELECT block_num FROM txs)").is_ok() + ); } #[test] @@ -467,15 +741,17 @@ mod tests { #[test] fn test_allows_cte_defined_table() { - assert!(validate_query( - "WITH my_cte AS (SELECT * FROM blocks) SELECT * FROM my_cte" - ) - .is_ok()); + assert!( + validate_query("WITH my_cte AS (SELECT * FROM blocks) SELECT * FROM my_cte").is_ok() + ); } #[test] fn test_rejects_dblink() { - assert!(validate_query("SELECT * FROM dblink('host=evil dbname=secrets', 'SELECT * FROM passwords')").is_err()); + assert!(validate_query( + "SELECT * FROM dblink('host=evil dbname=secrets', 'SELECT * FROM passwords')" + ) + .is_err()); assert!(validate_query("SELECT dblink_connect('myconn', 'host=evil')").is_err()); assert!(validate_query("SELECT dblink_exec('myconn', 'DROP TABLE blocks')").is_err()); } @@ -491,13 +767,17 @@ mod tests { fn test_rejects_recursive_cte() { assert!(validate_query( "WITH RECURSIVE r AS (SELECT 1 AS n UNION ALL SELECT n+1 FROM r) SELECT * FROM r" - ).is_err()); + ) + .is_err()); } #[test] fn test_rejects_generate_series() { assert!(validate_query("SELECT generate_series(1, 1000000000)").is_err()); - assert!(validate_query("SELECT * FROM blocks WHERE num IN (SELECT generate_series(1, 1000000))").is_err()); + assert!(validate_query( + "SELECT * FROM blocks WHERE num IN (SELECT generate_series(1, 1000000))" + ) + .is_err()); } #[test] @@ -531,14 +811,17 @@ mod tests { #[test] fn test_rejects_dangerous_function_in_having() { - assert!(validate_query("SELECT COUNT(*) FROM blocks GROUP BY num HAVING pg_sleep(1) IS NOT NULL").is_err()); + assert!(validate_query( + "SELECT COUNT(*) FROM blocks GROUP BY num HAVING pg_sleep(1) IS NOT NULL" + ) + .is_err()); } #[test] fn test_rejects_dangerous_function_in_join_on() { - assert!(validate_query( - "SELECT * FROM blocks JOIN txs ON pg_sleep(1) IS NOT NULL" - ).is_err()); + assert!( + validate_query("SELECT * FROM blocks JOIN txs ON pg_sleep(1) IS NOT NULL").is_err() + ); } #[test] @@ -564,7 +847,9 @@ mod tests { assert!(validate_query("SELECT COALESCE(gas_used, 0) FROM blocks").is_ok()); assert!(validate_query("SELECT ABS(gas_used) FROM blocks").is_ok()); assert!(validate_query("SELECT LOWER('test') FROM blocks").is_ok()); - assert!(validate_query("SELECT date_trunc('hour', to_timestamp(ts)) FROM blocks").is_ok()); + assert!( + validate_query("SELECT date_trunc('hour', to_timestamp(ts)) FROM blocks").is_ok() + ); } #[test] @@ -577,4 +862,98 @@ mod tests { fn test_rejects_unsupported_table_factor() { assert!(validate_query("SELECT * FROM UNNEST(ARRAY[1,2,3])").is_err()); } + + // === New tests for this commit === + + #[test] + fn test_rejects_for_update() { + assert!(validate_query("SELECT * FROM blocks FOR UPDATE").is_err()); + assert!(validate_query("SELECT * FROM blocks FOR SHARE").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_order_by() { + assert!( + validate_query("SELECT * FROM blocks ORDER BY pg_sleep(1)").is_err() + ); + } + + #[test] + fn test_rejects_excessive_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT 100000000").is_err()); + assert!(validate_query("SELECT * FROM blocks LIMIT 10001").is_err()); + } + + #[test] + fn test_allows_reasonable_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT 100").is_ok()); + assert!(validate_query("SELECT * FROM blocks LIMIT 10000").is_ok()); + assert!(validate_query("SELECT * FROM blocks LIMIT 1 OFFSET 5").is_ok()); + } + + #[test] + fn test_rejects_subquery_in_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT (SELECT 1)").is_err()); + } + + #[test] + fn test_rejects_deep_subquery_nesting() { + // 5 levels of derived table nesting exceeds MAX_SUBQUERY_DEPTH (4) + let deep = "SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM blocks) a) b) c) d) e"; + assert!(validate_query(deep).is_err()); + } + + #[test] + fn test_allows_moderate_subquery_nesting() { + // 3 levels of nesting is within limits + let moderate = "SELECT * FROM (SELECT * FROM (SELECT * FROM blocks) a) b"; + assert!(validate_query(moderate).is_ok()); + } + + #[test] + fn test_rejects_query_too_large() { + let huge = format!("SELECT * FROM blocks WHERE num IN ({})", "1,".repeat(70_000)); + assert!(validate_query(&huge).is_err()); + } + + #[test] + fn test_allows_order_by_column() { + assert!(validate_query("SELECT * FROM blocks ORDER BY num DESC").is_ok()); + assert!( + validate_query("SELECT * FROM blocks ORDER BY num DESC, hash ASC").is_ok() + ); + } + + #[test] + fn test_allows_cast_expression() { + assert!(validate_query("SELECT CAST(num AS TEXT) FROM blocks").is_ok()); + } + + #[test] + fn test_allows_between() { + assert!(validate_query("SELECT * FROM blocks WHERE num BETWEEN 1 AND 100").is_ok()); + } + + #[test] + fn test_allows_like() { + assert!(validate_query("SELECT * FROM txs WHERE hash LIKE '%abc%'").is_ok()); + } + + #[test] + fn test_allows_is_null() { + assert!(validate_query("SELECT * FROM blocks WHERE miner IS NOT NULL").is_ok()); + } + + #[test] + fn test_allows_case_when() { + assert!(validate_query( + "SELECT CASE WHEN num > 100 THEN 'big' ELSE 'small' END FROM blocks" + ) + .is_ok()); + } + + #[test] + fn test_allows_array_literal() { + assert!(validate_query("SELECT * FROM blocks WHERE num = ANY(ARRAY[1,2,3])").is_ok()); + } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 27a0d987..384dcf39 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -5,7 +5,7 @@ use std::time::Instant; use crate::db::Pool; use crate::metrics; -use crate::query::{extract_column_references, validate_query, EventSignature}; +use crate::query::{extract_column_references, validate_query, EventSignature, HARD_LIMIT_MAX}; #[derive(Debug, Clone, Serialize)] pub struct SyncStatus { @@ -109,7 +109,7 @@ impl Default for QueryOptions { fn default() -> Self { Self { timeout_ms: 5000, - limit: 10000, + limit: HARD_LIMIT_MAX, } } } @@ -145,13 +145,8 @@ pub async fn execute_query_postgres( sql.to_string() }; - // Add LIMIT if not present - let sql_upper = sql.to_uppercase(); - let sql = if !sql_upper.contains("LIMIT") { - format!("{} LIMIT {}", sql, options.limit) - } else { - sql - }; + // Add LIMIT if not present (AST-based detection to avoid string matching bypass) + let sql = append_limit_if_missing(&sql, options.limit); // Convert '0x...' hex literals to '\x...' bytea literals for PostgreSQL // Only replace hex values (40+ chars), not short '0x' prefixes used in concat() @@ -226,6 +221,21 @@ pub async fn execute_query_postgres( }) } +fn append_limit_if_missing(sql: &str, limit: i64) -> String { + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + if let Ok(stmts) = Parser::parse_sql(&dialect, sql) { + if let Some(sqlparser::ast::Statement::Query(query)) = stmts.first() { + if query.limit_clause.is_none() { + return format!("{sql} LIMIT {limit}"); + } + } + } + sql.to_string() +} + pub fn format_column_json(row: &tokio_postgres::Row, idx: usize) -> serde_json::Value { let col = &row.columns()[idx]; From 455942a24b2fc5b8b8661fe9920f7048bdb2468f Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Tue, 10 Feb 2026 22:44:01 +0000 Subject: [PATCH 6/6] fix: close FILTER clause bypass, LIMIT NULL, FETCH, negative limit Fixes found during oracle review: - Validate function FILTER (WHERE ...) clause: previously COUNT(*) FILTER (WHERE pg_sleep(1) IS NOT NULL) bypassed the function allowlist entirely - Validate WITHIN GROUP (ORDER BY ...) expressions - Reject LIMIT NULL (effectively means no limit, bypasses cap) - Reject negative LIMIT/OFFSET values - Reject FETCH clause (FETCH FIRST N ROWS ONLY bypasses LIMIT cap) 4 new tests, all 166 lib tests passing. Amp-Thread-ID: https://ampcode.com/threads/T-019c499a-f07e-73a9-9526-6c18fd511372 Co-authored-by: Amp --- src/query/validator.rs | 47 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/query/validator.rs b/src/query/validator.rs index 6544860a..a8563892 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -89,6 +89,11 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) )); } + // Block FETCH clause (alternative to LIMIT, could bypass cap) + if query.fetch.is_some() { + return Err(anyhow!("FETCH clause is not allowed, use LIMIT instead")); + } + let mut all_cte_names = cte_names.clone(); if let Some(with) = &query.with { for cte in &with.cte_tables { @@ -153,6 +158,9 @@ fn validate_limit_expr(expr: &Expr, context: &str) -> Result<()> { match val { sqlparser::ast::Value::Number(n, _) => { if let Ok(num) = n.parse::() { + if num < 0 { + return Err(anyhow!("{context} must not be negative")); + } if num > HARD_LIMIT_MAX { return Err(anyhow!( "{context} value {num} exceeds maximum ({HARD_LIMIT_MAX})" @@ -163,7 +171,9 @@ fn validate_limit_expr(expr: &Expr, context: &str) -> Result<()> { Err(anyhow!("{context} must be a valid integer")) } } - sqlparser::ast::Value::Null => Ok(()), + sqlparser::ast::Value::Null => { + Err(anyhow!("{context} NULL is not allowed")) + } _ => Err(anyhow!("{context} must be a numeric literal")), } } @@ -603,6 +613,16 @@ fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) } } + // Validate FILTER (WHERE ...) clause + if let Some(filter) = &func.filter { + validate_expr(filter, cte_names, depth)?; + } + + // Validate WITHIN GROUP (ORDER BY ...) clause + for order_expr in &func.within_group { + validate_expr(&order_expr.expr, cte_names, depth)?; + } + // Validate window function OVER clause if let Some(window_type) = &func.over { if let sqlparser::ast::WindowType::WindowSpec(spec) = window_type { @@ -956,4 +976,29 @@ mod tests { fn test_allows_array_literal() { assert!(validate_query("SELECT * FROM blocks WHERE num = ANY(ARRAY[1,2,3])").is_ok()); } + + #[test] + fn test_rejects_filter_clause_bypass() { + assert!(validate_query( + "SELECT COUNT(*) FILTER (WHERE pg_sleep(1) IS NOT NULL) FROM blocks" + ) + .is_err()); + } + + #[test] + fn test_rejects_limit_null() { + assert!(validate_query("SELECT * FROM blocks LIMIT NULL").is_err()); + } + + #[test] + fn test_rejects_negative_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT -1").is_err()); + } + + #[test] + fn test_rejects_fetch_clause() { + assert!( + validate_query("SELECT * FROM blocks FETCH FIRST 10 ROWS ONLY").is_err() + ); + } }