Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changelog/pr-181-query-result-bounds.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
tidx: patch
---

Bound PostgreSQL query result processing by streaming rows with a hard request limit and appending automatic LIMIT clauses on a separate line.
96 changes: 56 additions & 40 deletions src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
use futures::TryStreamExt;
use serde::Serialize;
use std::time::Instant;
use tokio_postgres::types::ToSql;

use crate::db::Pool;
use crate::metrics;
Expand Down Expand Up @@ -218,16 +220,40 @@ pub async fn execute_query_postgres(
.await?;

let start = Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_millis(options.timeout_ms + 100),
conn.query(&sql, &[]),
)
let timeout = std::time::Duration::from_millis(options.timeout_ms + 100);
let limit = options.limit as usize;
let result = tokio::time::timeout(timeout, async {
let params = std::iter::empty::<&(dyn ToSql + Sync)>();
let stream = conn.query_raw(&sql, params).await?;
futures::pin_mut!(stream);
let mut columns: Option<Vec<String>> = None;
let mut rows = Vec::new();

while let Some(row) = stream.try_next().await? {
if columns.is_none() {
columns = Some(row.columns().iter().map(|c| c.name().to_string()).collect());
}
if rows.len() >= limit {
return Err(anyhow!("Query returned more than {limit} rows"));
}
let cols = columns
.as_ref()
.expect("columns initialized from first row");
rows.push(
(0..cols.len())
.map(|i| format_column_json(&row, i))
.collect::<Vec<_>>(),
);
}

Ok::<_, anyhow::Error>((columns.unwrap_or_default(), rows))
})
.await;

let rows = match result {
Ok(Ok(rows)) => {
let (mut columns, result_rows) = match result {
Ok(Ok(result)) => {
metrics::record_query_duration(start.elapsed());
rows
result
}
Ok(Err(e)) => {
return Err(anyhow!(
Expand All @@ -238,45 +264,19 @@ pub async fn execute_query_postgres(
Err(_) => return Err(anyhow!("Query timeout")),
};

// Get columns from result (even if empty, prepared statement has column info)
let columns: Vec<String> = if rows.is_empty() {
// For empty results, prepare statement to get column metadata
conn.prepare(&sql)
if columns.is_empty() {
columns = conn
.prepare(&sql)
.await
.ok()
.map(|s| s.columns().iter().map(|c| c.name().to_string()).collect())
.unwrap_or_default()
} else {
rows[0]
.columns()
.iter()
.map(|c| c.name().to_string())
.collect()
};
.unwrap_or_default();
}

let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;

if rows.is_empty() {
return Ok(QueryResult {
columns,
rows: vec![],
row_count: 0,
engine: Some("postgres".to_string()),
query_time_ms: Some(elapsed_ms),
});
}
let row_count = rows.len();
let row_count = result_rows.len();
metrics::record_query_rows(row_count as u64);

let result_rows: Vec<Vec<serde_json::Value>> = rows
.iter()
.map(|row| {
(0..columns.len())
.map(|i| format_column_json(row, i))
.collect()
})
.collect();

Ok(QueryResult {
columns,
rows: result_rows,
Expand All @@ -294,7 +294,7 @@ fn append_limit_if_missing(sql: &str, limit: i64) -> String {
if let Ok(stmts) = Parser::parse_sql(&dialect, sql) {
if let Some(sqlparser::ast::Statement::Query(query)) = stmts.first() {
if query.limit_clause.is_none() {
return format!("{sql} LIMIT {limit}");
return format!("{sql}\nLIMIT {limit}");
}
}
}
Expand Down Expand Up @@ -530,6 +530,22 @@ mod tests {
assert_eq!(options.limit, 10000);
}

#[test]
fn test_append_limit_uses_newline_after_line_comment() {
let sql = "SELECT * FROM blocks -- trailing comment";
let limited = append_limit_if_missing(sql, 100);
assert_eq!(
limited,
"SELECT * FROM blocks -- trailing comment\nLIMIT 100"
);
}

#[test]
fn test_append_limit_preserves_existing_limit() {
let sql = "SELECT * FROM blocks LIMIT 10";
assert_eq!(append_limit_if_missing(sql, 100), sql);
}

// ========================================================================
// Sanitize Error Tests
// ========================================================================
Expand Down
Loading