diff --git a/.changelog/pr-181-query-result-bounds.md b/.changelog/pr-181-query-result-bounds.md new file mode 100644 index 00000000..0e5a1bf9 --- /dev/null +++ b/.changelog/pr-181-query-result-bounds.md @@ -0,0 +1,5 @@ +--- +tidx: patch +--- + +Bound PostgreSQL query result processing by streaming rows with a hard request limit and appending automatic LIMIT clauses on a separate line. diff --git a/src/service/mod.rs b/src/service/mod.rs index 01307a96..82e9763c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,7 +1,9 @@ use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; +use futures::TryStreamExt; use serde::Serialize; use std::time::Instant; +use tokio_postgres::types::ToSql; use crate::db::Pool; use crate::metrics; @@ -218,16 +220,40 @@ pub async fn execute_query_postgres( .await?; let start = Instant::now(); - let result = tokio::time::timeout( - std::time::Duration::from_millis(options.timeout_ms + 100), - conn.query(&sql, &[]), - ) + let timeout = std::time::Duration::from_millis(options.timeout_ms + 100); + let limit = options.limit as usize; + let result = tokio::time::timeout(timeout, async { + let params = std::iter::empty::<&(dyn ToSql + Sync)>(); + let stream = conn.query_raw(&sql, params).await?; + futures::pin_mut!(stream); + let mut columns: Option> = None; + let mut rows = Vec::new(); + + while let Some(row) = stream.try_next().await? { + if columns.is_none() { + columns = Some(row.columns().iter().map(|c| c.name().to_string()).collect()); + } + if rows.len() >= limit { + return Err(anyhow!("Query returned more than {limit} rows")); + } + let cols = columns + .as_ref() + .expect("columns initialized from first row"); + rows.push( + (0..cols.len()) + .map(|i| format_column_json(&row, i)) + .collect::>(), + ); + } + + Ok::<_, anyhow::Error>((columns.unwrap_or_default(), rows)) + }) .await; - let rows = match result { - Ok(Ok(rows)) => { + let (mut columns, result_rows) = match result { + Ok(Ok(result)) => { metrics::record_query_duration(start.elapsed()); - rows + result } Ok(Err(e)) => { return Err(anyhow!( @@ -238,45 +264,19 @@ pub async fn execute_query_postgres( Err(_) => return Err(anyhow!("Query timeout")), }; - // Get columns from result (even if empty, prepared statement has column info) - let columns: Vec = if rows.is_empty() { - // For empty results, prepare statement to get column metadata - conn.prepare(&sql) + if columns.is_empty() { + columns = conn + .prepare(&sql) .await .ok() .map(|s| s.columns().iter().map(|c| c.name().to_string()).collect()) - .unwrap_or_default() - } else { - rows[0] - .columns() - .iter() - .map(|c| c.name().to_string()) - .collect() - }; + .unwrap_or_default(); + } let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0; - - if rows.is_empty() { - return Ok(QueryResult { - columns, - rows: vec![], - row_count: 0, - engine: Some("postgres".to_string()), - query_time_ms: Some(elapsed_ms), - }); - } - let row_count = rows.len(); + let row_count = result_rows.len(); metrics::record_query_rows(row_count as u64); - let result_rows: Vec> = rows - .iter() - .map(|row| { - (0..columns.len()) - .map(|i| format_column_json(row, i)) - .collect() - }) - .collect(); - Ok(QueryResult { columns, rows: result_rows, @@ -294,7 +294,7 @@ fn append_limit_if_missing(sql: &str, limit: i64) -> String { if let Ok(stmts) = Parser::parse_sql(&dialect, sql) { if let Some(sqlparser::ast::Statement::Query(query)) = stmts.first() { if query.limit_clause.is_none() { - return format!("{sql} LIMIT {limit}"); + return format!("{sql}\nLIMIT {limit}"); } } } @@ -530,6 +530,22 @@ mod tests { assert_eq!(options.limit, 10000); } + #[test] + fn test_append_limit_uses_newline_after_line_comment() { + let sql = "SELECT * FROM blocks -- trailing comment"; + let limited = append_limit_if_missing(sql, 100); + assert_eq!( + limited, + "SELECT * FROM blocks -- trailing comment\nLIMIT 100" + ); + } + + #[test] + fn test_append_limit_preserves_existing_limit() { + let sql = "SELECT * FROM blocks LIMIT 10"; + assert_eq!(append_limit_if_missing(sql, 100), sql); + } + // ======================================================================== // Sanitize Error Tests // ========================================================================