diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index b41ccd5..d89bb33 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -15,6 +15,7 @@ use datafusion::error::DataFusionError; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; +use datafusion::sql::sqlparser::ast::{Expr, Ident, ObjectName, Statement as SqlStatement}; use log::{info, warn}; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::auth::StartupHandler; @@ -247,35 +248,39 @@ impl DfSessionService { Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) } - async fn try_respond_set_statements<'a, C>( + /// Handle structured SET statements from parsed AST (replaces string matching) + async fn handle_set_statement_structured<'a, C>( &self, client: &mut C, - query_lower: &str, - ) -> PgWireResult>> + variables: &[ObjectName], + value: &[Expr], + ) -> PgWireResult> where C: ClientInfo, { - if query_lower.starts_with("set") { - if query_lower.starts_with("set time zone") { - let parts: Vec<&str> = query_lower.split_whitespace().collect(); - if parts.len() >= 4 { - let tz = parts[3].trim_matches('"'); + let var_name = variables.first().map(|v| v.to_string()).unwrap_or_default(); + match var_name.to_lowercase().as_str() { + "time_zone" | "timezone" => { + if let Some(val) = value.first() { + let val_str = val.to_string(); + let tz = val_str.trim_matches('"').trim_matches('\''); let mut timezone = self.timezone.lock().await; *timezone = tz.to_string(); - Ok(Some(Response::Execution(Tag::new("SET")))) + Ok(Response::Execution(Tag::new("SET"))) } else { Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( "ERROR".to_string(), "42601".to_string(), - "Invalid SET TIME ZONE syntax".to_string(), + "Invalid SET TIME ZONE value".to_string(), ), ))) } - } else if query_lower.starts_with("set statement_timeout") { - let parts: Vec<&str> = query_lower.split_whitespace().collect(); - if parts.len() >= 3 { - let timeout_str = parts[2].trim_matches('"').trim_matches('\''); + } + "statement_timeout" => { + if let Some(val) = value.first() { + let val_str = val.to_string(); + let timeout_str = val_str.trim_matches('"').trim_matches('\''); let timeout = if timeout_str == "0" || timeout_str.is_empty() { None @@ -305,27 +310,112 @@ impl DfSessionService { }; Self::set_statement_timeout(client, timeout); - Ok(Some(Response::Execution(Tag::new("SET")))) + Ok(Response::Execution(Tag::new("SET"))) } else { Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( "ERROR".to_string(), "42601".to_string(), - "Invalid SET statement_timeout syntax".to_string(), + "Invalid SET statement_timeout value".to_string(), ), ))) } - } else { - // pass SET query to datafusion - if let Err(e) = self.session_context.sql(query_lower).await { - warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored"); + } + _ => { + // Pass unknown SET statements to DataFusion + let set_sql = format!( + "SET {} = {}", + var_name, + value + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + ); + if let Err(e) = self.session_context.sql(&set_sql).await { + warn!("SET statement {set_sql} is not supported by datafusion, error {e}, statement ignored"); } + Ok(Response::Execution(Tag::new("SET"))) + } + } + } - // Always return SET success - Ok(Some(Response::Execution(Tag::new("SET")))) + /// Handle structured SHOW statements from parsed AST (replaces string matching) + async fn handle_show_statement_structured<'a, C>( + &self, + client: &C, + variable: &[Ident], + ) -> PgWireResult> + where + C: ClientInfo, + { + let var_name = variable + .iter() + .map(|i| i.to_string()) + .collect::>() + .join("_"); + match var_name.to_lowercase().as_str() { + "time_zone" | "timezone" => { + let timezone = self.timezone.lock().await.clone(); + let resp = Self::mock_show_response("TimeZone", &timezone)?; + Ok(Response::Query(resp)) + } + "server_version" => { + let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; + Ok(Response::Query(resp)) + } + "transaction_isolation" => { + let resp = Self::mock_show_response("transaction_isolation", "read uncommitted")?; + Ok(Response::Query(resp)) + } + "search_path" => { + let default_schema = "public"; + let resp = Self::mock_show_response("search_path", default_schema)?; + Ok(Response::Query(resp)) + } + "statement_timeout" => { + let timeout = Self::get_statement_timeout(client); + let timeout_str = match timeout { + Some(duration) => format!("{}ms", duration.as_millis()), + None => "0".to_string(), + }; + let resp = Self::mock_show_response("statement_timeout", &timeout_str)?; + Ok(Response::Query(resp)) + } + _ => { + let catalogs = self.session_context.catalog_names(); + let value = catalogs.join(", "); + let resp = Self::mock_show_response(&var_name, &value)?; + Ok(Response::Query(resp)) } - } else { - Ok(None) + } + } + + /// Handle structured statements using AST instead of fragile string matching + async fn try_handle_structured_statement<'a, C>( + &self, + client: &mut C, + statement: &SqlStatement, + ) -> PgWireResult>> + where + C: ClientInfo, + { + match statement { + SqlStatement::SetVariable { + variables, value, .. + } => { + let response = self + .handle_set_statement_structured(client, variables, value) + .await?; + Ok(Some(response)) + } + SqlStatement::ShowVariable { variable } => { + let response = self + .handle_show_statement_structured(client, variable) + .await?; + Ok(Some(response)) + } + _ => Ok(None), } } @@ -380,6 +470,8 @@ impl DfSessionService { } } + /// Legacy string-based SHOW statement handler (deprecated - use structured AST instead) + #[allow(dead_code)] async fn try_respond_show_statements<'a, C>( &self, client: &C, @@ -462,6 +554,14 @@ impl SimpleQueryHandler for DfSessionService { // TODO: deal with multiple statements let mut statement = statements.remove(0); + // Handle SET/SHOW statements using structured AST (replaces fragile string matching) + if let Some(resp) = self + .try_handle_structured_statement(client, &statement) + .await? + { + return Ok(vec![resp]); + } + // Attempt to rewrite statement = rewrite(statement, &self.sql_rewrite_rules); @@ -482,19 +582,7 @@ impl SimpleQueryHandler for DfSessionService { self.check_query_permission(client, &query).await?; } - if let Some(resp) = self - .try_respond_set_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } - - if let Some(resp) = self - .try_respond_show_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } + // SET/SHOW statements now handled by structured AST parsing above // Check if we're in a failed transaction and block non-transaction // commands @@ -626,34 +714,38 @@ impl ExtendedQueryHandler for DfSessionService { where C: ClientInfo + Unpin + Send + Sync, { - let query = portal - .statement - .statement - .0 - .to_lowercase() - .trim() - .to_string(); - log::debug!("Received execute extended query: {query}"); // Log for debugging - - // Check permissions for the query (skip for SET and SHOW statements) - if !query.starts_with("set") && !query.starts_with("show") { - self.check_query_permission(client, &portal.statement.statement.0) - .await?; - } - - if let Some(resp) = self.try_respond_set_statements(client, &query).await? { - return Ok(resp); + let original_sql = &portal.statement.statement.0; + log::debug!("Received execute extended query: {original_sql}"); // Log for debugging + + // Handle SET/SHOW statements using structured AST (re-parse for AST access) + if let Ok(parsed_statements) = crate::sql::parse(original_sql) { + if let Some(statement) = parsed_statements.first() { + if let Some(resp) = self + .try_handle_structured_statement(client, statement) + .await? + { + return Ok(resp); + } + } } + // Handle transaction statements + let query_lower = original_sql.to_lowercase().trim().to_string(); if let Some(resp) = self - .try_respond_transaction_statements(client, &query) + .try_respond_transaction_statements(client, &query_lower) .await? { return Ok(resp); } - if let Some(resp) = self.try_respond_show_statements(client, &query).await? { - return Ok(resp); + // Check permissions for non-SET/SHOW/transaction statements + if !query_lower.starts_with("set") + && !query_lower.starts_with("show") + && !query_lower.starts_with("begin") + && !query_lower.starts_with("commit") + && !query_lower.starts_with("rollback") + { + self.check_query_permission(client, original_sql).await?; } // Check if we're in a failed transaction and block non-transaction @@ -752,18 +844,25 @@ impl Parser { ))); } - // show statement may not be supported by datafusion - if sql_trimmed.starts_with("show") { - // Return a dummy plan for transaction commands - they'll be handled by transaction handler - let show_schema = - Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); - let df_schema = show_schema.to_dfschema()?; - return Ok(Some(LogicalPlan::EmptyRelation( - datafusion::logical_expr::EmptyRelation { - produce_one_row: true, - schema: Arc::new(df_schema), - }, - ))); + // Parse and check for SET/SHOW statements using structured AST + if let Ok(parsed_statements) = crate::sql::parse(sql) { + if let Some(statement) = parsed_statements.first() { + if matches!( + statement, + SqlStatement::SetVariable { .. } | SqlStatement::ShowVariable { .. } + ) { + // Return a dummy plan for SET/SHOW commands - they'll be handled by structured handler + let show_schema = + Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); + let df_schema = show_schema.to_dfschema()?; + return Ok(Some(LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }, + ))); + } + } } Ok(None) @@ -898,9 +997,11 @@ mod tests { let service = DfSessionService::new(session_context, auth_manager); let mut client = MockClient::new(); - // Test setting timeout to 5000ms + // Test setting timeout using structured AST parsing + use crate::sql::parse; + let set_statements = parse("SET statement_timeout = '5000ms'").unwrap(); let set_response = service - .try_respond_set_statements(&mut client, "set statement_timeout '5000ms'") + .try_handle_structured_statement(&mut client, &set_statements[0]) .await .unwrap(); assert!(set_response.is_some()); @@ -909,9 +1010,10 @@ mod tests { let timeout = DfSessionService::get_statement_timeout(&client); assert_eq!(timeout, Some(Duration::from_millis(5000))); - // Test SHOW statement_timeout + // Test SHOW statement_timeout using structured AST parsing + let show_statements = parse("SHOW statement_timeout").unwrap(); let show_response = service - .try_respond_show_statements(&client, "show statement_timeout") + .try_handle_structured_statement(&mut client, &show_statements[0]) .await .unwrap(); assert!(show_response.is_some()); @@ -924,19 +1026,53 @@ mod tests { let service = DfSessionService::new(session_context, auth_manager); let mut client = MockClient::new(); - // Set timeout first + // Set timeout first using structured AST + use crate::sql::parse; + let set_statements = parse("SET statement_timeout = '1000ms'").unwrap(); service - .try_respond_set_statements(&mut client, "set statement_timeout '1000ms'") + .try_handle_structured_statement(&mut client, &set_statements[0]) .await .unwrap(); - // Disable timeout with 0 + // Disable timeout with 0 using structured AST + let disable_statements = parse("SET statement_timeout = '0'").unwrap(); service - .try_respond_set_statements(&mut client, "set statement_timeout '0'") + .try_handle_structured_statement(&mut client, &disable_statements[0]) .await .unwrap(); let timeout = DfSessionService::get_statement_timeout(&client); assert_eq!(timeout, None); } + + #[tokio::test] + async fn test_structured_vs_string_statement_handling() { + let session_context = Arc::new(SessionContext::new()); + let auth_manager = Arc::new(AuthManager::new()); + let service = DfSessionService::new(session_context, auth_manager); + let mut client = MockClient::new(); + + // Test that structured parsing works for complex SET statements + use crate::sql::parse; + + // Test with equals sign (structured parsing should handle this better) + let statements = parse("SET statement_timeout = '5000ms'").unwrap(); + let result = service + .try_handle_structured_statement(&mut client, &statements[0]) + .await; + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + + // Verify timeout was set correctly via structured parsing + let timeout = DfSessionService::get_statement_timeout(&client); + assert_eq!(timeout, Some(Duration::from_millis(5000))); + + // Test SHOW with structured parsing + let show_statements = parse("SHOW statement_timeout").unwrap(); + let show_result = service + .try_handle_structured_statement(&mut client, &show_statements[0]) + .await; + assert!(show_result.is_ok()); + assert!(show_result.unwrap().is_some()); + } } diff --git a/datafusion-postgres/tests/metabase.rs b/datafusion-postgres/tests/metabase.rs index 3c15700..a17e1e2 100644 --- a/datafusion-postgres/tests/metabase.rs +++ b/datafusion-postgres/tests/metabase.rs @@ -46,8 +46,8 @@ pub async fn test_metabase_startup_sql() { for query in METABASE_QUERIES { SimpleQueryHandler::do_query(&service, &mut client, query) .await - .expect(&format!( - "failed to run sql: \n--------------\n {query}\n--------------\n" - )); + .unwrap_or_else(|_| { + panic!("failed to run sql: \n--------------\n {query}\n--------------\n") + }); } } diff --git a/datafusion-postgres/tests/pgcli.rs b/datafusion-postgres/tests/pgcli.rs index dc59e9f..21e5320 100644 --- a/datafusion-postgres/tests/pgcli.rs +++ b/datafusion-postgres/tests/pgcli.rs @@ -138,8 +138,8 @@ pub async fn test_pgcli_startup_sql() { for query in PGCLI_QUERIES { SimpleQueryHandler::do_query(&service, &mut client, query) .await - .expect(&format!( - "failed to run sql:\n--------------\n {query}\n--------------\n" - )); + .unwrap_or_else(|_| { + panic!("failed to run sql:\n--------------\n {query}\n--------------\n") + }); } }