From c1f0d89b6bea878bc0705ff53452a2228ef90091 Mon Sep 17 00:00:00 2001 From: Brendan Ryan <1572504+brendanjryan@users.noreply.github.com> Date: Sun, 3 May 2026 13:40:43 -0700 Subject: [PATCH 1/2] fix: harden sql validator traversal --- src/query/validator.rs | 262 ++++++++++++++++++++++++++++++++++------- 1 file changed, 220 insertions(+), 42 deletions(-) diff --git a/src/query/validator.rs b/src/query/validator.rs index 34ff7257..43a9db29 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -26,6 +26,12 @@ pub fn validate_query(sql: &str) -> Result<()> { MAX_QUERY_LENGTH )); } + if regex_lite::Regex::new(r"(?i)\blimit\s+all\b") + .expect("valid LIMIT ALL regex") + .is_match(sql) + { + return Err(anyhow!("LIMIT ALL is not allowed")); + } let dialect = GenericDialect {}; let statements = @@ -43,23 +49,13 @@ pub fn validate_query(sql: &str) -> Result<()> { match stmt { Statement::Query(query) => { - let cte_names = extract_cte_names(query); + let cte_names = HashSet::new(); validate_query_ast(query, &cte_names, 0) } _ => Err(anyhow!("Only SELECT queries are allowed")), } } -fn extract_cte_names(query: &Query) -> HashSet { - let mut names = HashSet::new(); - if let Some(with) = &query.with { - for cte in &with.cte_tables { - names.insert(cte.alias.name.value.to_lowercase()); - } - } - names -} - fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) -> Result<()> { if depth > MAX_SUBQUERY_DEPTH { return Err(anyhow!( @@ -90,14 +86,11 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) let mut all_cte_names = cte_names.clone(); if let Some(with) = &query.with { for cte in &with.cte_tables { + validate_query_ast(&cte.query, &all_cte_names, depth + 1)?; all_cte_names.insert(cte.alias.name.value.to_lowercase()); } } - for cte in &query.with.as_ref().map_or(vec![], |w| w.cte_tables.clone()) { - validate_query_ast(&cte.query, &all_cte_names, depth + 1)?; - } - validate_set_expr(&query.body, &all_cte_names, depth)?; // Validate ORDER BY expressions @@ -122,6 +115,8 @@ fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) } => { if let Some(limit_expr) = limit { validate_limit_expr(limit_expr, "LIMIT")?; + } else { + return Err(anyhow!("LIMIT ALL is not allowed")); } if let Some(offset) = offset { validate_limit_expr(&offset.value, "OFFSET")?; @@ -169,6 +164,13 @@ fn validate_limit_expr(expr: &Expr, context: &str) -> Result<()> { } fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet, depth: usize) -> Result<()> { + if depth > MAX_SUBQUERY_DEPTH { + return Err(anyhow!( + "Subquery nesting too deep (max {} levels)", + MAX_SUBQUERY_DEPTH + )); + } + match set_expr { SetExpr::Select(select) => { // Reject SELECT INTO (creates objects) @@ -176,6 +178,12 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet, depth: usi return Err(anyhow!("SELECT INTO is not allowed")); } + if let Some(sqlparser::ast::Distinct::On(exprs)) = &select.distinct { + for expr in exprs { + validate_expr(expr, cte_names, depth)?; + } + } + for table in &select.from { validate_table_with_joins(table, cte_names, depth)?; } @@ -204,12 +212,16 @@ fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet, depth: usi validate_expr(having, cte_names, depth)?; } + for window in &select.named_window { + validate_named_window(window, cte_names, depth)?; + } + Ok(()) } - SetExpr::Query(q) => validate_query_ast(q, cte_names, depth), + SetExpr::Query(q) => validate_query_ast(q, cte_names, depth + 1), SetExpr::SetOperation { left, right, .. } => { - validate_set_expr(left, cte_names, depth)?; - validate_set_expr(right, cte_names, depth) + validate_set_expr(left, cte_names, depth + 1)?; + validate_set_expr(right, cte_names, depth + 1) } SetExpr::Values(values) => { for row in &values.rows { @@ -265,10 +277,31 @@ fn validate_table_factor( depth: usize, ) -> Result<()> { match factor { - TableFactor::Table { name, args, .. } => { + TableFactor::Table { + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + sample, + index_hints, + .. + } => { if args.is_some() { return Err(anyhow!("Table functions are not allowed")); } + if !with_hints.is_empty() + || version.is_some() + || *with_ordinality + || !partitions.is_empty() + || json_path.is_some() + || sample.is_some() + || !index_hints.is_empty() + { + return Err(anyhow!("Unsupported table modifier")); + } validate_table_name(name, cte_names) } TableFactor::Derived { subquery, .. } => validate_query_ast(subquery, cte_names, depth + 1), @@ -282,16 +315,14 @@ fn validate_table_factor( } fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result<()> { - let full_name = name.to_string().to_lowercase(); - const BLOCKED_SCHEMAS: &[&str] = &["pg_catalog", "information_schema", "pg_temp", "pg_toast"]; - for schema in BLOCKED_SCHEMAS { - if full_name.starts_with(schema) { - return Err(anyhow!( - "Access to system catalog '{schema}' is not allowed" - )); - } + let name_parts = object_name_parts(name)?; + if name_parts.len() > 1 && BLOCKED_SCHEMAS.contains(&name_parts[0].as_str()) { + return Err(anyhow!( + "Access to system catalog '{}' is not allowed", + name_parts[0] + )); } const BLOCKED_TABLES: &[&str] = &[ @@ -303,30 +334,36 @@ fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result "pg_roles", ]; - for table in BLOCKED_TABLES { - if full_name.contains(table) { - return Err(anyhow!("Access to table '{table}' is not allowed")); + for part in &name_parts { + if BLOCKED_TABLES.contains(&part.as_str()) { + return Err(anyhow!("Access to table '{part}' is not allowed")); } } - let bare_name = name - .0 - .last() - .and_then(|part| part.as_ident()) - .map(|ident| ident.value.to_lowercase()) - .unwrap_or_default(); + let bare_name = name_parts.last().cloned().unwrap_or_default(); if ALLOWED_TABLES.contains(&bare_name.as_str()) { return Ok(()); } - if cte_names.contains(&bare_name) { + if name_parts.len() == 1 && cte_names.contains(&bare_name) { return Ok(()); } Err(anyhow!("Access to table '{bare_name}' is not allowed")) } +fn object_name_parts(name: &ObjectName) -> Result> { + name.0 + .iter() + .map(|part| { + part.as_ident() + .map(|ident| ident.value.to_lowercase()) + .ok_or_else(|| anyhow!("Unsupported table identifier")) + }) + .collect() +} + /// Reject-by-default expression validation. /// Only explicitly allowed expression types are permitted. fn validate_expr(expr: &Expr, cte_names: &HashSet, depth: usize) -> Result<()> { @@ -501,7 +538,7 @@ fn validate_expr(expr: &Expr, cte_names: &HashSet, depth: usize) -> Resu } // Interval literal - Expr::Interval(_) => Ok(()), + Expr::Interval(interval) => validate_expr(&interval.value, cte_names, depth), // Reject everything else (reject-by-default) _ => Err(anyhow!("Unsupported expression type")), @@ -619,6 +656,9 @@ fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) validate_expr(expr, cte_names, depth)?; } } + for clause in &arg_list.clauses { + validate_function_argument_clause(clause, cte_names, depth)?; + } } // Validate FILTER (WHERE ...) clause @@ -633,17 +673,84 @@ fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) // Validate window function OVER clause if let Some(sqlparser::ast::WindowType::WindowSpec(spec)) = &func.over { - for expr in &spec.partition_by { - validate_expr(expr, cte_names, depth)?; + validate_window_spec(spec, cte_names, depth)?; + } + + Ok(()) +} + +fn validate_function_argument_clause( + clause: &sqlparser::ast::FunctionArgumentClause, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + match clause { + sqlparser::ast::FunctionArgumentClause::IgnoreOrRespectNulls(_) => Ok(()), + sqlparser::ast::FunctionArgumentClause::OrderBy(order_by) => { + for order_expr in order_by { + validate_expr(&order_expr.expr, cte_names, depth)?; + } + Ok(()) + } + sqlparser::ast::FunctionArgumentClause::Limit(expr) => { + validate_limit_expr(expr, "FUNCTION LIMIT") } - for order_expr in &spec.order_by { - validate_expr(&order_expr.expr, cte_names, depth)?; + sqlparser::ast::FunctionArgumentClause::Having(bound) => { + validate_expr(&bound.1, cte_names, depth) + } + _ => Err(anyhow!("Unsupported function clause")), + } +} + +fn validate_named_window( + window: &sqlparser::ast::NamedWindowDefinition, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + match &window.1 { + sqlparser::ast::NamedWindowExpr::NamedWindow(_) => Ok(()), + sqlparser::ast::NamedWindowExpr::WindowSpec(spec) => { + validate_window_spec(spec, cte_names, depth) } } +} +fn validate_window_spec( + spec: &sqlparser::ast::WindowSpec, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + for expr in &spec.partition_by { + validate_expr(expr, cte_names, depth)?; + } + for order_expr in &spec.order_by { + validate_expr(&order_expr.expr, cte_names, depth)?; + } + if let Some(frame) = &spec.window_frame { + validate_window_frame_bound(&frame.start_bound, cte_names, depth)?; + if let Some(end_bound) = &frame.end_bound { + validate_window_frame_bound(end_bound, cte_names, depth)?; + } + } Ok(()) } +fn validate_window_frame_bound( + bound: &sqlparser::ast::WindowFrameBound, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + match bound { + sqlparser::ast::WindowFrameBound::CurrentRow + | sqlparser::ast::WindowFrameBound::Preceding(None) + | sqlparser::ast::WindowFrameBound::Following(None) => Ok(()), + sqlparser::ast::WindowFrameBound::Preceding(Some(expr)) + | sqlparser::ast::WindowFrameBound::Following(Some(expr)) => { + validate_expr(expr, cte_names, depth) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1006,4 +1113,75 @@ mod tests { fn test_rejects_fetch_clause() { assert!(validate_query("SELECT * FROM blocks FETCH FIRST 10 ROWS ONLY").is_err()); } + + #[test] + fn test_rejects_self_referencing_cte_shadowing() { + assert!( + validate_query( + "WITH sync_state AS (SELECT * FROM sync_state) SELECT * FROM sync_state" + ) + .is_err() + ); + } + + #[test] + fn test_rejects_schema_qualified_cte_shadowing() { + assert!( + validate_query( + "WITH sync_state AS (SELECT * FROM blocks) SELECT * FROM public.sync_state" + ) + .is_err() + ); + } + + #[test] + fn test_rejects_quoted_blocked_schema() { + assert!(validate_query(r#"SELECT * FROM "pg_catalog".pg_proc"#).is_err()); + assert!(validate_query(r#"SELECT * FROM "information_schema".tables"#).is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_distinct_on() { + assert!(validate_query("SELECT DISTINCT ON (pg_sleep(1)) num FROM blocks").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_named_window() { + assert!( + validate_query( + "SELECT row_number() OVER w FROM blocks WINDOW w AS (ORDER BY pg_sleep(1))" + ) + .is_err() + ); + } + + #[test] + fn test_rejects_dangerous_function_in_aggregate_order_by() { + assert!(validate_query("SELECT array_agg(num ORDER BY pg_sleep(1)) FROM blocks").is_err()); + } + + #[test] + fn test_rejects_table_sample_modifier() { + assert!(validate_query("SELECT * FROM blocks TABLESAMPLE SYSTEM (10)").is_err()); + } + + #[test] + fn test_rejects_dangerous_interval_expression() { + assert!(validate_query("SELECT INTERVAL pg_sleep(1) FROM blocks").is_err()); + } + + #[test] + fn test_rejects_limit_all() { + assert!(validate_query("SELECT * FROM blocks LIMIT ALL").is_err()); + } + + #[test] + fn test_rejects_deep_parenthesized_query_nesting() { + let mut sql = "SELECT * FROM blocks".to_string(); + for _ in 0..32 { + sql = format!("({sql})"); + } + let sql = format!("SELECT * FROM {sql} nested"); + assert!(validate_query(&sql).is_err()); + } } From 0e7efe5939e339fe7d1068f76114483c70008bd5 Mon Sep 17 00:00:00 2001 From: Brendan Ryan <1572504+brendanjryan@users.noreply.github.com> Date: Sun, 3 May 2026 13:54:52 -0700 Subject: [PATCH 2/2] docs: add sql validator changelog --- .changelog/pr-179-sql-validator-hardening.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changelog/pr-179-sql-validator-hardening.md diff --git a/.changelog/pr-179-sql-validator-hardening.md b/.changelog/pr-179-sql-validator-hardening.md new file mode 100644 index 00000000..6825f513 --- /dev/null +++ b/.changelog/pr-179-sql-validator-hardening.md @@ -0,0 +1,5 @@ +--- +tidx: patch +--- + +Harden PostgreSQL SQL validation by fixing CTE scope handling, schema-qualified table checks, recursive depth accounting, LIMIT ALL rejection, and traversal of previously unchecked AST clauses.