diff --git a/src/query/parser.rs b/src/query/parser.rs index eb3e0b36..194f1538 100644 --- a/src/query/parser.rs +++ b/src/query/parser.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use sha3::{Digest, Keccak256}; -use sqlparser::ast::{visit_expressions, BinaryOperator, Expr, Value}; +use sqlparser::ast::{visit_expressions, BinaryOperator, Expr, SetExpr, Statement, Value}; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; use std::collections::{HashMap, HashSet}; @@ -576,6 +576,10 @@ fn extract_ident_from_expr(expr: &Expr, columns: &mut HashSet) { /// Returns SQL fragments like `block_num >= 100`, `address = '0x...'` etc. /// Only extracts simple comparisons (=, >=, <=, >, <) and IN lists on known /// raw columns. Decoded event columns are NOT extracted. +/// +/// IMPORTANT: Only top-level AND conjuncts are extracted. Predicates inside OR, +/// CASE, or other complex expressions are NOT pushed down, because converting +/// `WHERE a OR b` into `WHERE a AND b` would silently change query semantics. pub fn extract_raw_column_predicates(sql: &str) -> Vec { let mut predicates = Vec::new(); @@ -585,15 +589,50 @@ pub fn extract_raw_column_predicates(sql: &str) -> Vec { }; for stmt in &statements { - let _ = visit_expressions(stmt, |expr| { - extract_raw_predicate(expr, &mut predicates); - ControlFlow::<()>::Continue(()) - }); + if let Some(where_expr) = extract_where_clause(stmt) { + collect_and_conjuncts(where_expr, &mut predicates); + } } predicates } +/// Extract the WHERE clause expression from a SELECT statement. +fn extract_where_clause(stmt: &Statement) -> Option<&Expr> { + match stmt { + Statement::Query(query) => { + if let SetExpr::Select(select) = query.body.as_ref() { + select.selection.as_ref() + } else { + None + } + } + _ => None, + } +} + +/// Walk only top-level AND conjuncts, extracting raw predicates from each leaf. +/// Stops recursing at OR, CASE, or any non-AND binary operator to avoid +/// incorrectly converting disjunctions into conjunctions. +fn collect_and_conjuncts(expr: &Expr, predicates: &mut Vec) { + match expr { + Expr::BinaryOp { + left, + op: BinaryOperator::And, + right, + } => { + collect_and_conjuncts(left, predicates); + collect_and_conjuncts(right, predicates); + } + Expr::Nested(inner) => { + collect_and_conjuncts(inner, predicates); + } + other => { + extract_raw_predicate(other, predicates); + } + } +} + /// Extract a single raw-column predicate from an expression. fn extract_raw_predicate(expr: &Expr, predicates: &mut Vec) { match expr { @@ -1532,4 +1571,56 @@ mod tests { assert_eq!(without, with); } + // ======================================================================== + // Pushdown Safety: OR, CASE, and mixed expressions + // ======================================================================== + + #[test] + fn test_or_predicates_not_pushed_down() { + let preds = extract_raw_column_predicates( + "SELECT * FROM Transfer WHERE block_num = 1 OR block_num = 2", + ); + assert!(preds.is_empty(), "OR predicates must not be pushed down, got: {preds:?}"); + } + + #[test] + fn test_case_predicates_not_pushed_down() { + let preds = extract_raw_column_predicates( + "SELECT * FROM Transfer WHERE block_num = CASE WHEN 1=1 THEN 100 ELSE 200 END", + ); + assert!(preds.is_empty(), "CASE predicates must not be pushed down, got: {preds:?}"); + } + + #[test] + fn test_simple_and_predicates_pushed_down() { + let preds = extract_raw_column_predicates( + "SELECT * FROM Transfer WHERE block_num >= 100 AND block_num <= 200 AND address = '0xABC'", + ); + assert_eq!(preds.len(), 3); + assert!(preds.contains(&"block_num >= 100".to_string())); + assert!(preds.contains(&"block_num <= 200".to_string())); + assert!(preds.contains(&"address = '0xABC'".to_string())); + } + + #[test] + fn test_mixed_and_or_only_pushes_safe_conjuncts() { + // `block_num >= 100 AND (address = '0xA' OR address = '0xB')` + // Only the top-level AND conjunct `block_num >= 100` is safe to push down. + // The OR branch must NOT be pushed down. + let preds = extract_raw_column_predicates( + "SELECT * FROM Transfer WHERE block_num >= 100 AND (address = '0xA' OR address = '0xB')", + ); + assert_eq!(preds.len(), 1, "Only safe AND conjuncts should be pushed, got: {preds:?}"); + assert!(preds.contains(&"block_num >= 100".to_string())); + } + + #[test] + fn test_nested_or_inside_and_not_pushed_down() { + let preds = extract_raw_column_predicates( + "SELECT * FROM Transfer WHERE (block_num = 1 OR block_num = 2) AND address = '0xABC'", + ); + assert_eq!(preds.len(), 1, "Only simple AND conjuncts should be pushed, got: {preds:?}"); + assert!(preds.contains(&"address = '0xABC'".to_string())); + } + }