From 4479eef7d7ce4d3318d93f0127f9a0b647aa8b54 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Fri, 28 Nov 2025 18:31:59 -0800 Subject: [PATCH 01/23] Rewrite 3.0 --- .../client/query_engine/internal_values.rs | 2 +- pgdog/src/frontend/router/mod.rs | 1 + pgdog/src/frontend/router/parser/insert.rs | 4 + pgdog/src/frontend/router/parser/value.rs | 20 +- pgdog/src/frontend/router/rewrite/error.rs | 16 ++ pgdog/src/frontend/router/rewrite/mod.rs | 12 + .../router/rewrite/unique_id/insert.rs | 217 ++++++++++++++++++ .../frontend/router/rewrite/unique_id/mod.rs | 3 + pgdog/src/net/messages/bind.rs | 23 ++ pgdog/src/net/messages/parse.rs | 1 - pgdog/src/unique_id.rs | 60 +++-- 11 files changed, 318 insertions(+), 41 deletions(-) create mode 100644 pgdog/src/frontend/router/rewrite/error.rs create mode 100644 pgdog/src/frontend/router/rewrite/mod.rs create mode 100644 pgdog/src/frontend/router/rewrite/unique_id/insert.rs create mode 100644 pgdog/src/frontend/router/rewrite/unique_id/mod.rs diff --git a/pgdog/src/frontend/client/query_engine/internal_values.rs b/pgdog/src/frontend/client/query_engine/internal_values.rs index 7d7245932..4721d30b4 100644 --- a/pgdog/src/frontend/client/query_engine/internal_values.rs +++ b/pgdog/src/frontend/client/query_engine/internal_values.rs @@ -32,7 +32,7 @@ impl QueryEngine { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result<(), Error> { - let id = unique_id::UniqueId::generator()?.next_id().await; + let id = unique_id::UniqueId::generator()?.next_id(); let bytes_sent = context .stream .send_many(&[ diff --git a/pgdog/src/frontend/router/mod.rs b/pgdog/src/frontend/router/mod.rs index eb78d3e9e..df5543cc2 100644 --- a/pgdog/src/frontend/router/mod.rs +++ b/pgdog/src/frontend/router/mod.rs @@ -5,6 +5,7 @@ pub mod context; pub mod copy; pub mod error; pub mod parser; +pub mod rewrite; pub mod round_robin; pub mod search_path; pub mod sharding; diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs index be04e0fbb..3213192d8 100644 --- a/pgdog/src/frontend/router/parser/insert.rs +++ b/pgdog/src/frontend/router/parser/insert.rs @@ -57,6 +57,10 @@ impl<'a> Insert<'a> { .unwrap_or(vec![]) } + pub fn stmt(&'a self) -> &'a InsertStmt { + &self.stmt + } + /// Get table name, if specified (should always be). pub fn table(&'a self) -> Option> { self.stmt.relation.as_ref().map(Table::from) diff --git a/pgdog/src/frontend/router/parser/value.rs b/pgdog/src/frontend/router/parser/value.rs index 81cc1350f..d9f9a0e08 100644 --- a/pgdog/src/frontend/router/parser/value.rs +++ b/pgdog/src/frontend/router/parser/value.rs @@ -17,7 +17,7 @@ pub enum Value<'a> { Null, Placeholder(i32), Vector(Vector), - Function(&'a str), + Function(std::string::String), } impl Value<'_> { @@ -76,13 +76,19 @@ impl<'a> TryFrom<&'a Option> for Value<'a> { Some(NodeEnum::AConst(a_const)) => Ok(a_const.into()), Some(NodeEnum::ParamRef(param_ref)) => Ok(Value::Placeholder(param_ref.number)), Some(NodeEnum::FuncCall(func)) => { - if let Some(Node { - node: Some(NodeEnum::String(sval)), - }) = func.funcname.first() - { - Ok(Value::Function(&sval.sval)) - } else { + let mut name = Vec::new(); + for comp in func.funcname.iter() { + if let Node { + node: Some(NodeEnum::String(sval)), + } = comp + { + name.push(sval.sval.to_string()); + } + } + if name.is_empty() { Ok(Value::Null) + } else { + Ok(Value::Function(name.join("."))) } } Some(NodeEnum::TypeCast(cast)) => { diff --git a/pgdog/src/frontend/router/rewrite/error.rs b/pgdog/src/frontend/router/rewrite/error.rs new file mode 100644 index 000000000..24ba8159d --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/error.rs @@ -0,0 +1,16 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("parser error")] + ParserError, + + #[error("unique id: {0}")] + UniqueId(#[from] crate::unique_id::Error), + + #[error("pg_query: {0}")] + PgQuery(#[from] pg_query::Error), + + #[error("net: {0}")] + Net(#[from] crate::net::Error), +} diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs new file mode 100644 index 000000000..4bd04d941 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -0,0 +1,12 @@ +//! Query rewrite engine. +//! +//! It handles the following scenarios: +//! +//! 1. Sharding key UPDATE: rewrite to send a DELETE and INSERT +//! 2. Multi-tuple INSERT: rewrite to send multiple INSERTs +//! 3. pgdog.unique_id() call: inject a unique ID +//! +pub mod error; +pub mod unique_id; + +pub use error::Error; diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs new file mode 100644 index 000000000..d852bdb88 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -0,0 +1,217 @@ +use pg_query::{ + protobuf::{a_const::Val, AConst, InsertStmt, ParamRef}, + Node, NodeEnum, +}; + +use super::super::Error; +use crate::{ + frontend::{ + router::parser::{Insert, Value}, + PreparedStatements, + }, + net::{Bind, Datum, Parse, Query}, + unique_id, +}; + +pub struct InsertUniqueIdRewrite<'a> { + stmt: &'a InsertStmt, + bind: Option<&'a Bind>, +} + +#[derive(Debug, Clone)] +pub enum InsertRewriteResult { + Extended { parse: Parse, bind: Bind }, + Simple { query: Query }, +} + +#[derive(Debug, Clone)] +pub struct InsertRewriteOutput { + /// Rewritten AST. This should be used for any subsequent operations. + pub stmt: InsertStmt, + /// Rewritten Query or Parse and Bind. + pub rewrite: InsertRewriteResult, +} + +impl InsertRewriteOutput { + /// Get query text. + pub fn query(&self) -> &str { + match &self.rewrite { + InsertRewriteResult::Extended { parse, .. } => parse.query(), + InsertRewriteResult::Simple { query } => query.query(), + } + } +} + +impl<'a> InsertUniqueIdRewrite<'a> { + /// Create new INSERT statement rewriter + pub fn new(stmt: &'a InsertStmt, bind: Option<&'a Bind>) -> Self { + Self { stmt, bind } + } + + /// Handle statement rewrite + pub fn rewrite(&self) -> Result, Error> { + let mut need_rewrite = false; + + let wrapper = Insert::new(self.stmt); + + for tuple in wrapper.tuples() { + for value in tuple.values { + if let Value::Function(ref func) = value { + if *func == "pgdog.unique_id" { + need_rewrite = true; + } + } + } + } + + if !need_rewrite { + return Ok(None); + } + + let mut bind = self.bind.cloned(); + let mut stmt = self.stmt.clone(); + + let select = stmt + .select_stmt + .as_mut() + .ok_or(Error::ParserError)? + .node + .as_mut() + .ok_or(Error::ParserError)?; + if let NodeEnum::SelectStmt(stmt) = select { + for tuple in stmt.values_lists.iter_mut() { + if let Some(NodeEnum::List(ref mut tuple)) = tuple.node { + for column in tuple.items.iter_mut() { + if let Ok(Value::Function(name)) = Value::try_from(&column.node) { + // Replace function call with value. + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if let Some(ref mut bind) = bind { + NodeEnum::ParamRef(ParamRef { + number: bind.add_parameter(Datum::Bigint(id))?, + ..Default::default() + }) + } else { + NodeEnum::AConst(AConst { + val: Some(Val::Sval(pg_query::protobuf::String { + sval: id.to_string(), + })), + ..Default::default() + }) + }; + + column.node = Some(node); + } + } + } + } + } + } + + let wrapper = Node { + node: Some(NodeEnum::InsertStmt(Box::new(stmt.clone()))), + ..Default::default() + }; + + let rewrite = match bind { + Some(mut bind) => { + let mut parse = Parse::new_anonymous(&wrapper.deparse()?); + if !bind.anonymous() { + let (_, name) = PreparedStatements::global().write().insert(&parse); + parse.rename_fast(&name); + bind.rename(&name); + } + + InsertRewriteResult::Extended { parse, bind } + } + + None => InsertRewriteResult::Simple { + query: Query::new(wrapper.deparse()?), + }, + }; + + Ok(Some(InsertRewriteOutput { stmt, rewrite })) + } +} + +#[cfg(test)] +mod test { + + use std::env::set_var; + + use crate::net::bind::Parameter; + + use super::*; + + fn insert_root(query: &str) -> InsertStmt { + let stmt = pg_query::parse(query).unwrap(); + let root = stmt + .protobuf + .stmts + .first() + .cloned() + .unwrap() + .stmt + .unwrap() + .node + .unwrap(); + if let NodeEnum::InsertStmt(stmt) = root { + *stmt.clone() + } else { + panic!("not an insert") + } + } + + #[test] + fn test_unique_id_insert() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = insert_root( + r#" + INSERT INTO omnisharded (id, settings) + VALUES + (pgdog.unique_id(), '{}'::JSONB), + (pgdog.unique_id(), '{"hello": "world"}'::JSONB)"#, + ); + let insert = InsertUniqueIdRewrite::new(&stmt, None); + let rewrite = insert.rewrite().unwrap().unwrap(); + assert!(!rewrite.query().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_insert_parse() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = insert_root( + r#" + INSERT INTO omnisharded (id, settings) + VALUES + (pgdog.unique_id(), $1::JSONB), + (pgdog.unique_id(), $2::JSONB)"#, + ); + let bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 2, + data: "{}".into(), + }, + ], + ); + let rewrite = InsertUniqueIdRewrite::new(&stmt, Some(&bind)) + .rewrite() + .unwrap() + .unwrap(); + assert_eq!( + rewrite.query(), + "INSERT INTO omnisharded (id, settings) VALUES ($3, $1::jsonb), ($4, $2::jsonb)" + ); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs new file mode 100644 index 000000000..848372c91 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -0,0 +1,3 @@ +//! Unique ID rewrite engine. + +pub mod insert; diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index b7042fc6d..e1803593f 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -1,5 +1,6 @@ //! Bind (F) message. use crate::net::c_string_buf_len; +use crate::net::Datum; use uuid::Uuid; use super::code; @@ -203,6 +204,28 @@ impl Bind { &self.codes } + /// Add parameter and return its placeholder number. + pub fn add_parameter(&mut self, data: Datum) -> Result { + let format = { + let codes = self.codes(); + if codes.is_empty() { + Format::Text + } else if codes.len() == 1 { + codes[0] + } else { + self.codes.push(Format::Text); + Format::Text + } + }; + let bytes = data.encode(format)?; + self.params.push(Parameter { + len: bytes.len() as i32, + data: bytes, + }); + // Param codes are 1-indexed. + Ok(self.params.len() as i32) + } + pub fn new_statement(name: &str) -> Self { Self { statement: Bytes::from(name.to_string() + "\0"), diff --git a/pgdog/src/net/messages/parse.rs b/pgdog/src/net/messages/parse.rs index 06e2be6a7..7481e70f8 100644 --- a/pgdog/src/net/messages/parse.rs +++ b/pgdog/src/net/messages/parse.rs @@ -38,7 +38,6 @@ impl Parse { } /// New anonymous prepared statement. - #[cfg(test)] pub fn new_anonymous(query: &str) -> Self { Self { name: Bytes::from("\0"), diff --git a/pgdog/src/unique_id.rs b/pgdog/src/unique_id.rs index 660e2814b..043612015 100644 --- a/pgdog/src/unique_id.rs +++ b/pgdog/src/unique_id.rs @@ -8,13 +8,13 @@ //! clock, so `std::time::SystemTime` returns a good value. //! use std::sync::Arc; +use std::thread::sleep; use std::time::UNIX_EPOCH; use std::time::{Duration, SystemTime}; use once_cell::sync::OnceCell; +use parking_lot::Mutex; use thiserror::Error; -use tokio::sync::Mutex; -use tokio::time::sleep; use crate::config::config; use crate::util::{instance_id, node_id}; @@ -52,14 +52,14 @@ impl Default for State { impl State { // Generate next unique ID in a distributed sequence. // The `node_id` argument must be globally unique. - async fn next_id(&mut self, node_id: u64, id_offset: u64) -> u64 { - let mut now = wait_until(self.last_timestamp_ms).await; + fn next_id(&mut self, node_id: u64, id_offset: u64) -> u64 { + let mut now = wait_until(self.last_timestamp_ms); if now == self.last_timestamp_ms { self.sequence = (self.sequence + 1) & MAX_SEQUENCE; // Wraparound. if self.sequence == 0 { - now = wait_until(now + 1).await; + now = wait_until(now + 1); } } else { // Reset sequence to zero once we reach next ms. @@ -92,13 +92,13 @@ fn now_ms() -> u64 { // Get a monotonically increasing timestamp in ms. // Protects against clock drift. -async fn wait_until(target_ms: u64) -> u64 { +fn wait_until(target_ms: u64) -> u64 { loop { let now = now_ms(); if now >= target_ms { return now; } - sleep(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)); } } @@ -149,12 +149,8 @@ impl UniqueId { } /// Generate a globally unique, monotonically increasing identifier. - pub async fn next_id(&self) -> i64 { - self.inner - .lock() - .await - .next_id(self.node_id, self.id_offset) - .await as i64 + pub fn next_id(&self) -> i64 { + self.inner.lock().next_id(self.node_id, self.id_offset) as i64 } } @@ -164,8 +160,8 @@ mod test { use super::*; - #[tokio::test] - async fn test_unique_ids() { + #[test] + fn test_unique_ids() { unsafe { set_var("NODE_ID", "pgdog-1"); } @@ -174,32 +170,32 @@ mod test { let mut ids = HashSet::new(); for _ in 0..num_ids { - ids.insert(UniqueId::generator().unwrap().next_id().await); + ids.insert(UniqueId::generator().unwrap().next_id()); } assert_eq!(ids.len(), num_ids); } - #[tokio::test] - async fn test_ids_monotonically_increasing() { + #[test] + fn test_ids_monotonically_increasing() { let mut state = State::default(); let node_id = 1u64; let mut prev_id = 0u64; for _ in 0..10_000 { - let id = state.next_id(node_id, 0).await; + let id = state.next_id(node_id, 0); assert!(id > prev_id, "ID {id} not greater than previous {prev_id}"); prev_id = id; } } - #[tokio::test] - async fn test_ids_always_positive() { + #[test] + fn test_ids_always_positive() { let mut state = State::default(); let node_id = MAX_NODE_ID; // Use max node to maximize bits used for _ in 0..10_000 { - let id = state.next_id(node_id, 0).await; + let id = state.next_id(node_id, 0); let signed = id as i64; assert!(signed > 0, "ID should be positive, got {signed}"); } @@ -227,12 +223,12 @@ mod test { assert_eq!(id >> 63, 0, "Bit 63 should be clear"); } - #[tokio::test] - async fn test_extract_components() { + #[test] + fn test_extract_components() { let node: u64 = 42; let mut state = State::default(); - let id = state.next_id(node, 0).await; + let id = state.next_id(node, 0); // Extract components back let extracted_seq = id & MAX_SEQUENCE; @@ -244,7 +240,7 @@ mod test { assert!(extracted_elapsed > 0); // Elapsed time since epoch // Generate another ID and verify sequence increments - let id2 = state.next_id(node, 0).await; + let id2 = state.next_id(node, 0); let extracted_seq2 = id2 & MAX_SEQUENCE; let extracted_node2 = (id2 >> NODE_SHIFT) & MAX_NODE_ID; @@ -252,14 +248,14 @@ mod test { assert!(matches!(extracted_seq2, 1 | 0)); // Sequence incremented (or time advanced and reset to 0) } - #[tokio::test] - async fn test_id_offset() { + #[test] + fn test_id_offset() { let offset: u64 = 1_000_000_000; let node: u64 = 5; let mut state = State::default(); for _ in 0..1000 { - let id = state.next_id(node, offset).await; + let id = state.next_id(node, offset); assert!( id > offset, "ID {id} should be greater than offset {offset}" @@ -267,15 +263,15 @@ mod test { } } - #[tokio::test] - async fn test_id_offset_monotonic() { + #[test] + fn test_id_offset_monotonic() { let offset: u64 = 1_000_000_000; let node: u64 = 5; let mut state = State::default(); let mut prev_id = 0u64; for _ in 0..1000 { - let id = state.next_id(node, offset).await; + let id = state.next_id(node, offset); assert!(id > prev_id, "ID {id} not greater than previous {prev_id}"); prev_id = id; } From 89b7a5d0625142e6e3c5e3461f5916041eff973e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 29 Nov 2025 13:05:53 -0800 Subject: [PATCH 02/23] Rewrite 3.0 save --- pgdog/src/frontend/client_request.rs | 12 +- pgdog/src/frontend/router/parser/error.rs | 3 + pgdog/src/frontend/router/parser/query/mod.rs | 25 +- pgdog/src/frontend/router/rewrite/error.rs | 12 + pgdog/src/frontend/router/rewrite/input.rs | 126 +++++++ .../src/frontend/router/rewrite/interface.rs | 14 + pgdog/src/frontend/router/rewrite/mod.rs | 39 +++ .../router/rewrite/prepared/execute.rs | 90 +++++ .../frontend/router/rewrite/prepared/mod.rs | 33 ++ .../router/rewrite/prepared/prepare.rs | 74 +++++ .../router/rewrite/unique_id/insert.rs | 199 +++++------ .../frontend/router/rewrite/unique_id/mod.rs | 39 +++ .../router/rewrite/unique_id/select.rs | 313 ++++++++++++++++++ .../router/rewrite/unique_id/update.rs | 137 ++++++++ pgdog/src/net/messages/parse.rs | 4 +- 15 files changed, 983 insertions(+), 137 deletions(-) create mode 100644 pgdog/src/frontend/router/rewrite/input.rs create mode 100644 pgdog/src/frontend/router/rewrite/interface.rs create mode 100644 pgdog/src/frontend/router/rewrite/prepared/execute.rs create mode 100644 pgdog/src/frontend/router/rewrite/prepared/mod.rs create mode 100644 pgdog/src/frontend/router/rewrite/prepared/prepare.rs create mode 100644 pgdog/src/frontend/router/rewrite/unique_id/select.rs create mode 100644 pgdog/src/frontend/router/rewrite/unique_id/update.rs diff --git a/pgdog/src/frontend/client_request.rs b/pgdog/src/frontend/client_request.rs index 825073174..b289aea36 100644 --- a/pgdog/src/frontend/client_request.rs +++ b/pgdog/src/frontend/client_request.rs @@ -181,11 +181,15 @@ impl ClientRequest { /// Rewrite query in buffer. pub fn rewrite(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { - if self.messages.iter().any(|c| c.code() != 'Q') { - return Err(Error::OnlySimpleForRewrites); + for new_message in request { + if let Some(pos) = self + .messages + .iter() + .position(|message| message.code() == new_message.code()) + { + self.messages[pos] = new_message.clone(); + } } - self.messages.clear(); - self.messages.extend(request.to_vec()); Ok(()) } diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index 2d360f831..aba4975e5 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -99,4 +99,7 @@ pub enum Error { #[error("prepared statement \"{0}\" doesn't exist")] PreparedStatementDoesntExist(String), + + #[error("rewrite: {0}")] + Rewrite(#[from] super::super::rewrite::Error), } diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index fd91b4d2b..773bd9173 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -8,6 +8,7 @@ use crate::{ router::{ context::RouterContext, parser::{rewrite::Rewrite, OrderBy, Shard}, + rewrite::{self, RewriteModule}, round_robin, sharding::{Centroids, ContextBuilder, Value as ShardingValue}, }, @@ -16,6 +17,7 @@ use crate::{ net::{ messages::{Bind, Vector}, parameter::ParameterValue, + ProtocolMessage, }, plugin::plugins, }; @@ -195,6 +197,23 @@ impl QueryParser { } }; + let mut input = rewrite::Input::new(&statement.ast().protobuf, context.router_context.bind); + rewrite::Rewrite::new(context.prepared_statements()).rewrite(&mut input)?; + + match input.build()? { + rewrite::Output::NoOp => (), + rewrite::Output::Extended { parse, bind } => { + return Ok(Command::Rewrite(vec![ + ProtocolMessage::from(parse), + bind.into(), + ])) + } + rewrite::Output::Simple { query } => { + return Ok(Command::Rewrite(vec![ProtocolMessage::from(query)])) + } + _ => todo!("multi rewrite not supported yet"), + } + self.ensure_explain_recorder(statement.ast(), context); // Parse hardcoded shard from a query comment. @@ -218,12 +237,6 @@ impl QueryParser { debug!("{}", context.query()?.query()); trace!("{:#?}", statement); - let rewrite = Rewrite::new(statement.ast()); - if rewrite.needs_rewrite() { - debug!("rewrite needed"); - return rewrite.rewrite(context.prepared_statements()); - } - if let Some(multi_tenant) = context.multi_tenant() { debug!("running multi-tenant check"); diff --git a/pgdog/src/frontend/router/rewrite/error.rs b/pgdog/src/frontend/router/rewrite/error.rs index 24ba8159d..cd52af219 100644 --- a/pgdog/src/frontend/router/rewrite/error.rs +++ b/pgdog/src/frontend/router/rewrite/error.rs @@ -13,4 +13,16 @@ pub enum Error { #[error("net: {0}")] Net(#[from] crate::net::Error), + + #[error("rewrite engine didn't rewrite bind")] + NoBind, + + #[error("empty query")] + EmptyQuery, + + #[error("no rewrite")] + NoRewrite, + + #[error("prepared statement not found: {0}")] + PreparedStatementNotFound(String), } diff --git a/pgdog/src/frontend/router/rewrite/input.rs b/pgdog/src/frontend/router/rewrite/input.rs new file mode 100644 index 000000000..5dc5d6a32 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/input.rs @@ -0,0 +1,126 @@ +//! Rewrite input and output types. + +use pg_query::protobuf::{ParseResult, RawStmt}; + +use super::Error; +use crate::{ + frontend::PreparedStatements, + net::{Bind, Parse, Query}, +}; + +#[derive(Debug, Clone)] +pub struct Input<'a> { + // Most requeries won't require a rewrite. + // This is a clone-free way to check. + original: &'a ParseResult, + // If a rewrite was done, the statement is saved here. + rewrite: Option, + /// Original bind message, if any. + bind: Option<&'a Bind>, + /// Bind rewritten. + rewrite_bind: Option, +} + +impl<'a> Input<'a> { + /// Create new input. + pub fn new(original: &'a ParseResult, bind: Option<&'a Bind>) -> Self { + Self { + original, + bind, + rewrite: None, + rewrite_bind: None, + } + } + + /// Get the Bind message, if set. + pub fn bind(&'a self) -> Option<&'a Bind> { + if let Some(ref rewrite_bind) = self.rewrite_bind { + Some(rewrite_bind) + } else { + self.bind + } + } + + /// Take the Bind message for modification. + /// Don't forget to return it. + #[must_use] + pub fn bind_take(&mut self) -> Option { + if self.rewrite_bind.is_none() { + self.rewrite_bind = self.bind.cloned(); + } + + self.rewrite_bind.take() + } + + pub fn bind_put(&mut self, bind: Option) { + self.rewrite_bind = bind; + } + + /// Get the original (or modified) statement. + pub fn stmt(&'a self) -> Result<&'a RawStmt, Error> { + let stmt = if let Some(ref rewrite) = self.rewrite { + rewrite + } else { + self.original + }; + let root = stmt.stmts.first().ok_or(Error::EmptyQuery)?; + Ok(root) + } + + /// Get the mutable statement we're rewriting. + pub fn stmt_mut(&mut self) -> Result<&mut RawStmt, Error> { + let stmt = if let Some(ref mut rewrite) = self.rewrite { + rewrite + } else { + self.rewrite = Some(self.original.clone()); + self.rewrite.as_mut().unwrap() + }; + + Ok(stmt.stmts.first_mut().ok_or(Error::EmptyQuery)?) + } + + /// Assemble statement and add it to the global prepared statements cache. + pub fn build(mut self) -> Result { + if self.rewrite.is_none() { + Ok(Output::NoOp) + } else { + let bind = self.rewrite_bind.take(); + let stmt = self.rewrite.take().ok_or(Error::NoRewrite)?.deparse()?; + + if let Some(mut bind) = bind { + let mut parse = Parse::new_anonymous(stmt); + if bind.anonymous() { + Ok(Output::Extended { parse, bind }) + } else { + let (_, name) = PreparedStatements::global().write().insert(&parse); + parse.rename_fast(&name); + bind.rename(name); + Ok(Output::Extended { parse, bind }) + } + } else { + Ok(Output::Simple { + query: Query::new(stmt), + }) + } + } + } +} + +#[derive(Debug, Clone)] +pub enum Output { + NoOp, + Extended { parse: Parse, bind: Bind }, + Simple { query: Query }, + Multi(Vec>), +} + +impl Output { + /// Get rewritten query, if any. + pub fn query(&self) -> Result<&str, ()> { + match self { + Self::Extended { parse, .. } => Ok(parse.query()), + Self::Simple { query } => Ok(query.query()), + _ => Err(()), + } + } +} diff --git a/pgdog/src/frontend/router/rewrite/interface.rs b/pgdog/src/frontend/router/rewrite/interface.rs new file mode 100644 index 000000000..6099bdd41 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/interface.rs @@ -0,0 +1,14 @@ +//! Rewrite module interface. + +use super::{Error, Input}; + +/// Rewrite trait. +/// +/// All rewrite modules should follow this. +pub trait RewriteModule { + /// Take a statement and maybe rewrite it, if needed. + /// + /// If a rewrite is needed, the module should mutate the statement + /// and update the Bind message. + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error>; +} diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index 4bd04d941..a63fd367c 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -7,6 +7,45 @@ //! 3. pgdog.unique_id() call: inject a unique ID //! pub mod error; +pub mod input; +pub mod interface; +pub mod prepared; pub mod unique_id; pub use error::Error; +pub use input::{Input, Output}; +pub use interface::RewriteModule; + +use crate::frontend::PreparedStatements; + +/// Combined rewrite engine that runs all rewrite modules. +pub struct Rewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> Rewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for Rewrite<'_> { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + // N.B.: the ordering here matters! + // + // First, we need to inject the unique ID into the query. Once that's done, + // we can proceed with additional rewrites. + + // Unique ID rewrites + unique_id::insert::InsertUniqueIdRewrite::default().rewrite(input)?; + unique_id::update::UpdateUniqueIdRewrite::default().rewrite(input)?; + unique_id::select::SelectUniqueIdRewrite::default().rewrite(input)?; + + // Prepared statement rewrites + prepared::PreparedRewrite::new(self.prepared_statements).rewrite(input)?; + + Ok(()) + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/execute.rs b/pgdog/src/frontend/router/rewrite/prepared/execute.rs new file mode 100644 index 000000000..16c087ade --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/execute.rs @@ -0,0 +1,90 @@ +//! EXECUTE statement rewriter. + +use pg_query::NodeEnum; + +use super::super::{Error, Input, RewriteModule}; +use crate::frontend::PreparedStatements; + +/// Rewriter for EXECUTE statements. +/// +/// Renames the executed statement to use the globally cached name. +pub struct ExecuteRewrite<'a> { + prepared_statements: &'a PreparedStatements, +} + +impl<'a> ExecuteRewrite<'a> { + pub fn new(prepared_statements: &'a PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for ExecuteRewrite<'_> { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + if let Some(NodeEnum::ExecuteStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + let parse = self + .prepared_statements + .parse(&stmt.name) + .ok_or_else(|| Error::PreparedStatementNotFound(stmt.name.clone()))?; + + let new_name = parse.name().to_string(); + + if let Some(NodeEnum::ExecuteStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + stmt.name = new_name; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::{super::PrepareRewrite, *}; + + #[test] + fn test_execute_rewrite() { + let prepare_stmt = pg_query::parse("PREPARE test AS SELECT $1, $2, $3") + .unwrap() + .protobuf; + let mut prepared_statements = PreparedStatements::default(); + + // First prepare the statement + let mut prepare_rewrite = PrepareRewrite::new(&mut prepared_statements); + let mut input = Input::new(&prepare_stmt, None); + prepare_rewrite.rewrite(&mut input).unwrap(); + + // Now execute it + let execute_stmt = pg_query::parse("EXECUTE test(1, 2, 3)").unwrap().protobuf; + let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); + let mut input = Input::new(&execute_stmt, None); + execute_rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(query.contains("__pgdog_")); + assert!(!query.contains("EXECUTE test")); + } + + #[test] + fn test_execute_not_found() { + let execute_stmt = pg_query::parse("EXECUTE nonexistent(1, 2, 3)") + .unwrap() + .protobuf; + let prepared_statements = PreparedStatements::default(); + let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); + let mut input = Input::new(&execute_stmt, None); + let result = execute_rewrite.rewrite(&mut input); + assert!(result.is_err()); + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/mod.rs b/pgdog/src/frontend/router/rewrite/prepared/mod.rs new file mode 100644 index 000000000..3bc7d40aa --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/mod.rs @@ -0,0 +1,33 @@ +//! Prepared statement rewriter. +//! +//! Rewrites PREPARE and EXECUTE statements to use globally cached names. + +mod execute; +mod prepare; + +pub use execute::ExecuteRewrite; +pub use prepare::PrepareRewrite; + +use super::{Error, Input, RewriteModule}; +use crate::frontend::PreparedStatements; + +/// Combined rewriter for PREPARE and EXECUTE statements. +pub struct PreparedRewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> PreparedRewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for PreparedRewrite<'_> { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + PrepareRewrite::new(self.prepared_statements).rewrite(input)?; + ExecuteRewrite::new(self.prepared_statements).rewrite(input)?; + Ok(()) + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs new file mode 100644 index 000000000..8f76492ed --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs @@ -0,0 +1,74 @@ +//! PREPARE statement rewriter. + +use pg_query::NodeEnum; + +use super::super::{Error, Input, RewriteModule}; +use crate::frontend::PreparedStatements; + +/// Rewriter for PREPARE statements. +/// +/// Renames the prepared statement to use a globally unique name from the cache. +pub struct PrepareRewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> PrepareRewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for PrepareRewrite<'_> { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + if let Some(NodeEnum::PrepareStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + let statement = stmt + .query + .as_ref() + .ok_or(Error::EmptyQuery)? + .deparse() + .map_err(|_| Error::EmptyQuery)?; + + let mut parse = crate::net::Parse::named(&stmt.name, &statement); + self.prepared_statements.insert_anyway(&mut parse); + let new_name = parse.name().to_string(); + + if let Some(NodeEnum::PrepareStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + stmt.name = new_name; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_prepare_rewrite() { + let stmt = pg_query::parse("PREPARE test AS SELECT $1, $2, $3") + .unwrap() + .protobuf; + let mut prepared_statements = PreparedStatements::default(); + let mut rewrite = PrepareRewrite::new(&mut prepared_statements); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(query.contains("__pgdog_")); + assert!(!query.contains("PREPARE test")); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index d852bdb88..c7d9439ba 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -1,83 +1,66 @@ -use pg_query::{ - protobuf::{a_const::Val, AConst, InsertStmt, ParamRef}, - Node, NodeEnum, -}; +use pg_query::{protobuf::ParamRef, NodeEnum}; -use super::super::Error; +use super::{ + super::{Error, Input, RewriteModule}, + bigint_const, +}; use crate::{ - frontend::{ - router::parser::{Insert, Value}, - PreparedStatements, - }, - net::{Bind, Datum, Parse, Query}, + frontend::router::parser::{Insert, Value}, + net::Datum, unique_id, }; -pub struct InsertUniqueIdRewrite<'a> { - stmt: &'a InsertStmt, - bind: Option<&'a Bind>, -} - -#[derive(Debug, Clone)] -pub enum InsertRewriteResult { - Extended { parse: Parse, bind: Bind }, - Simple { query: Query }, -} - -#[derive(Debug, Clone)] -pub struct InsertRewriteOutput { - /// Rewritten AST. This should be used for any subsequent operations. - pub stmt: InsertStmt, - /// Rewritten Query or Parse and Bind. - pub rewrite: InsertRewriteResult, -} - -impl InsertRewriteOutput { - /// Get query text. - pub fn query(&self) -> &str { - match &self.rewrite { - InsertRewriteResult::Extended { parse, .. } => parse.query(), - InsertRewriteResult::Simple { query } => query.query(), - } - } -} - -impl<'a> InsertUniqueIdRewrite<'a> { - /// Create new INSERT statement rewriter - pub fn new(stmt: &'a InsertStmt, bind: Option<&'a Bind>) -> Self { - Self { stmt, bind } - } +#[derive(Default)] +pub struct InsertUniqueIdRewrite {} +impl RewriteModule for InsertUniqueIdRewrite { /// Handle statement rewrite - pub fn rewrite(&self) -> Result, Error> { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { let mut need_rewrite = false; - let wrapper = Insert::new(self.stmt); - - for tuple in wrapper.tuples() { - for value in tuple.values { - if let Value::Function(ref func) = value { - if *func == "pgdog.unique_id" { - need_rewrite = true; + if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .map(|stmt| stmt.node.as_ref()) + .flatten() + { + let wrapper = Insert::new(stmt); + + for tuple in wrapper.tuples() { + for value in tuple.values { + if let Value::Function(ref func) = value { + if *func == "pgdog.unique_id" { + need_rewrite = true; + } } } } - } - if !need_rewrite { - return Ok(None); + if !need_rewrite { + return Ok(()); + } } - let mut bind = self.bind.cloned(); - let mut stmt = self.stmt.clone(); + let mut bind = input.bind_take(); - let select = stmt - .select_stmt - .as_mut() - .ok_or(Error::ParserError)? - .node + let select = if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt_mut()? + .stmt .as_mut() - .ok_or(Error::ParserError)?; + .map(|stmt| stmt.node.as_mut()) + .flatten() + { + stmt.select_stmt + .as_mut() + .ok_or(Error::ParserError)? + .node + .as_mut() + .ok_or(Error::ParserError)? + } else { + return Ok(()); + }; + if let NodeEnum::SelectStmt(stmt) = select { for tuple in stmt.values_lists.iter_mut() { if let Some(NodeEnum::List(ref mut tuple)) = tuple.node { @@ -93,12 +76,7 @@ impl<'a> InsertUniqueIdRewrite<'a> { ..Default::default() }) } else { - NodeEnum::AConst(AConst { - val: Some(Val::Sval(pg_query::protobuf::String { - sval: id.to_string(), - })), - ..Default::default() - }) + bigint_const(id) }; column.node = Some(node); @@ -109,75 +87,43 @@ impl<'a> InsertUniqueIdRewrite<'a> { } } - let wrapper = Node { - node: Some(NodeEnum::InsertStmt(Box::new(stmt.clone()))), - ..Default::default() - }; - - let rewrite = match bind { - Some(mut bind) => { - let mut parse = Parse::new_anonymous(&wrapper.deparse()?); - if !bind.anonymous() { - let (_, name) = PreparedStatements::global().write().insert(&parse); - parse.rename_fast(&name); - bind.rename(&name); - } + input.bind_put(bind); - InsertRewriteResult::Extended { parse, bind } - } - - None => InsertRewriteResult::Simple { - query: Query::new(wrapper.deparse()?), - }, - }; - - Ok(Some(InsertRewriteOutput { stmt, rewrite })) + Ok(()) } } #[cfg(test)] mod test { - - use std::env::set_var; - - use crate::net::bind::Parameter; - use super::*; - - fn insert_root(query: &str) -> InsertStmt { - let stmt = pg_query::parse(query).unwrap(); - let root = stmt - .protobuf - .stmts - .first() - .cloned() - .unwrap() - .stmt - .unwrap() - .node - .unwrap(); - if let NodeEnum::InsertStmt(stmt) = root { - *stmt.clone() - } else { - panic!("not an insert") - } - } + use crate::net::bind::{Bind, Parameter}; + use std::env::set_var; #[test] fn test_unique_id_insert() { unsafe { set_var("NODE_ID", "pgdog-prod-1"); } - let stmt = insert_root( + let stmt = pg_query::parse( r#" INSERT INTO omnisharded (id, settings) VALUES (pgdog.unique_id(), '{}'::JSONB), (pgdog.unique_id(), '{"hello": "world"}'::JSONB)"#, + ) + .unwrap() + .protobuf; + let mut insert = InsertUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + insert.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!( + query.contains("bigint"), + "Query should contain bigint cast: {}", + query ); - let insert = InsertUniqueIdRewrite::new(&stmt, None); - let rewrite = insert.rewrite().unwrap().unwrap(); - assert!(!rewrite.query().contains("pgdog.unique_id")); } #[test] @@ -185,13 +131,15 @@ mod test { unsafe { set_var("NODE_ID", "pgdog-prod-1"); } - let stmt = insert_root( + let stmt = pg_query::parse( r#" INSERT INTO omnisharded (id, settings) VALUES (pgdog.unique_id(), $1::JSONB), (pgdog.unique_id(), $2::JSONB)"#, - ); + ) + .unwrap() + .protobuf; let bind = Bind::new_params( "", &[ @@ -205,12 +153,13 @@ mod test { }, ], ); - let rewrite = InsertUniqueIdRewrite::new(&stmt, Some(&bind)) - .rewrite() - .unwrap() + let mut input = Input::new(&stmt, Some(&bind)); + InsertUniqueIdRewrite::default() + .rewrite(&mut input) .unwrap(); + let output = input.build().unwrap(); assert_eq!( - rewrite.query(), + output.query().unwrap(), "INSERT INTO omnisharded (id, settings) VALUES ($3, $1::jsonb), ($4, $2::jsonb)" ); } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs index 848372c91..5b7f21966 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -1,3 +1,42 @@ //! Unique ID rewrite engine. +use pg_query::{ + protobuf::{a_const::Val, AConst, Node, TypeCast, TypeName}, + NodeEnum, +}; + pub mod insert; +pub mod select; +pub mod update; + +pub struct UniqueIdRewrite; + +/// Create a bigint-typed constant node for the given ID. +fn bigint_const(id: i64) -> NodeEnum { + NodeEnum::TypeCast(Box::new(TypeCast { + arg: Some(Box::new(Node { + node: Some(NodeEnum::AConst(AConst { + val: Some(Val::Sval(pg_query::protobuf::String { + sval: id.to_string(), + })), + ..Default::default() + })), + })), + type_name: Some(TypeName { + names: vec![ + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "pg_catalog".to_string(), + })), + }, + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "int8".to_string(), + })), + }, + ], + ..Default::default() + }), + ..Default::default() + })) +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs new file mode 100644 index 000000000..3c86fe83e --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -0,0 +1,313 @@ +//! SELECT statement rewriter for unique_id. + +use pg_query::{ + protobuf::{Node, ParamRef, SelectStmt}, + NodeEnum, +}; + +use super::{ + super::{Error, Input, RewriteModule}, + bigint_const, +}; +use crate::{frontend::router::parser::Value, net::Datum, unique_id}; + +#[derive(Default)] +pub struct SelectUniqueIdRewrite {} + +impl SelectUniqueIdRewrite { + fn needs_rewrite(stmt: &SelectStmt) -> bool { + // Check target_list (SELECT columns) + for target in &stmt.target_list { + if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { + if let Some(val) = &res.val { + if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { + if func == "pgdog.unique_id" { + return true; + } + } + } + } + } + + // Check CTEs recursively + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + if let Some(NodeEnum::CommonTableExpr(ref expr)) = cte.node { + if let Some(ref query) = expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref inner)) = query.node { + if Self::needs_rewrite(inner) { + return true; + } + } + } + } + } + } + + // Check subqueries in FROM clause + for from in &stmt.from_clause { + if Self::needs_rewrite_from_node(from) { + return true; + } + } + + // Check UNION/INTERSECT/EXCEPT (larg/rarg are Box) + if let Some(ref larg) = stmt.larg { + if Self::needs_rewrite(larg) { + return true; + } + } + if let Some(ref rarg) = stmt.rarg { + if Self::needs_rewrite(rarg) { + return true; + } + } + + false + } + + fn needs_rewrite_from_node(node: &Node) -> bool { + match node.node.as_ref() { + Some(NodeEnum::RangeSubselect(subselect)) => { + if let Some(ref subquery) = subselect.subquery { + if let Some(NodeEnum::SelectStmt(ref inner)) = subquery.node { + return Self::needs_rewrite(inner); + } + } + false + } + Some(NodeEnum::JoinExpr(join)) => { + let left = join + .larg + .as_ref() + .map_or(false, |n| Self::needs_rewrite_from_node(n)); + let right = join + .rarg + .as_ref() + .map_or(false, |n| Self::needs_rewrite_from_node(n)); + left || right + } + _ => false, + } + } + + fn rewrite_select( + stmt: &mut SelectStmt, + bind: &mut Option, + ) -> Result<(), Error> { + // Rewrite target_list + for target in stmt.target_list.iter_mut() { + if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { + if let Some(ref mut val) = res.val { + if let Ok(Value::Function(name)) = Value::try_from(&val.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if let Some(ref mut bind) = bind { + NodeEnum::ParamRef(ParamRef { + number: bind.add_parameter(Datum::Bigint(id))?, + ..Default::default() + }) + } else { + bigint_const(id) + }; + + val.node = Some(node); + } + } + } + } + } + + // Rewrite CTEs recursively + if let Some(ref mut with_clause) = stmt.with_clause { + for cte in with_clause.ctes.iter_mut() { + if let Some(NodeEnum::CommonTableExpr(ref mut expr)) = cte.node { + if let Some(ref mut query) = expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref mut inner)) = query.node { + Self::rewrite_select(inner, bind)?; + } + } + } + } + } + + // Rewrite subqueries in FROM clause + for from in stmt.from_clause.iter_mut() { + Self::rewrite_from_node(from, bind)?; + } + + // Rewrite UNION/INTERSECT/EXCEPT (larg/rarg are Box) + if let Some(ref mut larg) = stmt.larg { + Self::rewrite_select(larg, bind)?; + } + if let Some(ref mut rarg) = stmt.rarg { + Self::rewrite_select(rarg, bind)?; + } + + Ok(()) + } + + fn rewrite_from_node( + node: &mut Node, + bind: &mut Option, + ) -> Result<(), Error> { + match node.node.as_mut() { + Some(NodeEnum::RangeSubselect(ref mut subselect)) => { + if let Some(ref mut subquery) = subselect.subquery { + if let Some(NodeEnum::SelectStmt(ref mut inner)) = subquery.node { + Self::rewrite_select(inner, bind)?; + } + } + } + Some(NodeEnum::JoinExpr(ref mut join)) => { + if let Some(ref mut larg) = join.larg { + Self::rewrite_from_node(larg, bind)?; + } + if let Some(ref mut rarg) = join.rarg { + Self::rewrite_from_node(rarg, bind)?; + } + } + _ => {} + } + Ok(()) + } +} + +impl RewriteModule for SelectUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + let need_rewrite = if let Some(NodeEnum::SelectStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + Self::needs_rewrite(stmt) + } else { + false + }; + + if !need_rewrite { + return Ok(()); + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::SelectStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + Self::rewrite_select(stmt, &mut bind)?; + } + + input.bind_put(bind); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::{bind::Parameter, Bind}; + use std::env::set_var; + + #[test] + fn test_unique_id_select() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + println!("output: {}", output.query().unwrap()); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_with_bind() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"SELECT pgdog.unique_id() AS id, $1 AS name"#) + .unwrap() + .protobuf; + let bind = Bind::new_params( + "", + &[Parameter { + len: 4, + data: "test".into(), + }], + ); + let mut input = Input::new(&stmt, Some(&bind)); + SelectUniqueIdRewrite::default() + .rewrite(&mut input) + .unwrap(); + let output = input.build().unwrap(); + assert_eq!(output.query().unwrap(), "SELECT $2 AS id, $1 AS name"); + } + + #[test] + fn test_unique_id_select_cte() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"WITH ids AS (SELECT pgdog.unique_id() AS id) SELECT * FROM ids"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_subquery() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"SELECT * FROM (SELECT pgdog.unique_id() AS id) sub"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_union() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse( + r#"SELECT pgdog.unique_id() AS id UNION ALL SELECT pgdog.unique_id() AS id"#, + ) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_no_rewrite_when_no_unique_id() { + let stmt = pg_query::parse(r#"SELECT id FROM users"#).unwrap().protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(matches!(output, super::super::super::Output::NoOp)); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs new file mode 100644 index 000000000..ebf5f28a0 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -0,0 +1,137 @@ +//! UPDATE statement rewriter for unique_id. + +use pg_query::{protobuf::ParamRef, NodeEnum}; + +use super::{ + super::{Error, Input, RewriteModule}, + bigint_const, +}; +use crate::{frontend::router::parser::Value, net::Datum, unique_id}; + +#[derive(Default)] +pub struct UpdateUniqueIdRewrite {} + +impl RewriteModule for UpdateUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + let mut need_rewrite = false; + + if let Some(NodeEnum::UpdateStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .map(|stmt| stmt.node.as_ref()) + .flatten() + { + for target in &stmt.target_list { + if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { + if let Some(val) = &res.val { + if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { + if *func == "pgdog.unique_id" { + need_rewrite = true; + break; + } + } + } + } + } + + if !need_rewrite { + return Ok(()); + } + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::UpdateStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .map(|stmt| stmt.node.as_mut()) + .flatten() + { + for target in stmt.target_list.iter_mut() { + if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { + if let Some(ref mut val) = res.val { + if let Ok(Value::Function(name)) = Value::try_from(&val.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if let Some(ref mut bind) = bind { + NodeEnum::ParamRef(ParamRef { + number: bind.add_parameter(Datum::Bigint(id))?, + ..Default::default() + }) + } else { + bigint_const(id) + }; + + val.node = Some(node); + } + } + } + } + } + } + + input.bind_put(bind); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::{bind::Parameter, Bind}; + use std::env::set_var; + + #[test] + fn test_unique_id_update() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"UPDATE omnisharded SET id = pgdog.unique_id() WHERE old_id = 123"#) + .unwrap() + .protobuf; + let mut update = UpdateUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + update.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_update_parse() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse( + r#"UPDATE omnisharded SET id = pgdog.unique_id(), settings = $1 WHERE old_id = $2"#, + ) + .unwrap() + .protobuf; + let bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 3, + data: "123".into(), + }, + ], + ); + let mut input = Input::new(&stmt, Some(&bind)); + UpdateUniqueIdRewrite::default() + .rewrite(&mut input) + .unwrap(); + let output = input.build().unwrap(); + assert_eq!( + output.query().unwrap(), + "UPDATE omnisharded SET id = $3, settings = $1 WHERE old_id = $2" + ); + } +} diff --git a/pgdog/src/net/messages/parse.rs b/pgdog/src/net/messages/parse.rs index 7481e70f8..315abd066 100644 --- a/pgdog/src/net/messages/parse.rs +++ b/pgdog/src/net/messages/parse.rs @@ -38,10 +38,10 @@ impl Parse { } /// New anonymous prepared statement. - pub fn new_anonymous(query: &str) -> Self { + pub fn new_anonymous(query: impl ToString) -> Self { Self { name: Bytes::from("\0"), - query: Bytes::from(query.to_owned() + "\0"), + query: Bytes::from(query.to_string() + "\0"), data_types: Bytes::copy_from_slice(&0i16.to_be_bytes()), original: None, } From b2edbf998bfe63dad6eb6baca10b5cc9a1007d96 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 29 Nov 2025 13:07:28 -0800 Subject: [PATCH 03/23] clippy --- pgdog/src/frontend/prepared_statements/mod.rs | 3 +-- pgdog/src/frontend/router/parser/insert.rs | 2 +- pgdog/src/frontend/router/parser/query/mod.rs | 2 +- pgdog/src/frontend/router/rewrite/input.rs | 2 +- .../src/frontend/router/rewrite/unique_id/insert.rs | 6 ++---- .../src/frontend/router/rewrite/unique_id/select.rs | 4 ++-- .../src/frontend/router/rewrite/unique_id/update.rs | 6 ++---- pgdog/src/unique_id.rs | 13 ++----------- 8 files changed, 12 insertions(+), 26 deletions(-) diff --git a/pgdog/src/frontend/prepared_statements/mod.rs b/pgdog/src/frontend/prepared_statements/mod.rs index ef0f4c9b8..2fd75418f 100644 --- a/pgdog/src/frontend/prepared_statements/mod.rs +++ b/pgdog/src/frontend/prepared_statements/mod.rs @@ -123,8 +123,7 @@ impl PreparedStatements { pub fn parse(&self, name: &str) -> Option { self.local .get(name) - .map(|name| self.global.read().parse(name)) - .flatten() + .and_then(|name| self.global.read().parse(name)) } /// Number of prepared statements in the local cache. diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs index 3213192d8..b15f4dc24 100644 --- a/pgdog/src/frontend/router/parser/insert.rs +++ b/pgdog/src/frontend/router/parser/insert.rs @@ -58,7 +58,7 @@ impl<'a> Insert<'a> { } pub fn stmt(&'a self) -> &'a InsertStmt { - &self.stmt + self.stmt } /// Get table name, if specified (should always be). diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 773bd9173..01922ed6e 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -7,7 +7,7 @@ use crate::{ frontend::{ router::{ context::RouterContext, - parser::{rewrite::Rewrite, OrderBy, Shard}, + parser::{OrderBy, Shard}, rewrite::{self, RewriteModule}, round_robin, sharding::{Centroids, ContextBuilder, Value as ShardingValue}, diff --git a/pgdog/src/frontend/router/rewrite/input.rs b/pgdog/src/frontend/router/rewrite/input.rs index 5dc5d6a32..dd97afedc 100644 --- a/pgdog/src/frontend/router/rewrite/input.rs +++ b/pgdog/src/frontend/router/rewrite/input.rs @@ -76,7 +76,7 @@ impl<'a> Input<'a> { self.rewrite.as_mut().unwrap() }; - Ok(stmt.stmts.first_mut().ok_or(Error::EmptyQuery)?) + stmt.stmts.first_mut().ok_or(Error::EmptyQuery) } /// Assemble statement and add it to the global prepared statements cache. diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index c7d9439ba..f308fa19a 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -22,8 +22,7 @@ impl RewriteModule for InsertUniqueIdRewrite { .stmt()? .stmt .as_ref() - .map(|stmt| stmt.node.as_ref()) - .flatten() + .and_then(|stmt| stmt.node.as_ref()) { let wrapper = Insert::new(stmt); @@ -48,8 +47,7 @@ impl RewriteModule for InsertUniqueIdRewrite { .stmt_mut()? .stmt .as_mut() - .map(|stmt| stmt.node.as_mut()) - .flatten() + .and_then(|stmt| stmt.node.as_mut()) { stmt.select_stmt .as_mut() diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index 3c86fe83e..672a08a15 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -80,11 +80,11 @@ impl SelectUniqueIdRewrite { let left = join .larg .as_ref() - .map_or(false, |n| Self::needs_rewrite_from_node(n)); + .is_some_and(|n| Self::needs_rewrite_from_node(n)); let right = join .rarg .as_ref() - .map_or(false, |n| Self::needs_rewrite_from_node(n)); + .is_some_and(|n| Self::needs_rewrite_from_node(n)); left || right } _ => false, diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index ebf5f28a0..e329ac3e2 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -19,8 +19,7 @@ impl RewriteModule for UpdateUniqueIdRewrite { .stmt()? .stmt .as_ref() - .map(|stmt| stmt.node.as_ref()) - .flatten() + .and_then(|stmt| stmt.node.as_ref()) { for target in &stmt.target_list { if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { @@ -46,8 +45,7 @@ impl RewriteModule for UpdateUniqueIdRewrite { .stmt_mut()? .stmt .as_mut() - .map(|stmt| stmt.node.as_mut()) - .flatten() + .and_then(|stmt| stmt.node.as_mut()) { for target in stmt.target_list.iter_mut() { if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { diff --git a/pgdog/src/unique_id.rs b/pgdog/src/unique_id.rs index 043612015..916f200f9 100644 --- a/pgdog/src/unique_id.rs +++ b/pgdog/src/unique_id.rs @@ -34,21 +34,12 @@ const MAX_OFFSET: u64 = i64::MAX as u64 static UNIQUE_ID: OnceCell = OnceCell::new(); -#[derive(Debug)] +#[derive(Debug, Default)] struct State { last_timestamp_ms: u64, sequence: u64, } -impl Default for State { - fn default() -> Self { - Self { - last_timestamp_ms: 0, - sequence: 0, - } - } -} - impl State { // Generate next unique ID in a distributed sequence. // The `node_id` argument must be globally unique. @@ -145,7 +136,7 @@ impl UniqueId { /// Get (and initialize, if necessary) the unique ID generator. pub fn generator() -> Result<&'static UniqueId, Error> { - UNIQUE_ID.get_or_try_init(|| Self::new()) + UNIQUE_ID.get_or_try_init(Self::new) } /// Generate a globally unique, monotonically increasing identifier. From f3171e89e8b85f8a8378b8e875eab9255ad64bdb Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 29 Nov 2025 13:48:16 -0800 Subject: [PATCH 04/23] explain --- pgdog/src/frontend/client/query_engine/mod.rs | 13 +- .../frontend/client/query_engine/output.rs | 10 + pgdog/src/frontend/router/parser/query/mod.rs | 7 +- .../src/frontend/router/parser/rewrite/mod.rs | 127 -------- pgdog/src/frontend/router/rewrite/input.rs | 31 +- .../router/rewrite/insert_split/mod.rs | 1 + pgdog/src/frontend/router/rewrite/mod.rs | 14 +- pgdog/src/frontend/router/rewrite/output.rs | 43 +++ .../router/rewrite/unique_id/explain.rs | 273 ++++++++++++++++++ .../router/rewrite/unique_id/insert.rs | 95 +++--- .../frontend/router/rewrite/unique_id/mod.rs | 6 + .../router/rewrite/unique_id/select.rs | 6 +- .../router/rewrite/unique_id/update.rs | 99 ++++--- 13 files changed, 477 insertions(+), 248 deletions(-) create mode 100644 pgdog/src/frontend/client/query_engine/output.rs create mode 100644 pgdog/src/frontend/router/rewrite/insert_split/mod.rs create mode 100644 pgdog/src/frontend/router/rewrite/output.rs create mode 100644 pgdog/src/frontend/router/rewrite/unique_id/explain.rs diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 9ff2c4381..31d2fb957 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -22,6 +22,7 @@ pub mod incomplete_requests; pub mod insert_split; pub mod internal_values; pub mod notify_buffer; +pub mod output; pub mod prepared_statements; pub mod pub_sub; pub mod query; @@ -37,6 +38,7 @@ pub mod unknown_command; use self::query::ExplainResponseState; pub use context::QueryEngineContext; use notify_buffer::NotifyBuffer; +pub use output::QueryEngineOutput; pub use two_pc::phase::TwoPcPhase; use two_pc::TwoPc; @@ -108,7 +110,10 @@ impl QueryEngine { } /// Handle client request. - pub async fn handle(&mut self, context: &mut QueryEngineContext<'_>) -> Result<(), Error> { + pub async fn handle( + &mut self, + context: &mut QueryEngineContext<'_>, + ) -> Result { self.stats .received(context.client_request.total_message_len()); self.set_state(State::Active); // Client is active. @@ -119,14 +124,14 @@ impl QueryEngine { // Intercept commands we don't have to forward to a server. if self.intercept_incomplete(context).await? { self.update_stats(context); - return Ok(()); + return Ok(QueryEngineOutput::Executed); } // Route transaction to the right servers. if !self.route_transaction(context).await? { self.update_stats(context); debug!("transaction has nowhere to go"); - return Ok(()); + return Ok(QueryEngineOutput::Executed); } self.hooks.before_execution(context)?; @@ -246,7 +251,7 @@ impl QueryEngine { self.update_stats(context); - Ok(()) + Ok(QueryEngineOutput::Executed) } fn update_stats(&mut self, context: &mut QueryEngineContext<'_>) { diff --git a/pgdog/src/frontend/client/query_engine/output.rs b/pgdog/src/frontend/client/query_engine/output.rs new file mode 100644 index 000000000..6256ec44d --- /dev/null +++ b/pgdog/src/frontend/client/query_engine/output.rs @@ -0,0 +1,10 @@ +use crate::frontend::ClientRequest; + +#[derive(Debug, Clone)] +pub enum QueryEngineOutput { + // The request has been executed as-is. + Executed, + // The request has been rewritten and needs to + // be resent. + Rewritten(Vec), +} diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 01922ed6e..4dcc47c46 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -201,17 +201,16 @@ impl QueryParser { rewrite::Rewrite::new(context.prepared_statements()).rewrite(&mut input)?; match input.build()? { - rewrite::Output::NoOp => (), - rewrite::Output::Extended { parse, bind } => { + rewrite::StepOutput::NoOp => (), + rewrite::StepOutput::Extended { parse, bind } => { return Ok(Command::Rewrite(vec![ ProtocolMessage::from(parse), bind.into(), ])) } - rewrite::Output::Simple { query } => { + rewrite::StepOutput::Simple { query } => { return Ok(Command::Rewrite(vec![ProtocolMessage::from(query)])) } - _ => todo!("multi rewrite not supported yet"), } self.ensure_explain_recorder(statement.ast(), context); diff --git a/pgdog/src/frontend/router/parser/rewrite/mod.rs b/pgdog/src/frontend/router/parser/rewrite/mod.rs index 5bf8a3333..f33b12962 100644 --- a/pgdog/src/frontend/router/parser/rewrite/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/mod.rs @@ -1,132 +1,5 @@ -use pg_query::{NodeEnum, ParseResult}; - -use super::{Command, Error}; - mod insert_split; mod shard_key; -use crate::net::{Parse, ProtocolMessage}; -use crate::{frontend::PreparedStatements, net::Query}; pub use insert_split::{InsertSplitPlan, InsertSplitRow}; pub use shard_key::{Assignment, AssignmentValue, ShardKeyRewritePlan}; - -#[derive(Debug, Clone)] -pub struct Rewrite<'a> { - ast: &'a ParseResult, -} - -impl<'a> Rewrite<'a> { - pub fn new(ast: &'a ParseResult) -> Self { - Self { ast } - } - - /// Statement needs to be rewritten. - pub fn needs_rewrite(&self) -> bool { - for stmt in &self.ast.protobuf.stmts { - if let Some(ref stmt) = stmt.stmt { - if let Some(ref node) = stmt.node { - match node { - NodeEnum::PrepareStmt(_) => return true, - NodeEnum::ExecuteStmt(_) => return true, - NodeEnum::DeallocateStmt(_) => return true, - _ => (), - } - } - } - } - - false - } - - pub fn rewrite(&self, prepared_statements: &mut PreparedStatements) -> Result { - let mut ast = self.ast.protobuf.clone(); - - for stmt in &mut ast.stmts { - if let Some(ref mut stmt) = stmt.stmt { - if let Some(ref mut node) = stmt.node { - match node { - NodeEnum::PrepareStmt(ref mut stmt) => { - let statement = stmt - .query - .as_ref() - .ok_or(Error::EmptyQuery)? - .deparse() - .map_err(|_| Error::EmptyQuery)?; - - let mut parse = Parse::named(&stmt.name, &statement); - prepared_statements.insert_anyway(&mut parse); - stmt.name = parse.name().to_string(); - - return Ok(Command::Rewrite(vec![Query::new( - ast.deparse().map_err(|_| Error::EmptyQuery)?, - ) - .into()])); - } - - NodeEnum::ExecuteStmt(ref mut stmt) => { - let parse = prepared_statements.parse(&stmt.name); - if let Some(parse) = parse { - stmt.name = parse.name().to_string(); - - return Ok(Command::Rewrite(vec![ - ProtocolMessage::Prepare { - name: stmt.name.clone(), - statement: parse.query().to_string(), - }, - Query::new(ast.deparse().map_err(|_| Error::EmptyQuery)?) - .into(), - ])); - } else { - return Err(Error::PreparedStatementDoesntExist(stmt.name.clone())); - } - } - - NodeEnum::DeallocateStmt(_) => return Ok(Command::Deallocate), - - _ => (), - } - } - } - } - - Err(Error::EmptyQuery) - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use crate::net::{FromBytes, ToBytes}; - - use super::*; - - #[test] - fn test_rewrite_prepared() { - let ast = pg_query::parse("PREPARE test AS SELECT $1, $2, $3").unwrap(); - let rewrite = Rewrite::new(&ast); - assert!(rewrite.needs_rewrite()); - let mut prepared_statements = PreparedStatements::new(); - let queries = rewrite.rewrite(&mut prepared_statements).unwrap(); - match queries { - Command::Rewrite(messages) => { - let message = Query::from_bytes(messages[0].to_bytes().unwrap()).unwrap(); - assert_eq!(message.query(), "PREPARE __pgdog_1 AS SELECT $1, $2, $3"); - } - _ => panic!("not a rewrite"), - } - } - - #[test] - fn test_deallocate() { - for q in ["DEALLOCATE ALL", "DEALLOCATE test"] { - let ast = pg_query::parse(q).unwrap(); - let ast = Arc::new(ast); - let rewrite = Rewrite::new(&ast) - .rewrite(&mut PreparedStatements::new()) - .unwrap(); - - assert!(matches!(rewrite, Command::Deallocate)); - } - } -} diff --git a/pgdog/src/frontend/router/rewrite/input.rs b/pgdog/src/frontend/router/rewrite/input.rs index dd97afedc..2b7b10ff5 100644 --- a/pgdog/src/frontend/router/rewrite/input.rs +++ b/pgdog/src/frontend/router/rewrite/input.rs @@ -2,7 +2,7 @@ use pg_query::protobuf::{ParseResult, RawStmt}; -use super::Error; +use super::{Error, StepOutput}; use crate::{ frontend::PreparedStatements, net::{Bind, Parse, Query}, @@ -80,9 +80,9 @@ impl<'a> Input<'a> { } /// Assemble statement and add it to the global prepared statements cache. - pub fn build(mut self) -> Result { + pub fn build(mut self) -> Result { if self.rewrite.is_none() { - Ok(Output::NoOp) + Ok(StepOutput::NoOp) } else { let bind = self.rewrite_bind.take(); let stmt = self.rewrite.take().ok_or(Error::NoRewrite)?.deparse()?; @@ -90,37 +90,18 @@ impl<'a> Input<'a> { if let Some(mut bind) = bind { let mut parse = Parse::new_anonymous(stmt); if bind.anonymous() { - Ok(Output::Extended { parse, bind }) + Ok(StepOutput::Extended { parse, bind }) } else { let (_, name) = PreparedStatements::global().write().insert(&parse); parse.rename_fast(&name); bind.rename(name); - Ok(Output::Extended { parse, bind }) + Ok(StepOutput::Extended { parse, bind }) } } else { - Ok(Output::Simple { + Ok(StepOutput::Simple { query: Query::new(stmt), }) } } } } - -#[derive(Debug, Clone)] -pub enum Output { - NoOp, - Extended { parse: Parse, bind: Bind }, - Simple { query: Query }, - Multi(Vec>), -} - -impl Output { - /// Get rewritten query, if any. - pub fn query(&self) -> Result<&str, ()> { - match self { - Self::Extended { parse, .. } => Ok(parse.query()), - Self::Simple { query } => Ok(query.query()), - _ => Err(()), - } - } -} diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs new file mode 100644 index 000000000..c400252fb --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -0,0 +1 @@ +pub struct InsertSplitRewrite {} diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index a63fd367c..7e225f636 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -8,13 +8,16 @@ //! pub mod error; pub mod input; +pub mod insert_split; pub mod interface; +pub mod output; pub mod prepared; pub mod unique_id; pub use error::Error; -pub use input::{Input, Output}; +pub use input::Input; pub use interface::RewriteModule; +pub use output::StepOutput; use crate::frontend::PreparedStatements; @@ -38,10 +41,11 @@ impl RewriteModule for Rewrite<'_> { // First, we need to inject the unique ID into the query. Once that's done, // we can proceed with additional rewrites. - // Unique ID rewrites - unique_id::insert::InsertUniqueIdRewrite::default().rewrite(input)?; - unique_id::update::UpdateUniqueIdRewrite::default().rewrite(input)?; - unique_id::select::SelectUniqueIdRewrite::default().rewrite(input)?; + // Unique ID rewrites (including EXPLAIN wrappers) + unique_id::ExplainUniqueIdRewrite::default().rewrite(input)?; + unique_id::InsertUniqueIdRewrite::default().rewrite(input)?; + unique_id::UpdateUniqueIdRewrite::default().rewrite(input)?; + unique_id::SelectUniqueIdRewrite::default().rewrite(input)?; // Prepared statement rewrites prepared::PreparedRewrite::new(self.prepared_statements).rewrite(input)?; diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs new file mode 100644 index 000000000..c1b8090ac --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -0,0 +1,43 @@ +use crate::net::{Bind, Parse, ProtocolMessage, Query}; + +#[derive(Debug, Clone)] +pub struct RewrittenRequest { + pub messages: Vec, + pub action: ExecutionAction, +} + +/// Output of a single rewrite step. +#[derive(Debug, Clone)] +pub enum StepOutput { + NoOp, + Extended { parse: Parse, bind: Bind }, + Simple { query: Query }, +} + +impl StepOutput { + /// Get rewritten query, if any. + pub fn query(&self) -> Result<&str, ()> { + match self { + Self::Extended { parse, .. } => Ok(parse.query()), + Self::Simple { query } => Ok(query.query()), + _ => Err(()), + } + } +} + +#[derive(Debug, Clone)] +pub enum Output { + Passthrough, + Simple(RewrittenRequest), + Chain(Vec), +} + +#[derive(Debug, Clone)] +pub enum ExecutionAction { + /// Drop result completely. + Drop, + /// Return result to client. + Return, + /// Forward result to next step in the chain. + Forward, +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs new file mode 100644 index 000000000..009a9c1c5 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -0,0 +1,273 @@ +//! EXPLAIN statement rewriter for unique_id. + +use pg_query::NodeEnum; + +use super::{ + super::{Error, Input, RewriteModule}, + InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite, +}; + +#[derive(Default)] +pub struct ExplainUniqueIdRewrite {} + +impl RewriteModule for ExplainUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + // Check if this is an EXPLAIN statement + let is_explain = matches!( + input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()), + Some(NodeEnum::ExplainStmt(_)) + ); + + if !is_explain { + return Ok(()); + } + + // Get the inner query type and dispatch to appropriate rewriter + let inner_type = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + stmt.query + .as_ref() + .and_then(|q| q.node.as_ref()) + .map(|n| match n { + NodeEnum::SelectStmt(_) => "select", + NodeEnum::InsertStmt(_) => "insert", + NodeEnum::UpdateStmt(_) => "update", + _ => "other", + }) + } else { + None + }; + + match inner_type { + Some("select") => self.rewrite_explain_select(input), + Some("insert") => self.rewrite_explain_insert(input), + Some("update") => self.rewrite_explain_update(input), + _ => Ok(()), + } + } +} + +impl ExplainUniqueIdRewrite { + fn rewrite_explain_select(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + // Check if the inner SELECT needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::SelectStmt(select)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + SelectUniqueIdRewrite::needs_rewrite(select) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::SelectStmt(select)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + SelectUniqueIdRewrite::rewrite_select(select, &mut bind)?; + } + } + + input.bind_put(bind); + Ok(()) + } + + fn rewrite_explain_insert(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + // Check if the inner INSERT needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::InsertStmt(insert)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + InsertUniqueIdRewrite::needs_rewrite(insert) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::InsertStmt(insert)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + InsertUniqueIdRewrite::rewrite_insert(insert, &mut bind)?; + } + } + + input.bind_put(bind); + Ok(()) + } + + fn rewrite_explain_update(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + // Check if the inner UPDATE needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::UpdateStmt(update)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + UpdateUniqueIdRewrite::needs_rewrite(update) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::UpdateStmt(update)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + UpdateUniqueIdRewrite::rewrite_update(update, &mut bind)?; + } + } + + input.bind_put(bind); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::env::set_var; + + #[test] + fn test_explain_select_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_insert_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN INSERT INTO test (id) VALUES (pgdog.unique_id())"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_update_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"EXPLAIN UPDATE test SET id = pgdog.unique_id() WHERE old_id = 1"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_analyze_select_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN ANALYZE SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("ANALYZE")); + } + + #[test] + fn test_explain_no_unique_id() { + let stmt = pg_query::parse(r#"EXPLAIN SELECT 1"#).unwrap().protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Input::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(matches!(output, super::super::super::StepOutput::NoOp)); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index f308fa19a..5a880def2 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -1,4 +1,7 @@ -use pg_query::{protobuf::ParamRef, NodeEnum}; +use pg_query::{ + protobuf::{InsertStmt, ParamRef}, + NodeEnum, +}; use super::{ super::{Error, Input, RewriteModule}, @@ -13,58 +16,40 @@ use crate::{ #[derive(Default)] pub struct InsertUniqueIdRewrite {} -impl RewriteModule for InsertUniqueIdRewrite { - /// Handle statement rewrite - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { - let mut need_rewrite = false; +impl InsertUniqueIdRewrite { + pub fn needs_rewrite(stmt: &InsertStmt) -> bool { + let wrapper = Insert::new(stmt); - if let Some(NodeEnum::InsertStmt(stmt)) = input - .stmt()? - .stmt - .as_ref() - .and_then(|stmt| stmt.node.as_ref()) - { - let wrapper = Insert::new(stmt); - - for tuple in wrapper.tuples() { - for value in tuple.values { - if let Value::Function(ref func) = value { - if *func == "pgdog.unique_id" { - need_rewrite = true; - } + for tuple in wrapper.tuples() { + for value in tuple.values { + if let Value::Function(ref func) = value { + if *func == "pgdog.unique_id" { + return true; } } } - - if !need_rewrite { - return Ok(()); - } } - let mut bind = input.bind_take(); + false + } - let select = if let Some(NodeEnum::InsertStmt(stmt)) = input - .stmt_mut()? - .stmt + pub fn rewrite_insert( + stmt: &mut InsertStmt, + bind: &mut Option, + ) -> Result<(), Error> { + let select = stmt + .select_stmt .as_mut() - .and_then(|stmt| stmt.node.as_mut()) - { - stmt.select_stmt - .as_mut() - .ok_or(Error::ParserError)? - .node - .as_mut() - .ok_or(Error::ParserError)? - } else { - return Ok(()); - }; + .ok_or(Error::ParserError)? + .node + .as_mut() + .ok_or(Error::ParserError)?; if let NodeEnum::SelectStmt(stmt) = select { for tuple in stmt.values_lists.iter_mut() { if let Some(NodeEnum::List(ref mut tuple)) = tuple.node { for column in tuple.items.iter_mut() { if let Ok(Value::Function(name)) = Value::try_from(&column.node) { - // Replace function call with value. if name == "pgdog.unique_id" { let id = unique_id::UniqueId::generator()?.next_id(); @@ -85,6 +70,38 @@ impl RewriteModule for InsertUniqueIdRewrite { } } + Ok(()) + } +} + +impl RewriteModule for InsertUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + let need_rewrite = if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + Self::needs_rewrite(stmt) + } else { + false + }; + + if !need_rewrite { + return Ok(()); + } + + let mut bind = input.bind_take(); + + if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + Self::rewrite_insert(stmt, &mut bind)?; + } + input.bind_put(bind); Ok(()) diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs index 5b7f21966..0a28269a7 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -5,10 +5,16 @@ use pg_query::{ NodeEnum, }; +pub mod explain; pub mod insert; pub mod select; pub mod update; +pub use explain::ExplainUniqueIdRewrite; +pub use insert::InsertUniqueIdRewrite; +pub use select::SelectUniqueIdRewrite; +pub use update::UpdateUniqueIdRewrite; + pub struct UniqueIdRewrite; /// Create a bigint-typed constant node for the given ID. diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index 672a08a15..d41453796 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -15,7 +15,7 @@ use crate::{frontend::router::parser::Value, net::Datum, unique_id}; pub struct SelectUniqueIdRewrite {} impl SelectUniqueIdRewrite { - fn needs_rewrite(stmt: &SelectStmt) -> bool { + pub fn needs_rewrite(stmt: &SelectStmt) -> bool { // Check target_list (SELECT columns) for target in &stmt.target_list { if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { @@ -91,7 +91,7 @@ impl SelectUniqueIdRewrite { } } - fn rewrite_select( + pub fn rewrite_select( stmt: &mut SelectStmt, bind: &mut Option, ) -> Result<(), Error> { @@ -308,6 +308,6 @@ mod test { let mut input = Input::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); - assert!(matches!(output, super::super::super::Output::NoOp)); + assert!(matches!(output, super::super::super::StepOutput::NoOp)); } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index e329ac3e2..6029202fd 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -1,6 +1,9 @@ //! UPDATE statement rewriter for unique_id. -use pg_query::{protobuf::ParamRef, NodeEnum}; +use pg_query::{ + protobuf::{ParamRef, UpdateStmt}, + NodeEnum, +}; use super::{ super::{Error, Input, RewriteModule}, @@ -11,32 +14,67 @@ use crate::{frontend::router::parser::Value, net::Datum, unique_id}; #[derive(Default)] pub struct UpdateUniqueIdRewrite {} +impl UpdateUniqueIdRewrite { + pub fn needs_rewrite(stmt: &UpdateStmt) -> bool { + for target in &stmt.target_list { + if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { + if let Some(val) = &res.val { + if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { + if *func == "pgdog.unique_id" { + return true; + } + } + } + } + } + false + } + + pub fn rewrite_update( + stmt: &mut UpdateStmt, + bind: &mut Option, + ) -> Result<(), Error> { + for target in stmt.target_list.iter_mut() { + if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { + if let Some(ref mut val) = res.val { + if let Ok(Value::Function(name)) = Value::try_from(&val.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if let Some(ref mut bind) = bind { + NodeEnum::ParamRef(ParamRef { + number: bind.add_parameter(Datum::Bigint(id))?, + ..Default::default() + }) + } else { + bigint_const(id) + }; + + val.node = Some(node); + } + } + } + } + } + Ok(()) + } +} + impl RewriteModule for UpdateUniqueIdRewrite { fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { - let mut need_rewrite = false; - - if let Some(NodeEnum::UpdateStmt(stmt)) = input + let need_rewrite = if let Some(NodeEnum::UpdateStmt(stmt)) = input .stmt()? .stmt .as_ref() .and_then(|stmt| stmt.node.as_ref()) { - for target in &stmt.target_list { - if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { - if let Some(val) = &res.val { - if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { - if *func == "pgdog.unique_id" { - need_rewrite = true; - break; - } - } - } - } - } + Self::needs_rewrite(stmt) + } else { + false + }; - if !need_rewrite { - return Ok(()); - } + if !need_rewrite { + return Ok(()); } let mut bind = input.bind_take(); @@ -47,28 +85,7 @@ impl RewriteModule for UpdateUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - for target in stmt.target_list.iter_mut() { - if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { - if let Some(ref mut val) = res.val { - if let Ok(Value::Function(name)) = Value::try_from(&val.node) { - if name == "pgdog.unique_id" { - let id = unique_id::UniqueId::generator()?.next_id(); - - let node = if let Some(ref mut bind) = bind { - NodeEnum::ParamRef(ParamRef { - number: bind.add_parameter(Datum::Bigint(id))?, - ..Default::default() - }) - } else { - bigint_const(id) - }; - - val.node = Some(node); - } - } - } - } - } + Self::rewrite_update(stmt, &mut bind)?; } input.bind_put(bind); From 4e5b37e2287de17525ffb638fb78c0b3363b011e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 09:09:29 -0800 Subject: [PATCH 05/23] Working --- .../replication/logical/subscriber/context.rs | 2 +- pgdog/src/frontend/client/query_engine/mod.rs | 4 +-- .../src/frontend/client/query_engine/query.rs | 5 ++- pgdog/src/frontend/client_request.rs | 14 ++++++-- .../prepared_statements/global_cache.rs | 10 ++++++ pgdog/src/frontend/prepared_statements/mod.rs | 11 ++++++ pgdog/src/frontend/router/cli.rs | 2 +- .../frontend/router/parser/query/explain.rs | 8 ++--- pgdog/src/frontend/router/parser/query/mod.rs | 2 +- pgdog/src/frontend/router/rewrite/input.rs | 2 +- pgdog/src/frontend/router/rewrite/mod.rs | 2 +- pgdog/src/frontend/router/rewrite/output.rs | 35 ++++++++++++++++++- 12 files changed, 82 insertions(+), 15 deletions(-) diff --git a/pgdog/src/backend/replication/logical/subscriber/context.rs b/pgdog/src/backend/replication/logical/subscriber/context.rs index da7ed70ac..0fe6b751d 100644 --- a/pgdog/src/backend/replication/logical/subscriber/context.rs +++ b/pgdog/src/backend/replication/logical/subscriber/context.rs @@ -51,7 +51,7 @@ impl<'a> StreamContext<'a> { /// Construct router context. pub fn router_context(&'a mut self) -> Result, Error> { Ok(RouterContext::new( - &self.request, + &mut self.request, self.cluster, &mut self.prepared_statements, &self.params, diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 31d2fb957..d3443adf5 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -226,8 +226,8 @@ impl QueryEngine { self.set_route(context, route.clone()).await?; } Command::Copy(_) => self.execute(context, &route).await?, - Command::Rewrite(query) => { - context.client_request.rewrite(query)?; + Command::Rewrite(requests) => { + context.client_request.rewrite_extended(&requests)?; self.execute(context, &route).await?; } Command::InsertSplit(plan) => self.insert_split(context, *plan.clone()).await?, diff --git a/pgdog/src/frontend/client/query_engine/query.rs b/pgdog/src/frontend/client/query_engine/query.rs index b60561c35..7802e9985 100644 --- a/pgdog/src/frontend/client/query_engine/query.rs +++ b/pgdog/src/frontend/client/query_engine/query.rs @@ -48,7 +48,10 @@ impl QueryEngine { } if let Some(sql) = route.rewritten_sql() { - match context.client_request.rewrite(&[Query::new(sql).into()]) { + match context + .client_request + .rewrite_simple(&[Query::new(sql).into()]) + { Ok(()) => (), Err(crate::net::Error::OnlySimpleForRewrites) => { context.client_request.rewrite_prepared( diff --git a/pgdog/src/frontend/client_request.rs b/pgdog/src/frontend/client_request.rs index b289aea36..7cb1ff4af 100644 --- a/pgdog/src/frontend/client_request.rs +++ b/pgdog/src/frontend/client_request.rs @@ -180,16 +180,26 @@ impl ClientRequest { } /// Rewrite query in buffer. - pub fn rewrite(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { + pub fn rewrite_simple(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { + if self.messages.iter().any(|c| c.code() != 'Q') { + return Err(Error::OnlySimpleForRewrites); + } + self.messages.clear(); + self.messages.extend(request.to_vec()); + Ok(()) + } + + pub fn rewrite_extended(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { for new_message in request { if let Some(pos) = self .messages .iter() - .position(|message| message.code() == new_message.code()) + .position(|p| p.code() == new_message.code()) { self.messages[pos] = new_message.clone(); } } + Ok(()) } diff --git a/pgdog/src/frontend/prepared_statements/global_cache.rs b/pgdog/src/frontend/prepared_statements/global_cache.rs index e66a90daf..281c02372 100644 --- a/pgdog/src/frontend/prepared_statements/global_cache.rs +++ b/pgdog/src/frontend/prepared_statements/global_cache.rs @@ -280,6 +280,16 @@ impl GlobalCache { self.names.get(name).map(|p| p.parse.clone()) } + /// Get global prepared statement name. + pub fn name(&self, parse: &Parse) -> Option { + let cache_key = CacheKey { + query: parse.query_ref(), + data_types: parse.data_types_ref(), + version: 0, + }; + self.statements.get(&cache_key).map(|stmt| stmt.name()) + } + /// Get the RowDescription message for the prepared statement. /// /// It can be used to decode results received from executing the prepared diff --git a/pgdog/src/frontend/prepared_statements/mod.rs b/pgdog/src/frontend/prepared_statements/mod.rs index 2fd75418f..646848e1f 100644 --- a/pgdog/src/frontend/prepared_statements/mod.rs +++ b/pgdog/src/frontend/prepared_statements/mod.rs @@ -97,6 +97,17 @@ impl PreparedStatements { parse.rename_fast(&name) } + /// Store a rewritten statement in the global cache forever. + pub fn cache_rewritten(parse: &Parse) -> String { + let exists = Self::global().read().name(parse); + if let Some(exists) = exists { + exists + } else { + let (_, name) = Self::global().write().insert(parse); + name + } + } + /// Retrieve stored rewrite plan for a prepared statement, if any. pub fn rewrite_plan(&self, name: &str) -> Option { self.global.read().rewrite_plan(name) diff --git a/pgdog/src/frontend/router/cli.rs b/pgdog/src/frontend/router/cli.rs index ceb16b5ed..154823e36 100644 --- a/pgdog/src/frontend/router/cli.rs +++ b/pgdog/src/frontend/router/cli.rs @@ -49,7 +49,7 @@ impl RouterCli { let mut qp = QueryParser::default(); let req = vec![ProtocolMessage::from(Query::new(query))]; let cmd = qp.parse(RouterContext::new( - &req.into(), + &mut req.into(), &cluster, &mut stmt, ¶ms, diff --git a/pgdog/src/frontend/router/parser/query/explain.rs b/pgdog/src/frontend/router/parser/query/explain.rs index 8b38261a9..8c6b133e5 100644 --- a/pgdog/src/frontend/router/parser/query/explain.rs +++ b/pgdog/src/frontend/router/parser/query/explain.rs @@ -69,13 +69,13 @@ mod tests { // Helper function to route a plain SQL statement and return its `Route`. fn route(sql: &str) -> Route { enable_expanded_explain(); - let buffer = ClientRequest::from(vec![Query::new(sql).into()]); + let mut buffer = ClientRequest::from(vec![Query::new(sql).into()]); let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, @@ -96,13 +96,13 @@ mod tests { .collect::>(); let bind = Bind::new_params("", ¶meters); - let buffer: ClientRequest = vec![parse_msg.into(), bind.into()].into(); + let mut buffer: ClientRequest = vec![parse_msg.into(), bind.into()].into(); let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 4dcc47c46..96f195fc9 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -198,7 +198,7 @@ impl QueryParser { }; let mut input = rewrite::Input::new(&statement.ast().protobuf, context.router_context.bind); - rewrite::Rewrite::new(context.prepared_statements()).rewrite(&mut input)?; + rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input)?; match input.build()? { rewrite::StepOutput::NoOp => (), diff --git a/pgdog/src/frontend/router/rewrite/input.rs b/pgdog/src/frontend/router/rewrite/input.rs index 2b7b10ff5..04efe7d4e 100644 --- a/pgdog/src/frontend/router/rewrite/input.rs +++ b/pgdog/src/frontend/router/rewrite/input.rs @@ -92,7 +92,7 @@ impl<'a> Input<'a> { if bind.anonymous() { Ok(StepOutput::Extended { parse, bind }) } else { - let (_, name) = PreparedStatements::global().write().insert(&parse); + let name = PreparedStatements::cache_rewritten(&parse); parse.rename_fast(&name); bind.rename(name); Ok(StepOutput::Extended { parse, bind }) diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index 7e225f636..26bcd2a52 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -41,7 +41,7 @@ impl RewriteModule for Rewrite<'_> { // First, we need to inject the unique ID into the query. Once that's done, // we can proceed with additional rewrites. - // Unique ID rewrites (including EXPLAIN wrappers) + // Unique ID rewrites. unique_id::ExplainUniqueIdRewrite::default().rewrite(input)?; unique_id::InsertUniqueIdRewrite::default().rewrite(input)?; unique_id::UpdateUniqueIdRewrite::default().rewrite(input)?; diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index c1b8090ac..540be0629 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,9 +1,42 @@ -use crate::net::{Bind, Parse, ProtocolMessage, Query}; +use crate::{ + frontend::ClientRequest, + net::{Bind, Describe, Error, FromBytes, Parse, Protocol, ProtocolMessage, Query, ToBytes}, +}; #[derive(Debug, Clone)] pub struct RewrittenRequest { pub messages: Vec, pub action: ExecutionAction, + pub renamed: Option, +} + +impl RewrittenRequest { + /// Rewrite client request in-place, + /// making sure all messages use new prepared statement names. + pub fn rewrite_in_place(&self, request: &mut ClientRequest) -> Result<(), Error> { + for message in &self.messages { + let code = message.code(); + if let Some(pos) = request.messages.iter().position(|p| p.code() == code) { + request.messages[pos] = message.clone(); + } + } + + if let Some(ref renamed) = self.renamed { + for message in request.messages.iter_mut() { + // Rename describe to the new prepared statement. + if message.code() == 'D' { + let mut describe = Describe::from_bytes(message.to_bytes()?)?; + if !describe.is_statement() { + describe.rename(renamed); + } + + *message = ProtocolMessage::from(describe); + } + } + } + + Ok(()) + } } /// Output of a single rewrite step. From d9f8c69cf8e55533c77ec5031be3467c21406898 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 09:11:18 -0800 Subject: [PATCH 06/23] clippy --- pgdog-config/src/core.rs | 20 ++++++++----------- pgdog-config/src/sharding.rs | 2 +- pgdog-config/src/url.rs | 2 +- .../replication/logical/subscriber/context.rs | 2 +- pgdog/src/frontend/client/query_engine/mod.rs | 2 +- pgdog/src/frontend/router/cli.rs | 2 +- 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index a43c34437..fe58b1e17 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -195,8 +195,7 @@ impl Config { /// Organize all databases by name for quicker retrieval. pub fn databases(&self) -> HashMap>> { let mut databases = HashMap::new(); - let mut number = 0; - for database in &self.databases { + for (number, database) in self.databases.iter().enumerate() { let entry = databases .entry(database.name.clone()) .or_insert_with(Vec::new); @@ -210,7 +209,6 @@ impl Config { number, database: database.clone(), }); - number += 1; } databases } @@ -323,7 +321,7 @@ impl Config { ); } } else { - pooler_mode.insert(database.name.clone(), database.pooler_mode.clone()); + pooler_mode.insert(database.name.clone(), database.pooler_mode); } } @@ -352,17 +350,15 @@ impl Config { _ => (), } - if !self.general.two_phase_commit { - if self.rewrite.enabled { - if self.rewrite.shard_key == RewriteMode::Rewrite { - warn!("rewrite.shard_key=rewrite will apply non-atomic shard-key rewrites; enabling two_phase_commit is strongly recommended" + if !self.general.two_phase_commit && self.rewrite.enabled { + if self.rewrite.shard_key == RewriteMode::Rewrite { + warn!("rewrite.shard_key=rewrite will apply non-atomic shard-key rewrites; enabling two_phase_commit is strongly recommended" ); - } + } - if self.rewrite.split_inserts == RewriteMode::Rewrite { - warn!("rewrite.split_inserts=rewrite may commit partial multi-row INSERTs; enabling two_phase_commit is strongly recommended" + if self.rewrite.split_inserts == RewriteMode::Rewrite { + warn!("rewrite.split_inserts=rewrite may commit partial multi-row INSERTs; enabling two_phase_commit is strongly recommended" ); - } } } } diff --git a/pgdog-config/src/sharding.rs b/pgdog-config/src/sharding.rs index 35dcf3bbf..8e3121c44 100644 --- a/pgdog-config/src/sharding.rs +++ b/pgdog-config/src/sharding.rs @@ -213,7 +213,7 @@ impl ShardedSchema { } pub fn name(&self) -> &str { - self.name.as_ref().map(|name| name.as_str()).unwrap_or("*") + self.name.as_deref().unwrap_or("*") } pub fn shard(&self) -> Option { diff --git a/pgdog-config/src/url.rs b/pgdog-config/src/url.rs index 8440ff9bf..6d746fccb 100644 --- a/pgdog-config/src/url.rs +++ b/pgdog-config/src/url.rs @@ -136,7 +136,7 @@ impl ConfigAndUsers { let mirroring = mirror_strs .iter() - .map(|s| Mirroring::from_str(s).map_err(|e| Error::ParseError(e))) + .map(|s| Mirroring::from_str(s).map_err(Error::ParseError)) .collect::, _>>()?; self.config.mirroring = mirroring; diff --git a/pgdog/src/backend/replication/logical/subscriber/context.rs b/pgdog/src/backend/replication/logical/subscriber/context.rs index 0fe6b751d..da7ed70ac 100644 --- a/pgdog/src/backend/replication/logical/subscriber/context.rs +++ b/pgdog/src/backend/replication/logical/subscriber/context.rs @@ -51,7 +51,7 @@ impl<'a> StreamContext<'a> { /// Construct router context. pub fn router_context(&'a mut self) -> Result, Error> { Ok(RouterContext::new( - &mut self.request, + &self.request, self.cluster, &mut self.prepared_statements, &self.params, diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index d3443adf5..f555978c1 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -227,7 +227,7 @@ impl QueryEngine { } Command::Copy(_) => self.execute(context, &route).await?, Command::Rewrite(requests) => { - context.client_request.rewrite_extended(&requests)?; + context.client_request.rewrite_extended(requests)?; self.execute(context, &route).await?; } Command::InsertSplit(plan) => self.insert_split(context, *plan.clone()).await?, diff --git a/pgdog/src/frontend/router/cli.rs b/pgdog/src/frontend/router/cli.rs index 154823e36..ceb16b5ed 100644 --- a/pgdog/src/frontend/router/cli.rs +++ b/pgdog/src/frontend/router/cli.rs @@ -49,7 +49,7 @@ impl RouterCli { let mut qp = QueryParser::default(); let req = vec![ProtocolMessage::from(Query::new(query))]; let cmd = qp.parse(RouterContext::new( - &mut req.into(), + &req.into(), &cluster, &mut stmt, ¶ms, From c2b219e86f4e47bba21dc1cd5d4027a372de99ff Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 09:35:59 -0800 Subject: [PATCH 07/23] use typecast for Describe works --- .../router/rewrite/unique_id/insert.rs | 14 +++------ .../frontend/router/rewrite/unique_id/mod.rs | 30 ++++++++++++++++++- .../router/rewrite/unique_id/select.rs | 14 ++++----- .../router/rewrite/unique_id/update.rs | 14 +++------ 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index 5a880def2..5655b25cb 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -1,11 +1,8 @@ -use pg_query::{ - protobuf::{InsertStmt, ParamRef}, - NodeEnum, -}; +use pg_query::{protobuf::InsertStmt, NodeEnum}; use super::{ super::{Error, Input, RewriteModule}, - bigint_const, + bigint_const, bigint_param, }; use crate::{ frontend::router::parser::{Insert, Value}, @@ -54,10 +51,7 @@ impl InsertUniqueIdRewrite { let id = unique_id::UniqueId::generator()?.next_id(); let node = if let Some(ref mut bind) = bind { - NodeEnum::ParamRef(ParamRef { - number: bind.add_parameter(Datum::Bigint(id))?, - ..Default::default() - }) + bigint_param(bind.add_parameter(Datum::Bigint(id))?) } else { bigint_const(id) }; @@ -175,7 +169,7 @@ mod test { let output = input.build().unwrap(); assert_eq!( output.query().unwrap(), - "INSERT INTO omnisharded (id, settings) VALUES ($3, $1::jsonb), ($4, $2::jsonb)" + "INSERT INTO omnisharded (id, settings) VALUES ($3::bigint, $1::jsonb), ($4::bigint, $2::jsonb)" ); } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs index 0a28269a7..7b08268e4 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -1,7 +1,7 @@ //! Unique ID rewrite engine. use pg_query::{ - protobuf::{a_const::Val, AConst, Node, TypeCast, TypeName}, + protobuf::{a_const::Val, AConst, Node, ParamRef, TypeCast, TypeName}, NodeEnum, }; @@ -17,6 +17,34 @@ pub use update::UpdateUniqueIdRewrite; pub struct UniqueIdRewrite; +/// Create a bigint-typed parameter reference node. +fn bigint_param(number: i32) -> NodeEnum { + NodeEnum::TypeCast(Box::new(TypeCast { + arg: Some(Box::new(Node { + node: Some(NodeEnum::ParamRef(ParamRef { + number, + ..Default::default() + })), + })), + type_name: Some(TypeName { + names: vec![ + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "pg_catalog".to_string(), + })), + }, + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "int8".to_string(), + })), + }, + ], + ..Default::default() + }), + ..Default::default() + })) +} + /// Create a bigint-typed constant node for the given ID. fn bigint_const(id: i64) -> NodeEnum { NodeEnum::TypeCast(Box::new(TypeCast { diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index d41453796..6ed7eecfe 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -1,13 +1,13 @@ //! SELECT statement rewriter for unique_id. use pg_query::{ - protobuf::{Node, ParamRef, SelectStmt}, + protobuf::{Node, SelectStmt}, NodeEnum, }; use super::{ super::{Error, Input, RewriteModule}, - bigint_const, + bigint_const, bigint_param, }; use crate::{frontend::router::parser::Value, net::Datum, unique_id}; @@ -104,10 +104,7 @@ impl SelectUniqueIdRewrite { let id = unique_id::UniqueId::generator()?.next_id(); let node = if let Some(ref mut bind) = bind { - NodeEnum::ParamRef(ParamRef { - number: bind.add_parameter(Datum::Bigint(id))?, - ..Default::default() - }) + bigint_param(bind.add_parameter(Datum::Bigint(id))?) } else { bigint_const(id) }; @@ -250,7 +247,10 @@ mod test { .rewrite(&mut input) .unwrap(); let output = input.build().unwrap(); - assert_eq!(output.query().unwrap(), "SELECT $2 AS id, $1 AS name"); + assert_eq!( + output.query().unwrap(), + "SELECT $2::bigint AS id, $1 AS name" + ); } #[test] diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index 6029202fd..f6c5bffc8 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -1,13 +1,10 @@ //! UPDATE statement rewriter for unique_id. -use pg_query::{ - protobuf::{ParamRef, UpdateStmt}, - NodeEnum, -}; +use pg_query::{protobuf::UpdateStmt, NodeEnum}; use super::{ super::{Error, Input, RewriteModule}, - bigint_const, + bigint_const, bigint_param, }; use crate::{frontend::router::parser::Value, net::Datum, unique_id}; @@ -42,10 +39,7 @@ impl UpdateUniqueIdRewrite { let id = unique_id::UniqueId::generator()?.next_id(); let node = if let Some(ref mut bind) = bind { - NodeEnum::ParamRef(ParamRef { - number: bind.add_parameter(Datum::Bigint(id))?, - ..Default::default() - }) + bigint_param(bind.add_parameter(Datum::Bigint(id))?) } else { bigint_const(id) }; @@ -146,7 +140,7 @@ mod test { let output = input.build().unwrap(); assert_eq!( output.query().unwrap(), - "UPDATE omnisharded SET id = $3, settings = $1 WHERE old_id = $2" + "UPDATE omnisharded SET id = $3::bigint, settings = $1 WHERE old_id = $2" ); } } From 49cd0b57bed98161b4c8accb02a2bf858c940314 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 10:01:35 -0800 Subject: [PATCH 08/23] empty query --- pgdog/src/frontend/router/parser/query/mod.rs | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 96f195fc9..ccc5bccf9 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -198,19 +198,22 @@ impl QueryParser { }; let mut input = rewrite::Input::new(&statement.ast().protobuf, context.router_context.bind); - rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input)?; - - match input.build()? { - rewrite::StepOutput::NoOp => (), - rewrite::StepOutput::Extended { parse, bind } => { - return Ok(Command::Rewrite(vec![ - ProtocolMessage::from(parse), - bind.into(), - ])) - } - rewrite::StepOutput::Simple { query } => { - return Ok(Command::Rewrite(vec![ProtocolMessage::from(query)])) - } + match rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input) + { + Ok(()) => match input.build()? { + rewrite::StepOutput::NoOp => (), + rewrite::StepOutput::Extended { parse, bind } => { + return Ok(Command::Rewrite(vec![ + ProtocolMessage::from(parse), + bind.into(), + ])) + } + rewrite::StepOutput::Simple { query } => { + return Ok(Command::Rewrite(vec![ProtocolMessage::from(query)])) + } + }, + Err(rewrite::Error::EmptyQuery) => (), // We handle empty queries below. + Err(err) => return Err(err.into()), } self.ensure_explain_recorder(statement.ast(), context); From e164dbe88b4da33607e9dc8970589a7fd3ffbef7 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 13:40:24 -0800 Subject: [PATCH 09/23] rewrite --- pgdog/src/frontend/client/query_engine/mod.rs | 4 +- pgdog/src/frontend/client_request.rs | 2 + pgdog/src/frontend/router/parser/command.rs | 6 +- pgdog/src/frontend/router/parser/query/mod.rs | 14 ++--- .../router/rewrite/{input.rs => context.rs} | 47 +++++++++++--- .../src/frontend/router/rewrite/interface.rs | 4 +- pgdog/src/frontend/router/rewrite/mod.rs | 8 +-- pgdog/src/frontend/router/rewrite/output.rs | 62 ++++++++++--------- .../router/rewrite/prepared/execute.rs | 22 ++++--- .../frontend/router/rewrite/prepared/mod.rs | 4 +- .../router/rewrite/prepared/prepare.rs | 9 ++- .../router/rewrite/unique_id/explain.rs | 20 +++--- .../router/rewrite/unique_id/insert.rs | 8 +-- .../router/rewrite/unique_id/select.rs | 16 ++--- .../router/rewrite/unique_id/update.rs | 8 +-- pgdog/src/net/protocol_message.rs | 8 +++ 16 files changed, 143 insertions(+), 99 deletions(-) rename pgdog/src/frontend/router/rewrite/{input.rs => context.rs} (65%) diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index f555978c1..ea74beb43 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -227,7 +227,9 @@ impl QueryEngine { } Command::Copy(_) => self.execute(context, &route).await?, Command::Rewrite(requests) => { - context.client_request.rewrite_extended(requests)?; + for request in requests { + request.execute(context.client_request); + } self.execute(context, &route).await?; } Command::InsertSplit(plan) => self.insert_split(context, *plan.clone()).await?, diff --git a/pgdog/src/frontend/client_request.rs b/pgdog/src/frontend/client_request.rs index 7cb1ff4af..af685ab4c 100644 --- a/pgdog/src/frontend/client_request.rs +++ b/pgdog/src/frontend/client_request.rs @@ -197,6 +197,8 @@ impl ClientRequest { .position(|p| p.code() == new_message.code()) { self.messages[pos] = new_message.clone(); + } else { + self.messages.insert(0, new_message.clone()); } } diff --git a/pgdog/src/frontend/router/parser/command.rs b/pgdog/src/frontend/router/parser/command.rs index f924e8f29..da9e557e4 100644 --- a/pgdog/src/frontend/router/parser/command.rs +++ b/pgdog/src/frontend/router/parser/command.rs @@ -1,7 +1,7 @@ use super::*; use crate::{ - frontend::{client::TransactionType, BufferedQuery}, - net::{parameter::ParameterValue, ProtocolMessage}, + frontend::{client::TransactionType, router::rewrite::RewriteAction, BufferedQuery}, + net::parameter::ParameterValue, }; use lazy_static::lazy_static; @@ -28,7 +28,7 @@ pub enum Command { value: ParameterValue, }, PreparedStatement(Prepare), - Rewrite(Vec), + Rewrite(Vec), InternalField { name: String, value: String, diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index ccc5bccf9..43d184a31 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -17,7 +17,6 @@ use crate::{ net::{ messages::{Bind, Vector}, parameter::ParameterValue, - ProtocolMessage, }, plugin::plugins, }; @@ -197,19 +196,14 @@ impl QueryParser { } }; - let mut input = rewrite::Input::new(&statement.ast().protobuf, context.router_context.bind); + let mut input = + rewrite::Context::new(&statement.ast().protobuf, context.router_context.bind); match rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input) { Ok(()) => match input.build()? { rewrite::StepOutput::NoOp => (), - rewrite::StepOutput::Extended { parse, bind } => { - return Ok(Command::Rewrite(vec![ - ProtocolMessage::from(parse), - bind.into(), - ])) - } - rewrite::StepOutput::Simple { query } => { - return Ok(Command::Rewrite(vec![ProtocolMessage::from(query)])) + rewrite::StepOutput::Rewrite(req) => { + return Ok(Command::Rewrite(req)); } }, Err(rewrite::Error::EmptyQuery) => (), // We handle empty queries below. diff --git a/pgdog/src/frontend/router/rewrite/input.rs b/pgdog/src/frontend/router/rewrite/context.rs similarity index 65% rename from pgdog/src/frontend/router/rewrite/input.rs rename to pgdog/src/frontend/router/rewrite/context.rs index 04efe7d4e..003d6c4a3 100644 --- a/pgdog/src/frontend/router/rewrite/input.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -2,14 +2,14 @@ use pg_query::protobuf::{ParseResult, RawStmt}; -use super::{Error, StepOutput}; +use super::{output::RewriteActionKind, Error, RewriteAction, StepOutput}; use crate::{ frontend::PreparedStatements, - net::{Bind, Parse, Query}, + net::{Bind, Parse, ProtocolMessage, Query}, }; #[derive(Debug, Clone)] -pub struct Input<'a> { +pub struct Context<'a> { // Most requeries won't require a rewrite. // This is a clone-free way to check. original: &'a ParseResult, @@ -19,9 +19,11 @@ pub struct Input<'a> { bind: Option<&'a Bind>, /// Bind rewritten. rewrite_bind: Option, + /// Additional messages to add to the request. + result: Vec, } -impl<'a> Input<'a> { +impl<'a> Context<'a> { /// Create new input. pub fn new(original: &'a ParseResult, bind: Option<&'a Bind>) -> Self { Self { @@ -29,6 +31,7 @@ impl<'a> Input<'a> { bind, rewrite: None, rewrite_bind: None, + result: vec![], } } @@ -79,6 +82,14 @@ impl<'a> Input<'a> { stmt.stmts.first_mut().ok_or(Error::EmptyQuery) } + /// New request mutable reference. + pub fn prepend(&mut self, message: ProtocolMessage) { + self.result.push(RewriteAction { + message, + action: RewriteActionKind::Prepend, + }); + } + /// Assemble statement and add it to the global prepared statements cache. pub fn build(mut self) -> Result { if self.rewrite.is_none() { @@ -86,22 +97,40 @@ impl<'a> Input<'a> { } else { let bind = self.rewrite_bind.take(); let stmt = self.rewrite.take().ok_or(Error::NoRewrite)?.deparse()?; + let mut result = self.result; if let Some(mut bind) = bind { let mut parse = Parse::new_anonymous(stmt); if bind.anonymous() { - Ok(StepOutput::Extended { parse, bind }) + result.push(RewriteAction { + message: parse.into(), + action: RewriteActionKind::Replace, + }); + result.push(RewriteAction { + message: bind.into(), + action: RewriteActionKind::Replace, + }); } else { let name = PreparedStatements::cache_rewritten(&parse); parse.rename_fast(&name); bind.rename(name); - Ok(StepOutput::Extended { parse, bind }) + result.push(RewriteAction { + message: parse.into(), + action: RewriteActionKind::Replace, + }); + result.push(RewriteAction { + message: bind.into(), + action: RewriteActionKind::Replace, + }); } } else { - Ok(StepOutput::Simple { - query: Query::new(stmt), - }) + result.push(RewriteAction { + message: Query::new(stmt).into(), + action: RewriteActionKind::Replace, + }); } + + Ok(StepOutput::Rewrite(result)) } } } diff --git a/pgdog/src/frontend/router/rewrite/interface.rs b/pgdog/src/frontend/router/rewrite/interface.rs index 6099bdd41..d0babb5ca 100644 --- a/pgdog/src/frontend/router/rewrite/interface.rs +++ b/pgdog/src/frontend/router/rewrite/interface.rs @@ -1,6 +1,6 @@ //! Rewrite module interface. -use super::{Error, Input}; +use super::{Context, Error}; /// Rewrite trait. /// @@ -10,5 +10,5 @@ pub trait RewriteModule { /// /// If a rewrite is needed, the module should mutate the statement /// and update the Bind message. - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error>; + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error>; } diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index 26bcd2a52..72ae301c8 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -6,18 +6,18 @@ //! 2. Multi-tuple INSERT: rewrite to send multiple INSERTs //! 3. pgdog.unique_id() call: inject a unique ID //! +pub mod context; pub mod error; -pub mod input; pub mod insert_split; pub mod interface; pub mod output; pub mod prepared; pub mod unique_id; +pub use context::Context; pub use error::Error; -pub use input::Input; pub use interface::RewriteModule; -pub use output::StepOutput; +pub use output::{RewriteAction, StepOutput}; use crate::frontend::PreparedStatements; @@ -35,7 +35,7 @@ impl<'a> Rewrite<'a> { } impl RewriteModule for Rewrite<'_> { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { // N.B.: the ordering here matters! // // First, we need to inject the unique ID into the query. Once that's done, diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index 540be0629..a512a5859 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,6 +1,6 @@ use crate::{ frontend::ClientRequest, - net::{Bind, Describe, Error, FromBytes, Parse, Protocol, ProtocolMessage, Query, ToBytes}, + net::{Protocol, ProtocolMessage}, }; #[derive(Debug, Clone)] @@ -10,50 +10,56 @@ pub struct RewrittenRequest { pub renamed: Option, } -impl RewrittenRequest { - /// Rewrite client request in-place, - /// making sure all messages use new prepared statement names. - pub fn rewrite_in_place(&self, request: &mut ClientRequest) -> Result<(), Error> { - for message in &self.messages { - let code = message.code(); - if let Some(pos) = request.messages.iter().position(|p| p.code() == code) { - request.messages[pos] = message.clone(); - } - } - - if let Some(ref renamed) = self.renamed { - for message in request.messages.iter_mut() { - // Rename describe to the new prepared statement. - if message.code() == 'D' { - let mut describe = Describe::from_bytes(message.to_bytes()?)?; - if !describe.is_statement() { - describe.rename(renamed); - } +#[derive(Debug, Clone)] +pub struct RewriteAction { + pub(super) message: ProtocolMessage, + pub(super) action: RewriteActionKind, +} - *message = ProtocolMessage::from(describe); +impl RewriteAction { + /// Execute rewrite action. + pub fn execute(&self, request: &mut ClientRequest) { + match self.action { + RewriteActionKind::Append => request.push(self.message.clone()), + RewriteActionKind::Replace => { + if let Some(pos) = request.iter().position(|p| p.code() == self.message.code()) { + request[pos] = self.message.clone(); } } + RewriteActionKind::Prepend => request.insert(0, self.message.clone()), } - - Ok(()) } } +#[derive(Debug, Clone, PartialEq)] +pub(super) enum RewriteActionKind { + Replace, + Prepend, + #[allow(dead_code)] + Append, +} + /// Output of a single rewrite step. #[derive(Debug, Clone)] pub enum StepOutput { NoOp, - Extended { parse: Parse, bind: Bind }, - Simple { query: Query }, + Rewrite(Vec), } impl StepOutput { /// Get rewritten query, if any. pub fn query(&self) -> Result<&str, ()> { match self { - Self::Extended { parse, .. } => Ok(parse.query()), - Self::Simple { query } => Ok(query.query()), - _ => Err(()), + Self::NoOp => Err(()), + Self::Rewrite(actions) => { + for action in actions { + if let Some(query) = action.message.query() { + return Ok(query); + } + } + + Err(()) + } } } } diff --git a/pgdog/src/frontend/router/rewrite/prepared/execute.rs b/pgdog/src/frontend/router/rewrite/prepared/execute.rs index 16c087ade..9071c3460 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/execute.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/execute.rs @@ -2,8 +2,8 @@ use pg_query::NodeEnum; -use super::super::{Error, Input, RewriteModule}; -use crate::frontend::PreparedStatements; +use super::super::{Context, Error, RewriteModule}; +use crate::{frontend::PreparedStatements, net::ProtocolMessage}; /// Rewriter for EXECUTE statements. /// @@ -21,7 +21,7 @@ impl<'a> ExecuteRewrite<'a> { } impl RewriteModule for ExecuteRewrite<'_> { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { if let Some(NodeEnum::ExecuteStmt(stmt)) = input .stmt()? .stmt @@ -41,8 +41,13 @@ impl RewriteModule for ExecuteRewrite<'_> { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - stmt.name = new_name; + stmt.name = new_name.clone(); } + + input.prepend(ProtocolMessage::Prepare { + name: new_name, + statement: parse.query().to_string(), + }); } Ok(()) @@ -62,18 +67,17 @@ mod test { // First prepare the statement let mut prepare_rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Input::new(&prepare_stmt, None); + let mut input = Context::new(&prepare_stmt, None); prepare_rewrite.rewrite(&mut input).unwrap(); // Now execute it let execute_stmt = pg_query::parse("EXECUTE test(1, 2, 3)").unwrap().protobuf; let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Input::new(&execute_stmt, None); + let mut input = Context::new(&execute_stmt, None); execute_rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); - assert!(query.contains("__pgdog_")); - assert!(!query.contains("EXECUTE test")); + assert_eq!(query, "EXECUTE __pgdog_1(1, 2, 3)"); } #[test] @@ -83,7 +87,7 @@ mod test { .protobuf; let prepared_statements = PreparedStatements::default(); let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Input::new(&execute_stmt, None); + let mut input = Context::new(&execute_stmt, None); let result = execute_rewrite.rewrite(&mut input); assert!(result.is_err()); } diff --git a/pgdog/src/frontend/router/rewrite/prepared/mod.rs b/pgdog/src/frontend/router/rewrite/prepared/mod.rs index 3bc7d40aa..3b9df8089 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/mod.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/mod.rs @@ -8,7 +8,7 @@ mod prepare; pub use execute::ExecuteRewrite; pub use prepare::PrepareRewrite; -use super::{Error, Input, RewriteModule}; +use super::{Context, Error, RewriteModule}; use crate::frontend::PreparedStatements; /// Combined rewriter for PREPARE and EXECUTE statements. @@ -25,7 +25,7 @@ impl<'a> PreparedRewrite<'a> { } impl RewriteModule for PreparedRewrite<'_> { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { PrepareRewrite::new(self.prepared_statements).rewrite(input)?; ExecuteRewrite::new(self.prepared_statements).rewrite(input)?; Ok(()) diff --git a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs index 8f76492ed..f4b63c2a1 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs @@ -2,7 +2,7 @@ use pg_query::NodeEnum; -use super::super::{Error, Input, RewriteModule}; +use super::super::{Context, Error, RewriteModule}; use crate::frontend::PreparedStatements; /// Rewriter for PREPARE statements. @@ -21,7 +21,7 @@ impl<'a> PrepareRewrite<'a> { } impl RewriteModule for PrepareRewrite<'_> { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { if let Some(NodeEnum::PrepareStmt(stmt)) = input .stmt()? .stmt @@ -64,11 +64,10 @@ mod test { .protobuf; let mut prepared_statements = PreparedStatements::default(); let mut rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); - assert!(query.contains("__pgdog_")); - assert!(!query.contains("PREPARE test")); + assert_eq!(query, "PREPARE __pgdog_1 AS SELECT $1, $2, $3"); } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs index 009a9c1c5..c73cd4eee 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -3,7 +3,7 @@ use pg_query::NodeEnum; use super::{ - super::{Error, Input, RewriteModule}, + super::{Context, Error, RewriteModule}, InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite, }; @@ -11,7 +11,7 @@ use super::{ pub struct ExplainUniqueIdRewrite {} impl RewriteModule for ExplainUniqueIdRewrite { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { // Check if this is an EXPLAIN statement let is_explain = matches!( input @@ -56,7 +56,7 @@ impl RewriteModule for ExplainUniqueIdRewrite { } impl ExplainUniqueIdRewrite { - fn rewrite_explain_select(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite_explain_select(&mut self, input: &mut Context<'_>) -> Result<(), Error> { // Check if the inner SELECT needs rewriting let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt()? @@ -98,7 +98,7 @@ impl ExplainUniqueIdRewrite { Ok(()) } - fn rewrite_explain_insert(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite_explain_insert(&mut self, input: &mut Context<'_>) -> Result<(), Error> { // Check if the inner INSERT needs rewriting let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt()? @@ -140,7 +140,7 @@ impl ExplainUniqueIdRewrite { Ok(()) } - fn rewrite_explain_update(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite_explain_update(&mut self, input: &mut Context<'_>) -> Result<(), Error> { // Check if the inner UPDATE needs rewriting let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt()? @@ -197,7 +197,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -215,7 +215,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -234,7 +234,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -252,7 +252,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -265,7 +265,7 @@ mod test { fn test_explain_no_unique_id() { let stmt = pg_query::parse(r#"EXPLAIN SELECT 1"#).unwrap().protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index 5655b25cb..101d434ff 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -1,7 +1,7 @@ use pg_query::{protobuf::InsertStmt, NodeEnum}; use super::{ - super::{Error, Input, RewriteModule}, + super::{Context, Error, RewriteModule}, bigint_const, bigint_param, }; use crate::{ @@ -69,7 +69,7 @@ impl InsertUniqueIdRewrite { } impl RewriteModule for InsertUniqueIdRewrite { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { let need_rewrite = if let Some(NodeEnum::InsertStmt(stmt)) = input .stmt()? .stmt @@ -123,7 +123,7 @@ mod test { .unwrap() .protobuf; let mut insert = InsertUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); insert.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -162,7 +162,7 @@ mod test { }, ], ); - let mut input = Input::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind)); InsertUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index 6ed7eecfe..9b2375ef6 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -6,7 +6,7 @@ use pg_query::{ }; use super::{ - super::{Error, Input, RewriteModule}, + super::{Context, Error, RewriteModule}, bigint_const, bigint_param, }; use crate::{frontend::router::parser::Value, net::Datum, unique_id}; @@ -172,7 +172,7 @@ impl SelectUniqueIdRewrite { } impl RewriteModule for SelectUniqueIdRewrite { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { let need_rewrite = if let Some(NodeEnum::SelectStmt(stmt)) = input .stmt()? .stmt @@ -220,7 +220,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); println!("output: {}", output.query().unwrap()); @@ -242,7 +242,7 @@ mod test { data: "test".into(), }], ); - let mut input = Input::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind)); SelectUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); @@ -263,7 +263,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -278,7 +278,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -295,7 +295,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -305,7 +305,7 @@ mod test { fn test_no_rewrite_when_no_unique_id() { let stmt = pg_query::parse(r#"SELECT id FROM users"#).unwrap().protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index f6c5bffc8..462c37164 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -3,7 +3,7 @@ use pg_query::{protobuf::UpdateStmt, NodeEnum}; use super::{ - super::{Error, Input, RewriteModule}, + super::{Context, Error, RewriteModule}, bigint_const, bigint_param, }; use crate::{frontend::router::parser::Value, net::Datum, unique_id}; @@ -55,7 +55,7 @@ impl UpdateUniqueIdRewrite { } impl RewriteModule for UpdateUniqueIdRewrite { - fn rewrite(&mut self, input: &mut Input<'_>) -> Result<(), Error> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { let need_rewrite = if let Some(NodeEnum::UpdateStmt(stmt)) = input .stmt()? .stmt @@ -104,7 +104,7 @@ mod test { .unwrap() .protobuf; let mut update = UpdateUniqueIdRewrite::default(); - let mut input = Input::new(&stmt, None); + let mut input = Context::new(&stmt, None); update.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -133,7 +133,7 @@ mod test { }, ], ); - let mut input = Input::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind)); UpdateUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); diff --git a/pgdog/src/net/protocol_message.rs b/pgdog/src/net/protocol_message.rs index 239a267c8..5cb19592b 100644 --- a/pgdog/src/net/protocol_message.rs +++ b/pgdog/src/net/protocol_message.rs @@ -58,6 +58,14 @@ impl ProtocolMessage { Self::CopyFail(copy_fail) => copy_fail.len(), } } + + pub fn query(&self) -> Option<&str> { + match self { + ProtocolMessage::Query(query) => Some(query.query()), + ProtocolMessage::Parse(parse) => Some(parse.query()), + _ => None, + } + } } impl Protocol for ProtocolMessage { From 23df336dd054c2194d020fc9ad193b1ad624d0ca Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 30 Nov 2025 19:50:47 -0800 Subject: [PATCH 10/23] Fix rewrite --- pgdog/src/frontend/router/parser/query/mod.rs | 2 +- pgdog/src/frontend/router/rewrite/output.rs | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 43d184a31..c30ee8006 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -201,7 +201,7 @@ impl QueryParser { match rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input) { Ok(()) => match input.build()? { - rewrite::StepOutput::NoOp => (), + rewrite::StepOutput::NoOp => {} rewrite::StepOutput::Rewrite(req) => { return Ok(Command::Rewrite(req)); } diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index a512a5859..6a6e3cf79 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,7 +1,6 @@ -use crate::{ - frontend::ClientRequest, - net::{Protocol, ProtocolMessage}, -}; +use crate::{frontend::ClientRequest, net::ProtocolMessage}; + +use std::mem::discriminant; #[derive(Debug, Clone)] pub struct RewrittenRequest { @@ -22,7 +21,10 @@ impl RewriteAction { match self.action { RewriteActionKind::Append => request.push(self.message.clone()), RewriteActionKind::Replace => { - if let Some(pos) = request.iter().position(|p| p.code() == self.message.code()) { + if let Some(pos) = request + .iter() + .position(|p| discriminant(p) == discriminant(&self.message)) + { request[pos] = self.message.clone(); } } From c1b0eca481057a7260a2066d3291c351d2ecc073 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 1 Dec 2025 09:30:30 -0800 Subject: [PATCH 11/23] Start insert-split using new engine --- pgdog/src/frontend/router/rewrite/context.rs | 6 +- .../router/rewrite/insert_split/mod.rs | 115 +++++++++++++++++- pgdog/src/net/error.rs | 3 + pgdog/src/net/messages/bind.rs | 17 +++ 4 files changed, 139 insertions(+), 2 deletions(-) diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs index 003d6c4a3..aa49a9f30 100644 --- a/pgdog/src/frontend/router/rewrite/context.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -13,7 +13,7 @@ pub struct Context<'a> { // Most requeries won't require a rewrite. // This is a clone-free way to check. original: &'a ParseResult, - // If a rewrite was done, the statement is saved here. + // If an in-place rewrite was done, the statement is saved here. rewrite: Option, /// Original bind message, if any. bind: Option<&'a Bind>, @@ -82,6 +82,10 @@ impl<'a> Context<'a> { stmt.stmts.first_mut().ok_or(Error::EmptyQuery) } + pub fn proto_version(&self) -> i32 { + self.original.version + } + /// New request mutable reference. pub fn prepend(&mut self, message: ProtocolMessage) { self.result.push(RewriteAction { diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs index c400252fb..28c974621 100644 --- a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -1 +1,114 @@ -pub struct InsertSplitRewrite {} +use pg_query::{ + protobuf::{ParseResult, RawStmt}, + Node, NodeEnum, +}; + +use crate::net::Bind; + +use super::*; + +#[derive(Default)] +pub struct InsertSplitRewrite; + +impl RewriteModule for InsertSplitRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + let mut inserts = vec![]; + if let Some(NodeEnum::InsertStmt(insert)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::SelectStmt(ref select)) = insert + .select_stmt + .as_ref() + .and_then(|node| node.node.as_ref()) + { + if select.values_lists.len() <= 1 { + return Ok(()); + } + + // Clone the original statement only once. + let mut proto_select = select.clone(); + let mut proto_insert = insert.clone(); + proto_select.values_lists.clear(); + proto_insert.select_stmt = None; + + // Generate new INSERT statements, with one VALUES tuple each. + for values in &select.values_lists { + let mut new_insert = proto_insert.clone(); + let mut new_select = proto_select.clone(); + let mut new_values = values.clone(); + let mut new_bind = Bind::default(); + + // Rewrite the parameter references + // and create new Bind message for each INSERT statement. + if let Some(NodeEnum::List(list)) = new_values.node.as_mut() { + for value in list.items.iter_mut() { + if let Some(NodeEnum::ParamRef(param)) = value.node.as_mut() { + let parameter = input + .bind() + .map(|bind| bind.parameter(param.number as usize - 1).ok()) + .flatten() + .flatten(); + if let Some(parameter) = parameter { + param.number = new_bind.add_existing(parameter)?; + } + } + } + } + new_select.values_lists.push(new_values); + new_insert.select_stmt = Some(Box::new(Node { + node: Some(NodeEnum::SelectStmt(new_select)), + })); + let result = ParseResult { + version: input.proto_version(), + stmts: vec![RawStmt { + stmt: Some(Box::new(Node { + node: Some(NodeEnum::InsertStmt(new_insert)), + })), + ..Default::default() + }], + }; + inserts.push((result, new_bind)); + } + } + } + + drop(inserts); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::net::bind::Parameter; + + use super::*; + + #[test] + fn test_insert_split() { + let stmt = pg_query::parse( + "INSERT INTO users (id, email, created_at) + VALUES ($1, 'test@test.com', NOW()), (123, $2, '2025-01-01') RETURNING *", + ) + .unwrap(); + let bind = Bind::new_params( + "", + &[ + Parameter { + len: 4, + data: "1234".into(), + }, + Parameter { + len: 14, + data: "hello@test.com".into(), + }, + ], + ); + let mut context = Context::new(&stmt.protobuf, Some(&bind)); + let mut module = InsertSplitRewrite::default(); + module.rewrite(&mut context).unwrap(); + } +} diff --git a/pgdog/src/net/error.rs b/pgdog/src/net/error.rs index d28bf04ee..01c849861 100644 --- a/pgdog/src/net/error.rs +++ b/pgdog/src/net/error.rs @@ -99,4 +99,7 @@ pub enum Error { #[error("not a pg_lsn")] NotPgLsn, + + #[error("multiple parameter formats in same bind message")] + MultipleBindFormats, } diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index e1803593f..a3ddaaa4c 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -226,6 +226,23 @@ impl Bind { Ok(self.params.len() as i32) } + /// Add parameter with provided format. + pub fn add_existing(&mut self, param: ParameterWithFormat<'_>) -> Result { + let format = param.format; + let existing = self.codes.get(0).cloned(); + match (format, existing) { + (Format::Text, None) + | (Format::Text, Some(Format::Text)) + | (Format::Binary, Some(Format::Binary)) => (), + (Format::Binary, None) => self.codes.push(format), + (Format::Binary, Some(Format::Text)) | (Format::Text, Some(Format::Binary)) => { + return Err(Error::MultipleBindFormats); + } + } + self.params.push(param.parameter.clone()); + Ok(self.params.len() as i32) + } + pub fn new_statement(name: &str) -> Self { Self { statement: Bytes::from(name.to_string() + "\0"), From e422ba5c8dd3c987e0e52a91e74e8c428c5e4e47 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 1 Dec 2025 09:31:17 -0800 Subject: [PATCH 12/23] clippy --- pgdog/src/frontend/router/rewrite/insert_split/mod.rs | 3 +-- pgdog/src/net/messages/bind.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs index 28c974621..9d036af2d 100644 --- a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -48,8 +48,7 @@ impl RewriteModule for InsertSplitRewrite { if let Some(NodeEnum::ParamRef(param)) = value.node.as_mut() { let parameter = input .bind() - .map(|bind| bind.parameter(param.number as usize - 1).ok()) - .flatten() + .and_then(|bind| bind.parameter(param.number as usize - 1).ok()) .flatten(); if let Some(parameter) = parameter { param.number = new_bind.add_existing(parameter)?; diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index a3ddaaa4c..da562c653 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -229,7 +229,7 @@ impl Bind { /// Add parameter with provided format. pub fn add_existing(&mut self, param: ParameterWithFormat<'_>) -> Result { let format = param.format; - let existing = self.codes.get(0).cloned(); + let existing = self.codes.first().cloned(); match (format, existing) { (Format::Text, None) | (Format::Text, Some(Format::Text)) From 37d4fb83bae1142f4c649654ecd947764546b95a Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 1 Dec 2025 17:04:00 -0800 Subject: [PATCH 13/23] save --- integration/rust/tests/integration/mod.rs | 1 + .../rust/tests/integration/unique_id.rs | 37 ++++++ pgdog/src/backend/databases.rs | 4 +- pgdog/src/backend/pool/cluster.rs | 10 +- .../replication/logical/subscriber/context.rs | 1 + .../frontend/client/query_engine/context.rs | 5 + pgdog/src/frontend/client/query_engine/mod.rs | 28 ++++- .../client/query_engine/route_query.rs | 56 +++++---- pgdog/src/frontend/error.rs | 3 + pgdog/src/frontend/prepared_statements/mod.rs | 22 ++-- pgdog/src/frontend/router/cli.rs | 1 + pgdog/src/frontend/router/context.rs | 9 +- pgdog/src/frontend/router/parser/cache.rs | 62 +++++++--- pgdog/src/frontend/router/parser/context.rs | 22 +--- pgdog/src/frontend/router/parser/error.rs | 3 - .../frontend/router/parser/query/explain.rs | 6 +- pgdog/src/frontend/router/parser/query/mod.rs | 51 ++++----- .../src/frontend/router/parser/query/show.rs | 4 +- .../src/frontend/router/parser/query/test.rs | 93 ++++++++++++--- pgdog/src/frontend/router/rewrite/context.rs | 74 +++++++----- pgdog/src/frontend/router/rewrite/error.rs | 6 + .../router/rewrite/insert_split/mod.rs | 2 +- pgdog/src/frontend/router/rewrite/mod.rs | 3 + pgdog/src/frontend/router/rewrite/output.rs | 18 ++- .../router/rewrite/prepared/execute.rs | 6 +- .../router/rewrite/prepared/prepare.rs | 2 +- pgdog/src/frontend/router/rewrite/request.rs | 108 ++++++++++++++++++ pgdog/src/frontend/router/rewrite/state.rs | 1 + .../router/rewrite/unique_id/explain.rs | 31 +++-- .../router/rewrite/unique_id/insert.rs | 53 +++++++-- .../frontend/router/rewrite/unique_id/mod.rs | 20 +++- .../router/rewrite/unique_id/select.rs | 53 ++++++--- .../router/rewrite/unique_id/update.rs | 23 +++- 33 files changed, 609 insertions(+), 209 deletions(-) create mode 100644 integration/rust/tests/integration/unique_id.rs create mode 100644 pgdog/src/frontend/router/rewrite/request.rs create mode 100644 pgdog/src/frontend/router/rewrite/state.rs diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 4c9229a39..5f68ececf 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -20,3 +20,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unique_id; diff --git a/integration/rust/tests/integration/unique_id.rs b/integration/rust/tests/integration/unique_id.rs new file mode 100644 index 000000000..725c7f197 --- /dev/null +++ b/integration/rust/tests/integration/unique_id.rs @@ -0,0 +1,37 @@ +use rust::setup::connections_sqlx; +use sqlx::{Executor, Row}; + +#[tokio::test] +async fn unique_id_returns_bigint() -> Result<(), Box> { + let conns = connections_sqlx().await; + let sharded = conns.get(1).cloned().unwrap(); + + // Simple query + let row = sharded.fetch_one("SELECT pgdog.unique_id() AS id").await?; + let mut id: i64 = row.get("id"); + + assert!( + id > 0, + "unique_id should return a positive bigint, got {id}" + ); + + for _ in 0..100 { + // Prepared statement + let row = sqlx::query("SELECT pgdog.unique_id() AS id") + .fetch_one(&sharded) + .await?; + let prepared_id: i64 = row.get("id"); + assert!( + prepared_id > 0, + "prepared unique_id should return a positive bigint, got {prepared_id}" + ); + + assert!( + prepared_id > id, + "prepared id should be greater than simple query id" + ); + id = prepared_id; + } + + Ok(()) +} diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index afacd902c..5ed91bf9f 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -359,9 +359,9 @@ impl Databases { for cluster in self.all().values() { cluster.launch(); - if cluster.pooler_mode() == PoolerMode::Session && cluster.router_needed() { + if cluster.pooler_mode() == PoolerMode::Session && cluster.use_parser() { warn!( - r#"user "{}" for database "{}" requires transaction mode to route queries"#, + r#"user "{}" for database "{}" requires transaction mode to parse and route queries"#, cluster.user(), cluster.name() ); diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 1b5445530..73ae20762 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -413,10 +413,14 @@ impl Cluster { self.stats.clone() } - /// We'll need the query router to figure out - /// where a query should go. - pub fn router_needed(&self) -> bool { + /// We need to parse the query using pg_query. + pub fn use_parser(&self) -> bool { !(self.shards().len() == 1 && (self.read_only() || self.write_only())) + || self.query_parser_enabled + || self.multi_tenant.is_some() + || self.pub_sub_enabled() + || self.prepared_statements() == &PreparedStatements::Full + || self.dry_run } /// Multi-tenant config. diff --git a/pgdog/src/backend/replication/logical/subscriber/context.rs b/pgdog/src/backend/replication/logical/subscriber/context.rs index da7ed70ac..85c6ca785 100644 --- a/pgdog/src/backend/replication/logical/subscriber/context.rs +++ b/pgdog/src/backend/replication/logical/subscriber/context.rs @@ -57,6 +57,7 @@ impl<'a> StreamContext<'a> { &self.params, None, 1, + None, )?) } } diff --git a/pgdog/src/frontend/client/query_engine/context.rs b/pgdog/src/frontend/client/query_engine/context.rs index 87ede236a..368ba1bcc 100644 --- a/pgdog/src/frontend/client/query_engine/context.rs +++ b/pgdog/src/frontend/client/query_engine/context.rs @@ -4,6 +4,7 @@ use crate::{ backend::pool::{connection::mirror::Mirror, stats::MemoryStats}, frontend::{ client::{timeouts::Timeouts, TransactionType}, + router::parser::cache::CachedAst, Client, ClientRequest, PreparedStatements, }, net::{BackendKeyData, Parameters, Stream}, @@ -38,6 +39,8 @@ pub struct QueryEngineContext<'a> { pub(super) rollback: bool, /// Omnisharded modulo. pub(super) omni_sticky_index: usize, + /// Query AST. + pub(super) ast: Option, } impl<'a> QueryEngineContext<'a> { @@ -58,6 +61,7 @@ impl<'a> QueryEngineContext<'a> { requests_left: 0, rollback: false, omni_sticky_index: client.omni_sticky_index, + ast: None, } } @@ -83,6 +87,7 @@ impl<'a> QueryEngineContext<'a> { requests_left: 0, rollback: false, omni_sticky_index: thread_rng().gen_range(1..usize::MAX), + ast: None, } } diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index ea74beb43..5be73ebba 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -3,7 +3,11 @@ use crate::{ config::config, frontend::{ client::query_engine::hooks::QueryEngineHooks, - router::{parser::Shard, Route}, + router::{ + parser::Shard, + rewrite::{self, RewriteRequest}, + Route, + }, BufferedQuery, Client, Command, Comms, Error, Router, RouterContext, Stats, }, net::{BackendKeyData, ErrorResponse, Message, Parameters}, @@ -114,6 +118,28 @@ impl QueryEngine { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result { + // Check that we have the latest version of the config. + if let Some(error) = self.ensure_cluster(context.in_transaction()).await { + self.error_response(context, error).await?; + return Ok(QueryEngineOutput::Executed); + } + + if let Ok(cluster) = self.backend.cluster() { + if cluster.use_parser() && context.ast.is_none() { + // Execute request rewrite, if needed. + let mut rewrite = RewriteRequest::new( + context.client_request, + self.backend.cluster()?, + context.prepared_statements, + ); + match rewrite.execute() { + Ok(ast) => context.ast = Some(ast), + Err(rewrite::Error::EmptyQuery) => (), + Err(err) => return Err(err.into()), + } + } + } + self.stats .received(context.client_request.total_message_len()); self.set_state(State::Active); // Client is active. diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index 6dd9382f0..a8aff8054 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -1,9 +1,41 @@ use pgdog_config::PoolerMode; use tracing::trace; +use crate::backend::Cluster; + use super::*; impl QueryEngine { + pub fn cluster(&self) -> Result<&Cluster, Error> { + Ok(self.backend.cluster()?) + } + + /// Check that cluster still exists. + pub async fn ensure_cluster(&mut self, in_transaction: bool) -> Option { + if let Ok(cluster) = self.backend.cluster() { + let identifier = cluster.identifier(); + + if !in_transaction && !cluster.online() { + // Reload cluster config. + if let Err(_) = self.backend.safe_reload().await { + return Some(ErrorResponse::connection( + &identifier.user, + &identifier.database, + )); + } + + if let Err(_) = self.backend.cluster() { + return Some(ErrorResponse::connection( + &identifier.user, + &identifier.database, + )); + } + } + } + + None + } + pub(super) async fn route_transaction( &mut self, context: &mut QueryEngineContext<'_>, @@ -18,28 +50,7 @@ impl QueryEngine { // Admin doesn't have a cluster. let cluster = if let Ok(cluster) = self.backend.cluster() { - if !context.in_transaction() && !cluster.online() { - let identifier = cluster.identifier(); - - // Reload cluster config. - self.backend.safe_reload().await?; - - match self.backend.cluster() { - Ok(cluster) => cluster, - Err(_) => { - // Cluster is gone. - self.error_response( - context, - ErrorResponse::connection(&identifier.user, &identifier.database), - ) - .await?; - - return Ok(false); - } - } - } else { - cluster - } + cluster } else { return Ok(true); }; @@ -51,6 +62,7 @@ impl QueryEngine { context.params, context.transaction, context.omni_sticky_index, + context.ast.as_ref(), )?; match self.router.query(router_context) { Ok(cmd) => { diff --git a/pgdog/src/frontend/error.rs b/pgdog/src/frontend/error.rs index ac23367b0..c43409e69 100644 --- a/pgdog/src/frontend/error.rs +++ b/pgdog/src/frontend/error.rs @@ -50,6 +50,9 @@ pub enum Error { #[error("unique id: {0}")] UniqueId(#[from] unique_id::Error), + + #[error("rewrite: {0}")] + Rewrite(#[from] crate::frontend::router::rewrite::Error), } impl Error { diff --git a/pgdog/src/frontend/prepared_statements/mod.rs b/pgdog/src/frontend/prepared_statements/mod.rs index 646848e1f..255055aef 100644 --- a/pgdog/src/frontend/prepared_statements/mod.rs +++ b/pgdog/src/frontend/prepared_statements/mod.rs @@ -9,7 +9,7 @@ use tracing::debug; use crate::{ config::{config, PreparedStatements as PreparedStatementsLevel}, - frontend::router::parser::RewritePlan, + frontend::router::parser::{cache::CachedAst, RewritePlan}, net::{Parse, ProtocolMessage}, stats::memory::MemoryUsage, }; @@ -31,6 +31,7 @@ pub struct PreparedStatements { pub(super) local: HashMap, pub(super) level: PreparedStatementsLevel, pub(super) memory_used: usize, + pub(super) rewrite: HashMap, } impl MemoryUsage for PreparedStatements { @@ -49,6 +50,7 @@ impl Default for PreparedStatements { local: HashMap::default(), level: PreparedStatementsLevel::Extended, memory_used: 0, + rewrite: HashMap::new(), } } } @@ -97,15 +99,15 @@ impl PreparedStatements { parse.rename_fast(&name) } - /// Store a rewritten statement in the global cache forever. - pub fn cache_rewritten(parse: &Parse) -> String { - let exists = Self::global().read().name(parse); - if let Some(exists) = exists { - exists - } else { - let (_, name) = Self::global().write().insert(parse); - name - } + /// Get original AST for a prepared statement + /// we have rewritten. + pub fn get_original_ast(&self, name: &str) -> Option<&CachedAst> { + self.rewrite.get(name) + } + + /// Save original AST for re-use by subsequent Bind messages. + pub fn save_original_ast(&mut self, name: &str, ast: &CachedAst) { + self.rewrite.insert(name.to_string(), ast.clone()); } /// Retrieve stored rewrite plan for a prepared statement, if any. diff --git a/pgdog/src/frontend/router/cli.rs b/pgdog/src/frontend/router/cli.rs index ceb16b5ed..c733c8a52 100644 --- a/pgdog/src/frontend/router/cli.rs +++ b/pgdog/src/frontend/router/cli.rs @@ -55,6 +55,7 @@ impl RouterCli { ¶ms, None, 1, + None, )?)?; result.push(cmd); } diff --git a/pgdog/src/frontend/router/context.rs b/pgdog/src/frontend/router/context.rs index f430cc0f8..cbef297be 100644 --- a/pgdog/src/frontend/router/context.rs +++ b/pgdog/src/frontend/router/context.rs @@ -1,7 +1,10 @@ use super::Error; use crate::{ backend::Cluster, - frontend::{client::TransactionType, BufferedQuery, ClientRequest, PreparedStatements}, + frontend::{ + client::TransactionType, router::parser::cache::CachedAst, BufferedQuery, ClientRequest, + PreparedStatements, + }, net::{Bind, Parameters}, }; @@ -27,6 +30,8 @@ pub struct RouterContext<'a> { pub two_pc: bool, /// Sticky omnisharded index. pub omni_sticky_index: usize, + /// Query ast. + pub ast: Option<&'a CachedAst>, } impl<'a> RouterContext<'a> { @@ -37,6 +42,7 @@ impl<'a> RouterContext<'a> { params: &'a Parameters, transaction: Option, omni_sticky_index: usize, + ast: Option<&'a CachedAst>, ) -> Result { let query = buffer.query()?; let bind = buffer.parameters()?; @@ -53,6 +59,7 @@ impl<'a> RouterContext<'a> { executable: buffer.executable(), two_pc: cluster.two_pc_enabled(), omni_sticky_index, + ast, }) } diff --git a/pgdog/src/frontend/router/parser/cache.rs b/pgdog/src/frontend/router/parser/cache.rs index 4156c8e66..0797c1ea4 100644 --- a/pgdog/src/frontend/router/parser/cache.rs +++ b/pgdog/src/frontend/router/parser/cache.rs @@ -55,7 +55,7 @@ pub struct CachedAstInner { /// Cached AST. pub ast: ParseResult, /// AST stats. - pub stats: Mutex, + pub stats: Arc>, } impl Deref for CachedAst { @@ -70,6 +70,15 @@ impl CachedAst { /// Create new cache entry from pg_query's AST. pub fn new(query: &str, schema: &ShardingSchema) -> std::result::Result { let ast = parse(query).map_err(super::Error::PgQuery)?; + Self::new_parsed(query, ast, schema) + } + + /// Create new cached ast entry with the AST already pre-parsed. + pub fn new_parsed( + query: &str, + ast: ParseResult, + schema: &ShardingSchema, + ) -> std::result::Result { let (shard, role) = comment(query, schema)?; Ok(Self { @@ -77,10 +86,10 @@ impl CachedAst { comment_shard: shard, comment_role: role, inner: Arc::new(CachedAstInner { - stats: Mutex::new(Stats { + stats: Arc::new(Mutex::new(Stats { hits: 1, ..Default::default() - }), + })), ast, }), }) @@ -185,6 +194,26 @@ impl Cache { debug!("ast cache size set to {}", capacity); } + /// Save a pre-parsed query into the cache. + pub fn save( + &self, + query: &str, + ast: ParseResult, + schema: &ShardingSchema, + ) -> std::result::Result { + if let Some(exists) = self.check_existing(query) { + return Ok(exists); + } + + let entry = CachedAst::new_parsed(query, ast, schema)?; + + let mut guard = self.inner.lock(); + guard.queries.put(query.to_owned(), entry.clone()); + guard.stats.misses += 1; + + Ok(entry) + } + /// Parse a statement by either getting it from cache /// or using pg_query parser. /// @@ -196,16 +225,8 @@ impl Cache { query: &str, schema: &ShardingSchema, ) -> std::result::Result { - { - let mut guard = self.inner.lock(); - let ast = guard.queries.get_mut(query).map(|entry| { - entry.stats.lock().hits += 1; // No contention on this. - entry.clone() - }); - if let Some(ast) = ast { - guard.stats.hits += 1; - return Ok(ast); - } + if let Some(exists) = self.check_existing(query) { + return Ok(exists); } // Parse query without holding lock. @@ -218,6 +239,21 @@ impl Cache { Ok(entry) } + /// Check existing entry and return it if exists. + fn check_existing(&self, query: &str) -> Option { + let mut guard = self.inner.lock(); + let ast = guard.queries.get_mut(query).map(|entry| { + entry.stats.lock().hits += 1; // No contention on this. + entry.clone() + }); + if let Some(ast) = ast { + guard.stats.hits += 1; + Some(ast) + } else { + None + } + } + /// Parse a statement but do not store it in the cache. pub fn parse_uncached( &self, diff --git a/pgdog/src/frontend/router/parser/context.rs b/pgdog/src/frontend/router/parser/context.rs index 62469e450..cccbbf4c7 100644 --- a/pgdog/src/frontend/router/parser/context.rs +++ b/pgdog/src/frontend/router/parser/context.rs @@ -2,7 +2,6 @@ use std::os::raw::c_void; -use pgdog_config::PreparedStatements as ConfigPreparedStatements; use pgdog_plugin::pg_query::protobuf::ParseResult; use pgdog_plugin::{PdParameters, PdRouterContext, PdStatement}; @@ -22,8 +21,6 @@ use super::Error; /// and its inputs. /// pub struct QueryParserContext<'a> { - /// whether query_parser_enabled has been set. - pub(super) query_parser_enabled: bool, /// Cluster is read-only, i.e. has no primary. pub(super) read_only: bool, /// Cluster has no replicas, only a primary. @@ -36,13 +33,9 @@ pub struct QueryParserContext<'a> { pub(super) router_context: RouterContext<'a>, /// How aggressively we want to send reads to replicas. pub(super) rw_strategy: &'a ReadWriteStrategy, - /// Are we re-writing prepared statements sent over the simple protocol? - pub(super) full_prepared_statements: bool, /// Do we need the router at all? Shortcut to bypass this for unsharded /// clusters with databases that only read or write. - pub(super) router_needed: bool, - /// Do we have support for LISTEN/NOTIFY enabled? - pub(super) pub_sub_enabled: bool, + pub(super) use_parser: bool, /// Are we running multi-tenant checks? pub(super) multi_tenant: &'a Option, /// Dry run enabled? @@ -66,11 +59,7 @@ impl<'a> QueryParserContext<'a> { shards: router_context.cluster.shards().len(), sharding_schema: router_context.cluster.sharding_schema(), rw_strategy: router_context.cluster.read_write_strategy(), - full_prepared_statements: router_context.cluster.prepared_statements() - == &ConfigPreparedStatements::Full, - query_parser_enabled: router_context.cluster.query_parser_enabled(), - router_needed: router_context.cluster.router_needed(), - pub_sub_enabled: router_context.cluster.pub_sub_enabled(), + use_parser: router_context.cluster.use_parser(), multi_tenant: router_context.cluster.multi_tenant(), dry_run: router_context.cluster.dry_run(), expanded_explain: router_context.cluster.expanded_explain(), @@ -98,12 +87,7 @@ impl<'a> QueryParserContext<'a> { /// /// Shortcut to avoid the overhead if we can. pub(super) fn use_parser(&self) -> bool { - self.query_parser_enabled - || self.full_prepared_statements - || self.router_needed - || self.pub_sub_enabled - || self.multi_tenant().is_some() - || self.dry_run + self.use_parser } /// Get the query we're parsing, if any. diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index aba4975e5..2d360f831 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -99,7 +99,4 @@ pub enum Error { #[error("prepared statement \"{0}\" doesn't exist")] PreparedStatementDoesntExist(String), - - #[error("rewrite: {0}")] - Rewrite(#[from] super::super::rewrite::Error), } diff --git a/pgdog/src/frontend/router/parser/query/explain.rs b/pgdog/src/frontend/router/parser/query/explain.rs index 8c6b133e5..fd6bc0a22 100644 --- a/pgdog/src/frontend/router/parser/query/explain.rs +++ b/pgdog/src/frontend/router/parser/query/explain.rs @@ -75,7 +75,8 @@ mod tests { let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = + RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1, None).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, @@ -102,7 +103,8 @@ mod tests { let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = + RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1, None).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index c30ee8006..f9e68c31d 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -8,7 +8,6 @@ use crate::{ router::{ context::RouterContext, parser::{OrderBy, Shard}, - rewrite::{self, RewriteModule}, round_robin, sharding::{Centroids, ContextBuilder, Value as ShardingValue}, }, @@ -178,42 +177,32 @@ impl QueryParser { let cache = Cache::get(); // Get the AST from cache or parse the statement live. - let statement = match context.query()? { - // Only prepared statements (or just extended) are cached. - BufferedQuery::Prepared(query) => { - cache.parse(query.query(), &context.sharding_schema)? - } - // Don't cache simple queries. - // - // They contain parameter values, which makes the cache - // too large to be practical. - // - // Make your clients use prepared statements - // or at least send statements with placeholders using the - // extended protocol. - BufferedQuery::Query(query) => { - cache.parse_uncached(query.query(), &context.sharding_schema)? + let statement = if let Some(ast) = context.router_context.ast.cloned() { + ast // The AST was already parsed by the rewriter module. + } else { + match context.query()? { + // Only prepared statements (or just extended) are cached. + BufferedQuery::Prepared(query) => { + cache.parse(query.query(), &context.sharding_schema)? + } + // Don't cache simple queries. + // + // They contain parameter values, which makes the cache + // too large to be practical. + // + // Make your clients use prepared statements + // or at least send statements with placeholders using the + // extended protocol. + BufferedQuery::Query(query) => { + cache.parse_uncached(query.query(), &context.sharding_schema)? + } } }; - let mut input = - rewrite::Context::new(&statement.ast().protobuf, context.router_context.bind); - match rewrite::Rewrite::new(context.router_context.prepared_statements).rewrite(&mut input) - { - Ok(()) => match input.build()? { - rewrite::StepOutput::NoOp => {} - rewrite::StepOutput::Rewrite(req) => { - return Ok(Command::Rewrite(req)); - } - }, - Err(rewrite::Error::EmptyQuery) => (), // We handle empty queries below. - Err(err) => return Err(err.into()), - } - self.ensure_explain_recorder(statement.ast(), context); // Parse hardcoded shard from a query comment. - if context.router_needed || context.dry_run { + if context.use_parser() { self.shard = statement.comment_shard.clone(); let role_override = statement.comment_role; if let Some(role) = role_override { diff --git a/pgdog/src/frontend/router/parser/query/show.rs b/pgdog/src/frontend/router/parser/query/show.rs index 9a26f3c06..032d243db 100644 --- a/pgdog/src/frontend/router/parser/query/show.rs +++ b/pgdog/src/frontend/router/parser/query/show.rs @@ -42,7 +42,7 @@ mod test_show { // First call let query = "SHOW TRANSACTION ISOLATION LEVEL"; let buffer = ClientRequest::from(vec![Query::new(query).into()]); - let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1).unwrap(); + let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1, None).unwrap(); let first = parser.parse(context).unwrap().clone(); let first_shard = first.route().shard(); @@ -51,7 +51,7 @@ mod test_show { // Second call let query = "SHOW TRANSACTION ISOLATION LEVEL"; let buffer = ClientRequest::from(vec![Query::new(query).into()]); - let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1).unwrap(); + let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1, None).unwrap(); let second = parser.parse(context).unwrap().clone(); let second_shard = second.route().shard(); diff --git a/pgdog/src/frontend/router/parser/query/test.rs b/pgdog/src/frontend/router/parser/query/test.rs index 3a8d02e8c..aca076937 100644 --- a/pgdog/src/frontend/router/parser/query/test.rs +++ b/pgdog/src/frontend/router/parser/query/test.rs @@ -78,7 +78,7 @@ fn parse_query(query: &str) -> Command { let params = Parameters::default(); let context = - RouterContext::new(&client_request, &cluster, &mut stmt, ¶ms, None, 1).unwrap(); + RouterContext::new(&client_request, &cluster, &mut stmt, ¶ms, None, 1, None).unwrap(); let command = query_parser.parse(context).unwrap().clone(); command } @@ -107,6 +107,7 @@ macro_rules! command { ¶ms, transaction, 1, + None, ) .unwrap(); let command = query_parser.parse(context).unwrap().clone(); @@ -152,6 +153,7 @@ macro_rules! query_parser { ¶ms, maybe_transaction, 1, + None, ) .unwrap(); @@ -187,6 +189,7 @@ macro_rules! parse { &Parameters::default(), None, 1, + None, ) .unwrap(), ) @@ -209,8 +212,16 @@ fn parse_with_parameters(query: &str, params: Parameters) -> Result Result { ¶meters, None, 1, + None, ) .unwrap(); @@ -490,8 +502,16 @@ fn test_set() { let mut prep_stmts = PreparedStatements::default(); let params = Parameters::default(); let transaction = Some(TransactionType::ReadWrite); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, transaction, 1).unwrap(); + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + transaction, + 1, + None, + ) + .unwrap(); let mut context = QueryParserContext::new(router_context); for read_only in [true, false] { @@ -571,8 +591,16 @@ fn update_sharding_key_errors_by_default() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let result = QueryParser::default().parse(router_context); assert!( @@ -591,8 +619,16 @@ fn update_sharding_key_ignore_mode_allows() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let command = QueryParser::default().parse(router_context).unwrap(); assert!(matches!(command, Command::Query(_))); @@ -608,8 +644,16 @@ fn update_sharding_key_rewrite_mode_not_supported() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let result = QueryParser::default().parse(router_context); assert!( @@ -628,8 +672,16 @@ fn update_sharding_key_rewrite_plan_detected() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let command = QueryParser::default().parse(router_context).unwrap(); match command { @@ -804,8 +856,16 @@ WHERE t2.account = ( .into()] .into(); let transaction = Some(TransactionType::ReadWrite); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, transaction, 1).unwrap(); + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + transaction, + 1, + None, + ) + .unwrap(); let mut context = QueryParserContext::new(router_context); let route = qp.query(&mut context).unwrap(); match route { @@ -866,7 +926,8 @@ fn test_close_direct_one_shard() { let params = Parameters::default(); let transaction = None; - let context = RouterContext::new(&buf, &cluster, &mut pp, ¶ms, transaction, 1).unwrap(); + let context = + RouterContext::new(&buf, &cluster, &mut pp, ¶ms, transaction, 1, None).unwrap(); let cmd = qp.parse(context).unwrap(); diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs index aa49a9f30..460c16ca6 100644 --- a/pgdog/src/frontend/router/rewrite/context.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -3,10 +3,7 @@ use pg_query::protobuf::{ParseResult, RawStmt}; use super::{output::RewriteActionKind, Error, RewriteAction, StepOutput}; -use crate::{ - frontend::PreparedStatements, - net::{Bind, Parse, ProtocolMessage, Query}, -}; +use crate::net::{Bind, Parse, ProtocolMessage, Query}; #[derive(Debug, Clone)] pub struct Context<'a> { @@ -21,20 +18,37 @@ pub struct Context<'a> { rewrite_bind: Option, /// Additional messages to add to the request. result: Vec, + /// Extended protocol. + parse: Option<&'a Parse>, } impl<'a> Context<'a> { /// Create new input. - pub fn new(original: &'a ParseResult, bind: Option<&'a Bind>) -> Self { + pub(super) fn new( + original: &'a ParseResult, + bind: Option<&'a Bind>, + parse: Option<&'a Parse>, + ) -> Self { Self { original, bind, rewrite: None, rewrite_bind: None, result: vec![], + parse, } } + /// Get Parse reference. + pub fn parse(&'a self) -> Option<&'a Parse> { + self.parse + } + + /// We are rewriting an extended protocol request. + pub fn extended(&self) -> bool { + self.parse().is_some() || self.bind().is_some() + } + /// Get the Bind message, if set. pub fn bind(&'a self) -> Option<&'a Bind> { if let Some(ref rewrite_bind) = self.rewrite_bind { @@ -55,6 +69,7 @@ impl<'a> Context<'a> { self.rewrite_bind.take() } + /// Put the bind message back. pub fn bind_put(&mut self, bind: Option) { self.rewrite_bind = bind; } @@ -82,11 +97,17 @@ impl<'a> Context<'a> { stmt.stmts.first_mut().ok_or(Error::EmptyQuery) } + /// Get protocol version from the original statement. pub fn proto_version(&self) -> i32 { self.original.version } - /// New request mutable reference. + /// Get the parse result (original or rewritten). + pub fn parse_result(&self) -> &ParseResult { + self.rewrite.as_ref().unwrap_or(&self.original) + } + + /// Prepend new message to rewritten request. pub fn prepend(&mut self, message: ProtocolMessage) { self.result.push(RewriteAction { message, @@ -94,47 +115,42 @@ impl<'a> Context<'a> { }); } - /// Assemble statement and add it to the global prepared statements cache. + /// Assemble rewrite instructions. pub fn build(mut self) -> Result { if self.rewrite.is_none() { Ok(StepOutput::NoOp) } else { let bind = self.rewrite_bind.take(); - let stmt = self.rewrite.take().ok_or(Error::NoRewrite)?.deparse()?; - let mut result = self.result; + let ast = self.rewrite.take().ok_or(Error::NoRewrite)?; + let stmt = ast.deparse()?; + let extended = self.extended(); + let mut parse = self.parse().cloned(); - if let Some(mut bind) = bind { - let mut parse = Parse::new_anonymous(stmt); - if bind.anonymous() { - result.push(RewriteAction { - message: parse.into(), - action: RewriteActionKind::Replace, - }); - result.push(RewriteAction { - message: bind.into(), - action: RewriteActionKind::Replace, - }); - } else { - let name = PreparedStatements::cache_rewritten(&parse); - parse.rename_fast(&name); - bind.rename(name); - result.push(RewriteAction { + let mut actions = self.result; + + if extended { + if let Some(mut parse) = parse.take() { + parse.set_query(&stmt); + actions.push(RewriteAction { message: parse.into(), action: RewriteActionKind::Replace, }); - result.push(RewriteAction { + } + + if let Some(bind) = bind { + actions.push(RewriteAction { message: bind.into(), action: RewriteActionKind::Replace, }); } } else { - result.push(RewriteAction { - message: Query::new(stmt).into(), + actions.push(RewriteAction { + message: Query::new(stmt.clone()).into(), action: RewriteActionKind::Replace, }); } - Ok(StepOutput::Rewrite(result)) + Ok(StepOutput::RewriteInPlace { stmt, ast, actions }) } } } diff --git a/pgdog/src/frontend/router/rewrite/error.rs b/pgdog/src/frontend/router/rewrite/error.rs index cd52af219..f29c825db 100644 --- a/pgdog/src/frontend/router/rewrite/error.rs +++ b/pgdog/src/frontend/router/rewrite/error.rs @@ -25,4 +25,10 @@ pub enum Error { #[error("prepared statement not found: {0}")] PreparedStatementNotFound(String), + + #[error("statement parameters and bind count mismatch")] + ParameterCountMismatch, + + #[error("parser: {0}")] + Parser(#[from] crate::frontend::router::parser::Error), } diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs index 9d036af2d..84ab4eb55 100644 --- a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -106,7 +106,7 @@ mod test { }, ], ); - let mut context = Context::new(&stmt.protobuf, Some(&bind)); + let mut context = Context::new(&stmt.protobuf, Some(&bind), None); let mut module = InsertSplitRewrite::default(); module.rewrite(&mut context).unwrap(); } diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index 72ae301c8..bbe57d880 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -12,12 +12,15 @@ pub mod insert_split; pub mod interface; pub mod output; pub mod prepared; +pub mod request; +pub mod state; pub mod unique_id; pub use context::Context; pub use error::Error; pub use interface::RewriteModule; pub use output::{RewriteAction, StepOutput}; +pub use request::RewriteRequest; use crate::frontend::PreparedStatements; diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index 6a6e3cf79..13b4e387d 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,3 +1,5 @@ +use pg_query::protobuf::ParseResult; + use crate::{frontend::ClientRequest, net::ProtocolMessage}; use std::mem::discriminant; @@ -45,7 +47,11 @@ pub(super) enum RewriteActionKind { #[derive(Debug, Clone)] pub enum StepOutput { NoOp, - Rewrite(Vec), + RewriteInPlace { + actions: Vec, + ast: ParseResult, + stmt: String, + }, } impl StepOutput { @@ -53,15 +59,7 @@ impl StepOutput { pub fn query(&self) -> Result<&str, ()> { match self { Self::NoOp => Err(()), - Self::Rewrite(actions) => { - for action in actions { - if let Some(query) = action.message.query() { - return Ok(query); - } - } - - Err(()) - } + Self::RewriteInPlace { stmt, .. } => Ok(stmt.as_str()), } } } diff --git a/pgdog/src/frontend/router/rewrite/prepared/execute.rs b/pgdog/src/frontend/router/rewrite/prepared/execute.rs index 9071c3460..b59dbed6a 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/execute.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/execute.rs @@ -67,13 +67,13 @@ mod test { // First prepare the statement let mut prepare_rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Context::new(&prepare_stmt, None); + let mut input = Context::new(&prepare_stmt, None, None); prepare_rewrite.rewrite(&mut input).unwrap(); // Now execute it let execute_stmt = pg_query::parse("EXECUTE test(1, 2, 3)").unwrap().protobuf; let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Context::new(&execute_stmt, None); + let mut input = Context::new(&execute_stmt, None, None); execute_rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -87,7 +87,7 @@ mod test { .protobuf; let prepared_statements = PreparedStatements::default(); let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Context::new(&execute_stmt, None); + let mut input = Context::new(&execute_stmt, None, None); let result = execute_rewrite.rewrite(&mut input); assert!(result.is_err()); } diff --git a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs index f4b63c2a1..d0f9f8d66 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs @@ -64,7 +64,7 @@ mod test { .protobuf; let mut prepared_statements = PreparedStatements::default(); let mut rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs new file mode 100644 index 000000000..756fd4efb --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -0,0 +1,108 @@ +use pg_query::ParseResult; +use tracing::debug; + +use super::{Context, Error, Rewrite, StepOutput}; +use crate::{ + backend::Cluster, + frontend::{ + router::{ + parser::{cache::CachedAst, Cache}, + rewrite::RewriteModule, + }, + ClientRequest, PreparedStatements, + }, + net::ProtocolMessage, +}; + +pub struct RewriteRequest<'a> { + request: &'a mut ClientRequest, + cluster: &'a Cluster, + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> RewriteRequest<'a> { + /// Perform new rewrite request. + pub fn new( + request: &'a mut ClientRequest, + cluster: &'a Cluster, + prepared_statements: &'a mut PreparedStatements, + ) -> Self { + Self { + request, + cluster, + prepared_statements, + } + } + + /// Execute rewrite and return the query AST. + pub fn execute(&'a mut self) -> Result { + let schema = self.cluster.sharding_schema(); + + let (result, ast, extended) = { + let mut parse = None; + let mut bind = None; + let mut ast = None; + + let schema = self.cluster.sharding_schema(); + + for message in self.request.iter() { + match message { + ProtocolMessage::Parse(p) => { + ast = Some(Cache::get().parse(p.query(), &schema)?); + self.prepared_statements + .save_original_ast(p.name(), ast.as_ref().unwrap()); + parse = Some(p); + } + + ProtocolMessage::Query(query) => { + ast = Some(Cache::get().parse_uncached(query.query(), &schema)?); + } + + ProtocolMessage::Bind(b) => { + let existing = self.prepared_statements.get_original_ast(b.statement()); + if let Some(existing) = existing { + ast = Some(existing.clone()); + bind = Some(b); + } + } + + _ => (), + } + } + + let ast = ast.ok_or(Error::EmptyQuery)?; + + let mut context = Context::new(&ast.ast().protobuf, bind, parse); + let mut rewrite = Rewrite::new(self.prepared_statements); + + let result = match rewrite.rewrite(&mut context) { + Ok(_) => context.build()?, + Err(Error::EmptyQuery) => StepOutput::NoOp, + Err(err) => return Err(err), + }; + + (result, ast, parse.is_some()) + }; + + let ast = match result { + StepOutput::NoOp => { + debug!("rewrite was a no-op"); + ast + } + StepOutput::RewriteInPlace { stmt, ast, actions } => { + debug!("rewrite in-place: {}", stmt); + for action in actions { + action.execute(self.request); + } + let ast = ParseResult::new(ast, "".into()); + // Cache new rewritten prepared statement. + if extended { + Cache::get().save(&stmt, ast, &schema).unwrap() + } else { + CachedAst::new_parsed(&stmt, ast, &schema).unwrap() + } + } + }; + Ok(ast) + } +} diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/state.rs @@ -0,0 +1 @@ + diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs index c73cd4eee..b683915b9 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -80,6 +80,8 @@ impl ExplainUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); + let mut parameter_counter = 0; if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt_mut()? @@ -90,7 +92,12 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::SelectStmt(select)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - SelectUniqueIdRewrite::rewrite_select(select, &mut bind)?; + SelectUniqueIdRewrite::rewrite_select( + select, + &mut bind, + extended, + &mut parameter_counter, + )?; } } @@ -122,6 +129,7 @@ impl ExplainUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt_mut()? @@ -132,7 +140,7 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::InsertStmt(insert)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - InsertUniqueIdRewrite::rewrite_insert(insert, &mut bind)?; + InsertUniqueIdRewrite::rewrite_insert(insert, &mut bind, extended)?; } } @@ -164,6 +172,8 @@ impl ExplainUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); + let mut param_counter = super::max_param_number(input.parse_result()); if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt_mut()? @@ -174,7 +184,12 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::UpdateStmt(update)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - UpdateUniqueIdRewrite::rewrite_update(update, &mut bind)?; + UpdateUniqueIdRewrite::rewrite_update( + update, + &mut bind, + extended, + &mut param_counter, + )?; } } @@ -197,7 +212,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -215,7 +230,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -234,7 +249,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -252,7 +267,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -265,7 +280,7 @@ mod test { fn test_explain_no_unique_id() { let stmt = pg_query::parse(r#"EXPLAIN SELECT 1"#).unwrap().protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index 101d434ff..c6af0c125 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -33,7 +33,9 @@ impl InsertUniqueIdRewrite { pub fn rewrite_insert( stmt: &mut InsertStmt, bind: &mut Option, + extended: bool, ) -> Result<(), Error> { + let mut param_counter = Self::param_count(stmt)?; let select = stmt .select_stmt .as_mut() @@ -42,16 +44,24 @@ impl InsertUniqueIdRewrite { .as_mut() .ok_or(Error::ParserError)?; - if let NodeEnum::SelectStmt(stmt) = select { - for tuple in stmt.values_lists.iter_mut() { + if let NodeEnum::SelectStmt(select_stmt) = select { + for tuple in select_stmt.values_lists.iter_mut() { if let Some(NodeEnum::List(ref mut tuple)) = tuple.node { for column in tuple.items.iter_mut() { if let Ok(Value::Function(name)) = Value::try_from(&column.node) { if name == "pgdog.unique_id" { let id = unique_id::UniqueId::generator()?.next_id(); - let node = if let Some(ref mut bind) = bind { - bigint_param(bind.add_parameter(Datum::Bigint(id))?) + let node = if extended { + param_counter += 1; + if let Some(ref mut bind) = bind { + let count = bind.add_parameter(Datum::Bigint(id))?; + // The number of parameters in the query doesn't match what's in the bind message. + if count != param_counter { + return Err(Error::ParameterCountMismatch); + } + } + bigint_param(param_counter) } else { bigint_const(id) }; @@ -66,6 +76,34 @@ impl InsertUniqueIdRewrite { Ok(()) } + + fn param_count(stmt: &InsertStmt) -> Result { + let mut max = 0; + + let select = stmt + .select_stmt + .as_ref() + .ok_or(Error::ParserError)? + .node + .as_ref() + .ok_or(Error::ParserError)?; + + if let NodeEnum::SelectStmt(stmt) = select { + for tuple in stmt.values_lists.iter() { + if let Some(NodeEnum::List(ref tuple)) = tuple.node { + for column in tuple.items.iter() { + if let Some(NodeEnum::ParamRef(ref param)) = column.node { + if param.number > max { + max = param.number; + } + } + } + } + } + } + + Ok(max) + } } impl RewriteModule for InsertUniqueIdRewrite { @@ -86,6 +124,7 @@ impl RewriteModule for InsertUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); if let Some(NodeEnum::InsertStmt(stmt)) = input .stmt_mut()? @@ -93,7 +132,7 @@ impl RewriteModule for InsertUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_insert(stmt, &mut bind)?; + Self::rewrite_insert(stmt, &mut bind, extended)?; } input.bind_put(bind); @@ -123,7 +162,7 @@ mod test { .unwrap() .protobuf; let mut insert = InsertUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); insert.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -162,7 +201,7 @@ mod test { }, ], ); - let mut input = Context::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind), None); InsertUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs index 7b08268e4..ff2733f3f 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -1,8 +1,8 @@ //! Unique ID rewrite engine. use pg_query::{ - protobuf::{a_const::Val, AConst, Node, ParamRef, TypeCast, TypeName}, - NodeEnum, + protobuf::{a_const::Val, AConst, Node, ParamRef, ParseResult, TypeCast, TypeName}, + NodeEnum, NodeRef, }; pub mod explain; @@ -74,3 +74,19 @@ fn bigint_const(id: i64) -> NodeEnum { ..Default::default() })) } + +/// Find the maximum parameter number ($N) in a parse result. +pub fn max_param_number(result: &ParseResult) -> i32 { + result + .nodes() + .iter() + .filter_map(|(node, _, _, _)| { + if let NodeRef::ParamRef(p) = node { + Some(p.number) + } else { + None + } + }) + .max() + .unwrap_or(0) +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index 9b2375ef6..8ec520019 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -9,7 +9,11 @@ use super::{ super::{Context, Error, RewriteModule}, bigint_const, bigint_param, }; -use crate::{frontend::router::parser::Value, net::Datum, unique_id}; +use crate::{ + frontend::router::{parser::Value, rewrite::unique_id::max_param_number}, + net::Datum, + unique_id, +}; #[derive(Default)] pub struct SelectUniqueIdRewrite {} @@ -94,6 +98,8 @@ impl SelectUniqueIdRewrite { pub fn rewrite_select( stmt: &mut SelectStmt, bind: &mut Option, + extended: bool, + paramter_counter: &mut i32, ) -> Result<(), Error> { // Rewrite target_list for target in stmt.target_list.iter_mut() { @@ -103,8 +109,17 @@ impl SelectUniqueIdRewrite { if name == "pgdog.unique_id" { let id = unique_id::UniqueId::generator()?.next_id(); - let node = if let Some(ref mut bind) = bind { - bigint_param(bind.add_parameter(Datum::Bigint(id))?) + let node = if extended { + *paramter_counter += 1; + + if let Some(bind) = bind { + let counter = bind.add_parameter(Datum::Bigint(id))?; + if counter != *paramter_counter { + return Err(Error::ParameterCountMismatch); + } + } + + bigint_param(*paramter_counter) } else { bigint_const(id) }; @@ -122,7 +137,7 @@ impl SelectUniqueIdRewrite { if let Some(NodeEnum::CommonTableExpr(ref mut expr)) = cte.node { if let Some(ref mut query) = expr.ctequery { if let Some(NodeEnum::SelectStmt(ref mut inner)) = query.node { - Self::rewrite_select(inner, bind)?; + Self::rewrite_select(inner, bind, extended, paramter_counter)?; } } } @@ -131,15 +146,15 @@ impl SelectUniqueIdRewrite { // Rewrite subqueries in FROM clause for from in stmt.from_clause.iter_mut() { - Self::rewrite_from_node(from, bind)?; + Self::rewrite_from_node(from, bind, extended, paramter_counter)?; } // Rewrite UNION/INTERSECT/EXCEPT (larg/rarg are Box) if let Some(ref mut larg) = stmt.larg { - Self::rewrite_select(larg, bind)?; + Self::rewrite_select(larg, bind, extended, paramter_counter)?; } if let Some(ref mut rarg) = stmt.rarg { - Self::rewrite_select(rarg, bind)?; + Self::rewrite_select(rarg, bind, extended, paramter_counter)?; } Ok(()) @@ -148,21 +163,23 @@ impl SelectUniqueIdRewrite { fn rewrite_from_node( node: &mut Node, bind: &mut Option, + extended: bool, + paramter_counter: &mut i32, ) -> Result<(), Error> { match node.node.as_mut() { Some(NodeEnum::RangeSubselect(ref mut subselect)) => { if let Some(ref mut subquery) = subselect.subquery { if let Some(NodeEnum::SelectStmt(ref mut inner)) = subquery.node { - Self::rewrite_select(inner, bind)?; + Self::rewrite_select(inner, bind, extended, paramter_counter)?; } } } Some(NodeEnum::JoinExpr(ref mut join)) => { if let Some(ref mut larg) = join.larg { - Self::rewrite_from_node(larg, bind)?; + Self::rewrite_from_node(larg, bind, extended, paramter_counter)?; } if let Some(ref mut rarg) = join.rarg { - Self::rewrite_from_node(rarg, bind)?; + Self::rewrite_from_node(rarg, bind, extended, paramter_counter)?; } } _ => {} @@ -189,6 +206,8 @@ impl RewriteModule for SelectUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); + let mut parameter_counter = max_param_number(input.parse_result()); if let Some(NodeEnum::SelectStmt(stmt)) = input .stmt_mut()? @@ -196,7 +215,7 @@ impl RewriteModule for SelectUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_select(stmt, &mut bind)?; + Self::rewrite_select(stmt, &mut bind, extended, &mut parameter_counter)?; } input.bind_put(bind); @@ -220,7 +239,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); println!("output: {}", output.query().unwrap()); @@ -242,7 +261,7 @@ mod test { data: "test".into(), }], ); - let mut input = Context::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind), None); SelectUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); @@ -263,7 +282,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -278,7 +297,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -295,7 +314,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -305,7 +324,7 @@ mod test { fn test_no_rewrite_when_no_unique_id() { let stmt = pg_query::parse(r#"SELECT id FROM users"#).unwrap().protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index 462c37164..ee4c5041b 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -4,7 +4,7 @@ use pg_query::{protobuf::UpdateStmt, NodeEnum}; use super::{ super::{Context, Error, RewriteModule}, - bigint_const, bigint_param, + bigint_const, bigint_param, max_param_number, }; use crate::{frontend::router::parser::Value, net::Datum, unique_id}; @@ -30,6 +30,8 @@ impl UpdateUniqueIdRewrite { pub fn rewrite_update( stmt: &mut UpdateStmt, bind: &mut Option, + extended: bool, + param_counter: &mut i32, ) -> Result<(), Error> { for target in stmt.target_list.iter_mut() { if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { @@ -38,8 +40,15 @@ impl UpdateUniqueIdRewrite { if name == "pgdog.unique_id" { let id = unique_id::UniqueId::generator()?.next_id(); - let node = if let Some(ref mut bind) = bind { - bigint_param(bind.add_parameter(Datum::Bigint(id))?) + let node = if extended { + *param_counter += 1; + if let Some(ref mut bind) = bind { + let count = bind.add_parameter(Datum::Bigint(id))?; + if count != *param_counter { + return Err(Error::ParameterCountMismatch); + } + } + bigint_param(*param_counter) } else { bigint_const(id) }; @@ -72,6 +81,8 @@ impl RewriteModule for UpdateUniqueIdRewrite { } let mut bind = input.bind_take(); + let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); if let Some(NodeEnum::UpdateStmt(stmt)) = input .stmt_mut()? @@ -79,7 +90,7 @@ impl RewriteModule for UpdateUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_update(stmt, &mut bind)?; + Self::rewrite_update(stmt, &mut bind, extended, &mut param_counter)?; } input.bind_put(bind); @@ -104,7 +115,7 @@ mod test { .unwrap() .protobuf; let mut update = UpdateUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None); + let mut input = Context::new(&stmt, None, None); update.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -133,7 +144,7 @@ mod test { }, ], ); - let mut input = Context::new(&stmt, Some(&bind)); + let mut input = Context::new(&stmt, Some(&bind), None); UpdateUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); From a3ebfc618973009afe876d0da191ff710b212830 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 1 Dec 2025 18:24:22 -0800 Subject: [PATCH 14/23] stats --- CLAUDE.md | 40 +----- pgdog/src/admin/mod.rs | 1 + pgdog/src/admin/parser.rs | 6 +- pgdog/src/admin/show_mirrors.rs | 2 +- pgdog/src/admin/show_rewrite.rs | 49 ++++++++ pgdog/src/admin/tests/mod.rs | 4 +- pgdog/src/backend/pool/cluster.rs | 8 +- .../{mirror_stats.rs => cluster_stats.rs} | 41 +++--- .../backend/pool/connection/mirror/handler.rs | 58 ++++----- .../src/backend/pool/connection/mirror/mod.rs | 14 +-- pgdog/src/backend/pool/mod.rs | 4 +- pgdog/src/frontend/client/query_engine/mod.rs | 4 +- pgdog/src/frontend/prepared_statements/mod.rs | 15 +-- pgdog/src/frontend/router/rewrite/context.rs | 17 ++- pgdog/src/frontend/router/rewrite/mod.rs | 2 + pgdog/src/frontend/router/rewrite/output.rs | 2 + pgdog/src/frontend/router/rewrite/request.rs | 34 +++-- pgdog/src/frontend/router/rewrite/state.rs | 36 ++++++ pgdog/src/frontend/router/rewrite/stats.rs | 22 ++++ pgdog/src/stats/http_server.rs | 11 +- pgdog/src/stats/mirror_stats.rs | 2 +- pgdog/src/stats/mod.rs | 2 + pgdog/src/stats/rewrite_stats.rs | 119 ++++++++++++++++++ 23 files changed, 362 insertions(+), 131 deletions(-) create mode 100644 pgdog/src/admin/show_rewrite.rs rename pgdog/src/backend/pool/{mirror_stats.rs => cluster_stats.rs} (82%) create mode 100644 pgdog/src/frontend/router/rewrite/stats.rs create mode 100644 pgdog/src/stats/rewrite_stats.rs diff --git a/CLAUDE.md b/CLAUDE.md index 5e70dfc80..0cba01475 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,45 +7,15 @@ # Code style -Use standard Rust code style. Use `cargo fmt` to reformat code automatically after every edit. -Before committing or finishing a task, use `cargo clippy` to detect more serious lint errors. - -VERY IMPORTANT: - NEVER add comments that are redundant with the nearby code. - ALWAYS be sure comments document "Why", not "What", the code is doing. - ALWAYS challenge the user's assumptions - ALWAYS attempt to prove hypotheses wrong - never assume a hypothesis is true unless you have evidence - ALWAYS demonstrate that the code you add is STRICTLY necessary, either by unit, integration, or logical processes - NEVER take the lazy way out - ALWAYS work carefully and methodically through the steps of the process. - NEVER use quick fixes. Always carefully work through the problem unless specifically asked. - ALWAYS Ask clarifying questions before implementing - ALWAYS Break large tasks into single-session chunks - -VERY IMPORTANT: you are to act as a detective, attempting to find ways to falsify the code or planning we've done by discovering gaps or inconsistencies. ONLY write code when it is absolutely required to pass tests, the build, or typecheck. - -VERY IMPORTANT: NEVER comment out code or skip tests unless specifically requested by the user - -## Principles -- **Data first**: Define types before implementation -- **Small Modules**: Try to keep files under 200 lines, unless required by implementation. NEVER allow files to exceed 1000 lines unless specifically instructed. +- Use standard Rust code style. +- Use `cargo fmt` to reformat code automatically after every edit. +- Don't write functions with many arguments: create a struct and use that as input instead. # Workflow - Prefer to run individual tests with `cargo nextest run --test-threads=1 --no-fail-fast `. This is much faster. -- A local PostgreSQL server is required for some tests to pass. Ensure it is set up, and if necessary create a database called "pgdog", and create a user called "pgdog" with password "pgdog". -- Focus on files in `./pgdog` and `./integration` - other files are LOWEST priority - -## Test-Driven Development (TDD) - STRICT ENFORCEMENT -- **MANDATORY WORKFLOW - NO EXCEPTIONS:** - 1. Write exactly ONE test that fails - 2. Write ONLY the minimal code to make that test pass - 3. Refactor if needed (tests must still pass) - 4. Return to step 1 for next test -- **CRITICAL RULES:** - - NO implementation code without a failing test first - - NO untested code is allowed to exist - - Every line of production code must be justified by a test +- A local PostgreSQL server is required for some tests to pass. Assume it's running, if not, stop and ask the user to start it. +- Coe # About the project diff --git a/pgdog/src/admin/mod.rs b/pgdog/src/admin/mod.rs index 521119c8c..3e4732d4c 100644 --- a/pgdog/src/admin/mod.rs +++ b/pgdog/src/admin/mod.rs @@ -30,6 +30,7 @@ pub mod show_pools; pub mod show_prepared_statements; pub mod show_query_cache; pub mod show_replication; +pub mod show_rewrite; pub mod show_server_memory; pub mod show_servers; pub mod show_stats; diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 5d94ab445..16a2ee929 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -7,7 +7,7 @@ use super::{ show_client_memory::ShowClientMemory, show_clients::ShowClients, show_config::ShowConfig, show_instance_id::ShowInstanceId, show_lists::ShowLists, show_mirrors::ShowMirrors, show_peers::ShowPeers, show_pools::ShowPools, show_prepared_statements::ShowPreparedStatements, - show_query_cache::ShowQueryCache, show_replication::ShowReplication, + show_query_cache::ShowQueryCache, show_replication::ShowReplication, show_rewrite::ShowRewrite, show_server_memory::ShowServerMemory, show_servers::ShowServers, show_stats::ShowStats, show_transactions::ShowTransactions, show_version::ShowVersion, shutdown::Shutdown, Command, Error, @@ -30,6 +30,7 @@ pub enum ParseResult { ShowStats(ShowStats), ShowTransactions(ShowTransactions), ShowMirrors(ShowMirrors), + ShowRewrite(ShowRewrite), ShowVersion(ShowVersion), ShowInstanceId(ShowInstanceId), SetupSchema(SetupSchema), @@ -65,6 +66,7 @@ impl ParseResult { ShowStats(show_stats) => show_stats.execute().await, ShowTransactions(show_transactions) => show_transactions.execute().await, ShowMirrors(show_mirrors) => show_mirrors.execute().await, + ShowRewrite(show_rewrite) => show_rewrite.execute().await, ShowVersion(show_version) => show_version.execute().await, ShowInstanceId(show_instance_id) => show_instance_id.execute().await, SetupSchema(setup_schema) => setup_schema.execute().await, @@ -100,6 +102,7 @@ impl ParseResult { ShowStats(show_stats) => show_stats.name(), ShowTransactions(show_transactions) => show_transactions.name(), ShowMirrors(show_mirrors) => show_mirrors.name(), + ShowRewrite(show_rewrite) => show_rewrite.name(), ShowVersion(show_version) => show_version.name(), ShowInstanceId(show_instance_id) => show_instance_id.name(), SetupSchema(setup_schema) => setup_schema.name(), @@ -163,6 +166,7 @@ impl Parser { "lists" => ParseResult::ShowLists(ShowLists::parse(&sql)?), "prepared" => ParseResult::ShowPrepared(ShowPreparedStatements::parse(&sql)?), "replication" => ParseResult::ShowReplication(ShowReplication::parse(&sql)?), + "rewrite" => ParseResult::ShowRewrite(ShowRewrite::parse(&sql)?), command => { debug!("unknown admin show command: '{}'", command); return Err(Error::Syntax); diff --git a/pgdog/src/admin/show_mirrors.rs b/pgdog/src/admin/show_mirrors.rs index 896e62e61..1ae61a546 100644 --- a/pgdog/src/admin/show_mirrors.rs +++ b/pgdog/src/admin/show_mirrors.rs @@ -35,7 +35,7 @@ impl Command for ShowMirrors { let counts = { let stats = cluster.stats(); let stats = stats.lock(); - stats.counts + stats.mirrors }; // Create a data row for this cluster diff --git a/pgdog/src/admin/show_rewrite.rs b/pgdog/src/admin/show_rewrite.rs new file mode 100644 index 000000000..df8d4cfbf --- /dev/null +++ b/pgdog/src/admin/show_rewrite.rs @@ -0,0 +1,49 @@ +//! SHOW REWRITE - per-cluster rewrite statistics + +use crate::backend::databases::databases; + +use super::prelude::*; + +pub struct ShowRewrite; + +#[async_trait] +impl Command for ShowRewrite { + fn name(&self) -> String { + "SHOW REWRITE".into() + } + + fn parse(_: &str) -> Result { + Ok(Self) + } + + async fn execute(&self) -> Result, Error> { + let fields = vec![ + Field::text("database"), + Field::text("user"), + Field::numeric("parse"), + Field::numeric("bind"), + Field::numeric("simple"), + ]; + + let mut messages = vec![RowDescription::new(&fields).message()?]; + + for (user, cluster) in databases().all() { + let rewrite = { + let stats = cluster.stats(); + let stats = stats.lock(); + stats.rewrite + }; + + let mut dr = DataRow::new(); + dr.add(user.database.as_str()) + .add(user.user.as_str()) + .add(rewrite.parse as i64) + .add(rewrite.bind as i64) + .add(rewrite.simple as i64); + + messages.push(dr.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/tests/mod.rs b/pgdog/src/admin/tests/mod.rs index a5fda216d..57f709d04 100644 --- a/pgdog/src/admin/tests/mod.rs +++ b/pgdog/src/admin/tests/mod.rs @@ -1,6 +1,6 @@ use crate::admin::Command; use crate::backend::databases::{databases, from_config, replace_databases, Databases}; -use crate::backend::pool::mirror_stats::Counts; +use crate::backend::pool::cluster_stats::MirrorStats; use crate::config::{self, ConfigAndUsers, Database, Role, User as ConfigUser}; use crate::net::messages::{DataRow, DataType, FromBytes, Protocol, RowDescription}; @@ -226,7 +226,7 @@ async fn show_mirrors_reports_counts() { { let cluster_stats = cluster.stats(); let mut stats = cluster_stats.lock(); - stats.counts = Counts { + stats.mirrors = MirrorStats { total_count: 5, mirrored_count: 4, dropped_count: 1, diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 73ae20762..b0c606e9a 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -26,7 +26,7 @@ use crate::{ net::{messages::BackendKeyData, Query}, }; -use super::{Address, Config, Error, Guard, MirrorStats, Request, Shard, ShardConfig}; +use super::{Address, ClusterStats, Config, Error, Guard, Request, Shard, ShardConfig}; use crate::config::LoadBalancingStrategy; #[derive(Clone, Debug, Default)] @@ -53,7 +53,7 @@ pub struct Cluster { multi_tenant: Option, rw_strategy: ReadWriteStrategy, schema_admin: bool, - stats: Arc>, + stats: Arc>, cross_shard_disabled: bool, two_phase_commit: bool, two_phase_commit_auto: bool, @@ -245,7 +245,7 @@ impl Cluster { multi_tenant: multi_tenant.clone(), rw_strategy, schema_admin, - stats: Arc::new(Mutex::new(MirrorStats::default())), + stats: Arc::new(Mutex::new(ClusterStats::default())), cross_shard_disabled, two_phase_commit: two_pc && shards.len() > 1, two_phase_commit_auto: two_pc_auto && shards.len() > 1, @@ -409,7 +409,7 @@ impl Cluster { self.schema_admin = owner; } - pub fn stats(&self) -> Arc> { + pub fn stats(&self) -> Arc> { self.stats.clone() } diff --git a/pgdog/src/backend/pool/mirror_stats.rs b/pgdog/src/backend/pool/cluster_stats.rs similarity index 82% rename from pgdog/src/backend/pool/mirror_stats.rs rename to pgdog/src/backend/pool/cluster_stats.rs index a31d857b7..7d540ced3 100644 --- a/pgdog/src/backend/pool/mirror_stats.rs +++ b/pgdog/src/backend/pool/cluster_stats.rs @@ -5,8 +5,10 @@ use std::{ ops::{Add, Div, Sub}, }; +use crate::frontend::router::rewrite::stats::RewriteStats; + #[derive(Debug, Clone, Default, Copy)] -pub struct Counts { +pub struct MirrorStats { pub total_count: usize, pub mirrored_count: usize, pub dropped_count: usize, @@ -14,8 +16,8 @@ pub struct Counts { pub queue_length: usize, } -impl Sub for Counts { - type Output = Counts; +impl Sub for MirrorStats { + type Output = MirrorStats; fn sub(self, rhs: Self) -> Self::Output { Self { @@ -28,8 +30,8 @@ impl Sub for Counts { } } -impl Div for Counts { - type Output = Counts; +impl Div for MirrorStats { + type Output = MirrorStats; fn div(self, rhs: usize) -> Self::Output { Self { @@ -42,11 +44,11 @@ impl Div for Counts { } } -impl Add for Counts { - type Output = Counts; +impl Add for MirrorStats { + type Output = MirrorStats; - fn add(self, rhs: Counts) -> Self::Output { - Counts { + fn add(self, rhs: MirrorStats) -> Self::Output { + MirrorStats { total_count: self.total_count + rhs.total_count, mirrored_count: self.mirrored_count + rhs.mirrored_count, dropped_count: self.dropped_count + rhs.dropped_count, @@ -56,9 +58,9 @@ impl Add for Counts { } } -impl Sum for Counts { +impl Sum for MirrorStats { fn sum>(iter: I) -> Self { - let mut result = Counts::default(); + let mut result = MirrorStats::default(); for next in iter { result = result + next; } @@ -68,8 +70,9 @@ impl Sum for Counts { } #[derive(Debug, Clone, Default, Copy)] -pub struct MirrorStats { - pub counts: Counts, +pub struct ClusterStats { + pub mirrors: MirrorStats, + pub rewrite: RewriteStats, } #[cfg(test)] @@ -78,16 +81,16 @@ mod tests { #[test] fn test_queue_length_default_is_zero() { - let stats = MirrorStats::default(); + let stats = ClusterStats::default(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 by default" ); } #[test] fn test_queue_length_arithmetic_operations() { - let counts1 = Counts { + let counts1 = MirrorStats { total_count: 10, mirrored_count: 5, dropped_count: 3, @@ -95,7 +98,7 @@ mod tests { queue_length: 7, }; - let counts2 = Counts { + let counts2 = MirrorStats { total_count: 5, mirrored_count: 3, dropped_count: 1, @@ -127,7 +130,7 @@ mod tests { #[test] fn test_queue_length_saturating_sub() { - let counts1 = Counts { + let counts1 = MirrorStats { total_count: 10, mirrored_count: 5, dropped_count: 3, @@ -135,7 +138,7 @@ mod tests { queue_length: 3, }; - let counts2 = Counts { + let counts2 = MirrorStats { total_count: 5, mirrored_count: 3, dropped_count: 1, diff --git a/pgdog/src/backend/pool/connection/mirror/handler.rs b/pgdog/src/backend/pool/connection/mirror/handler.rs index 92c2e406c..86147b276 100644 --- a/pgdog/src/backend/pool/connection/mirror/handler.rs +++ b/pgdog/src/backend/pool/connection/mirror/handler.rs @@ -4,7 +4,7 @@ //! use super::*; -use crate::backend::pool::MirrorStats; +use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; @@ -35,7 +35,7 @@ pub struct MirrorHandler { /// Request timer, to simulate delays between queries. timer: Instant, /// Reference to cluster stats for tracking mirror metrics. - stats: Arc>, + stats: Arc>, } impl MirrorHandler { @@ -45,7 +45,7 @@ impl MirrorHandler { } /// Create new mirror handle with exposure. - pub fn new(tx: Sender, exposure: f32, stats: Arc>) -> Self { + pub fn new(tx: Sender, exposure: f32, stats: Arc>) -> Self { Self { tx, exposure, @@ -139,44 +139,44 @@ impl MirrorHandler { /// Increment the total request count. pub fn increment_total_count(&self) { let mut stats = self.stats.lock(); - stats.counts.total_count += 1; + stats.mirrors.total_count += 1; } /// Increment the mirrored request count. pub fn increment_mirrored_count(&self) { let mut stats = self.stats.lock(); - stats.counts.mirrored_count += 1; + stats.mirrors.mirrored_count += 1; } /// Increment the dropped request count. pub fn increment_dropped_count(&self) { let mut stats = self.stats.lock(); - stats.counts.dropped_count += 1; + stats.mirrors.dropped_count += 1; } /// Increment the error count. pub fn increment_error_count(&self) { let mut stats = self.stats.lock(); - stats.counts.error_count += 1; + stats.mirrors.error_count += 1; } /// Increment the queue length. pub fn increment_queue_length(&self) { let mut stats = self.stats.lock(); - stats.counts.queue_length += 1; + stats.mirrors.queue_length += 1; } /// Decrement the queue length. pub fn decrement_queue_length(&self) { let mut stats = self.stats.lock(); - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); } } #[cfg(test)] mod tests { use super::*; - use crate::backend::pool::MirrorStats; + use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; use tokio::sync::mpsc::{channel, Receiver}; @@ -185,22 +185,22 @@ mod tests { exposure: f32, ) -> ( MirrorHandler, - Arc>, + Arc>, Receiver, ) { let (tx, rx) = channel(1000); // Keep receiver to prevent channel closure - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let handler = MirrorHandler::new(tx, exposure, stats.clone()); (handler, stats, rx) } - fn get_stats_counts(stats: &Arc>) -> (usize, usize, usize, usize) { + fn get_stats_counts(stats: &Arc>) -> (usize, usize, usize, usize) { let stats = stats.lock(); ( - stats.counts.total_count, - stats.counts.mirrored_count, - stats.counts.dropped_count, - stats.counts.error_count, + stats.mirrors.total_count, + stats.mirrors.mirrored_count, + stats.mirrors.dropped_count, + stats.mirrors.error_count, ) } @@ -344,7 +344,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 initially" ); } @@ -358,7 +358,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should still be 0 before flush" ); } @@ -368,7 +368,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 1, + stats.mirrors.queue_length, 1, "queue_length should be 1 after flush" ); } @@ -388,7 +388,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 initially" ); } @@ -403,17 +403,17 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should remain 0 for dropped transactions" ); - assert_eq!(stats.counts.dropped_count, 1, "dropped_count should be 1"); + assert_eq!(stats.mirrors.dropped_count, 1, "dropped_count should be 1"); } } #[test] fn test_queue_length_with_channel_overflow() { let (tx, _rx) = channel(1); // Channel with capacity of 1 - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let mut handler = MirrorHandler::new(tx, 1.0, stats.clone()); // Fill the channel @@ -428,11 +428,11 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 1, + stats.mirrors.queue_length, 1, "queue_length should be 1 (first successful send)" ); assert_eq!( - stats.counts.error_count, 1, + stats.mirrors.error_count, 1, "error_count should be 1 due to overflow" ); } @@ -441,16 +441,16 @@ mod tests { #[test] fn test_queue_length_never_negative() { // Test to ensure queue_length never goes negative even with mismatched increment/decrement - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); // Manually try to decrement without incrementing (should use saturating_sub) // This will be tested more thoroughly once decrement_queue_length is implemented { let mut stats = stats.lock(); // Simulating a decrement when queue is already 0 - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should not go negative" ); } diff --git a/pgdog/src/backend/pool/connection/mirror/mod.rs b/pgdog/src/backend/pool/connection/mirror/mod.rs index 28363dea4..7acc88a07 100644 --- a/pgdog/src/backend/pool/connection/mirror/mod.rs +++ b/pgdog/src/backend/pool/connection/mirror/mod.rs @@ -119,14 +119,14 @@ impl Mirror { // Decrement queue_length when we receive a message from the channel { let mut stats = stats_for_errors.lock(); - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); } // TODO: timeout these. if let Err(err) = mirror.handle(&mut req, &mut query_engine).await { error!("mirror error: {}", err); // Increment error count on mirror handling error let mut stats = stats_for_errors.lock(); - stats.counts.error_count += 1; + stats.mirrors.error_count += 1; } } else { debug!("mirror client shutting down"); @@ -170,12 +170,12 @@ mod test { #[tokio::test] async fn test_mirror_exposure() { - use crate::backend::pool::MirrorStats; + use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; let (tx, rx) = channel(25); - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let mut handle = MirrorHandler::new(tx.clone(), 1.0, stats.clone()); for _ in 0..25 { @@ -189,7 +189,7 @@ mod test { assert_eq!(rx.len(), 25); let (tx, rx) = channel(25); - let stats2 = Arc::new(Mutex::new(MirrorStats::default())); + let stats2 = Arc::new(Mutex::new(ClusterStats::default())); let mut handle = MirrorHandler::new(tx.clone(), 0.5, stats2); let dropped = (0..25) .into_iter() @@ -270,7 +270,7 @@ mod test { let initial_stats = { let stats_arc = cluster.stats(); let stats = stats_arc.lock(); - stats.counts + stats.mirrors }; let mut mirror = Mirror::spawn("pgdog", &cluster, None).unwrap(); @@ -290,7 +290,7 @@ mod test { let final_stats = { let stats_arc = cluster.stats(); let stats = stats_arc.lock(); - stats.counts + stats.mirrors }; assert_eq!( diff --git a/pgdog/src/backend/pool/mod.rs b/pgdog/src/backend/pool/mod.rs index b4f54d2cd..e79d47b32 100644 --- a/pgdog/src/backend/pool/mod.rs +++ b/pgdog/src/backend/pool/mod.rs @@ -3,6 +3,7 @@ pub mod address; pub mod cleanup; pub mod cluster; +pub mod cluster_stats; pub mod comms; pub mod config; pub mod connection; @@ -13,7 +14,6 @@ pub mod healthcheck; pub mod inner; pub mod lsn_monitor; pub mod mapping; -pub mod mirror_stats; pub mod monitor; pub mod oids; pub mod pool_impl; @@ -27,13 +27,13 @@ pub mod waiting; pub use address::Address; pub use cluster::{Cluster, ClusterConfig, ClusterShardConfig, PoolConfig, ShardingSchema}; +pub use cluster_stats::ClusterStats; pub use config::Config; pub use connection::Connection; pub use error::Error; pub use guard::Guard; pub use healthcheck::Healtcheck; pub use lsn_monitor::LsnStats; -pub use mirror_stats::MirrorStats; pub use monitor::Monitor; pub use oids::Oids; pub use pool_impl::Pool; diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 5be73ebba..b5901d962 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -5,7 +5,7 @@ use crate::{ client::query_engine::hooks::QueryEngineHooks, router::{ parser::Shard, - rewrite::{self, RewriteRequest}, + rewrite::{self, RewriteRequest, RewriteState}, Route, }, BufferedQuery, Client, Command, Comms, Error, Router, RouterContext, Stats, @@ -61,6 +61,7 @@ pub struct QueryEngine { notify_buffer: NotifyBuffer, pending_explain: Option, hooks: QueryEngineHooks, + rewrite_state: RewriteState, } impl QueryEngine { @@ -131,6 +132,7 @@ impl QueryEngine { context.client_request, self.backend.cluster()?, context.prepared_statements, + &mut self.rewrite_state, ); match rewrite.execute() { Ok(ast) => context.ast = Some(ast), diff --git a/pgdog/src/frontend/prepared_statements/mod.rs b/pgdog/src/frontend/prepared_statements/mod.rs index 255055aef..2fd75418f 100644 --- a/pgdog/src/frontend/prepared_statements/mod.rs +++ b/pgdog/src/frontend/prepared_statements/mod.rs @@ -9,7 +9,7 @@ use tracing::debug; use crate::{ config::{config, PreparedStatements as PreparedStatementsLevel}, - frontend::router::parser::{cache::CachedAst, RewritePlan}, + frontend::router::parser::RewritePlan, net::{Parse, ProtocolMessage}, stats::memory::MemoryUsage, }; @@ -31,7 +31,6 @@ pub struct PreparedStatements { pub(super) local: HashMap, pub(super) level: PreparedStatementsLevel, pub(super) memory_used: usize, - pub(super) rewrite: HashMap, } impl MemoryUsage for PreparedStatements { @@ -50,7 +49,6 @@ impl Default for PreparedStatements { local: HashMap::default(), level: PreparedStatementsLevel::Extended, memory_used: 0, - rewrite: HashMap::new(), } } } @@ -99,17 +97,6 @@ impl PreparedStatements { parse.rename_fast(&name) } - /// Get original AST for a prepared statement - /// we have rewritten. - pub fn get_original_ast(&self, name: &str) -> Option<&CachedAst> { - self.rewrite.get(name) - } - - /// Save original AST for re-use by subsequent Bind messages. - pub fn save_original_ast(&mut self, name: &str, ast: &CachedAst) { - self.rewrite.insert(name.to_string(), ast.clone()); - } - /// Retrieve stored rewrite plan for a prepared statement, if any. pub fn rewrite_plan(&self, name: &str) -> Option { self.global.read().rewrite_plan(name) diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs index 460c16ca6..f050301f8 100644 --- a/pgdog/src/frontend/router/rewrite/context.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -1,8 +1,8 @@ -//! Rewrite input and output types. +//! Context passed throughout the rewrite engine. use pg_query::protobuf::{ParseResult, RawStmt}; -use super::{output::RewriteActionKind, Error, RewriteAction, StepOutput}; +use super::{output::RewriteActionKind, stats::RewriteStats, Error, RewriteAction, StepOutput}; use crate::net::{Bind, Parse, ProtocolMessage, Query}; #[derive(Debug, Clone)] @@ -104,7 +104,7 @@ impl<'a> Context<'a> { /// Get the parse result (original or rewritten). pub fn parse_result(&self) -> &ParseResult { - self.rewrite.as_ref().unwrap_or(&self.original) + self.rewrite.as_ref().unwrap_or(self.original) } /// Prepend new message to rewritten request. @@ -120,6 +120,7 @@ impl<'a> Context<'a> { if self.rewrite.is_none() { Ok(StepOutput::NoOp) } else { + let mut stats = RewriteStats::default(); let bind = self.rewrite_bind.take(); let ast = self.rewrite.take().ok_or(Error::NoRewrite)?; let stmt = ast.deparse()?; @@ -135,6 +136,7 @@ impl<'a> Context<'a> { message: parse.into(), action: RewriteActionKind::Replace, }); + stats.parse += 1; } if let Some(bind) = bind { @@ -142,15 +144,22 @@ impl<'a> Context<'a> { message: bind.into(), action: RewriteActionKind::Replace, }); + stats.bind += 1; } } else { actions.push(RewriteAction { message: Query::new(stmt.clone()).into(), action: RewriteActionKind::Replace, }); + stats.simple += 1; } - Ok(StepOutput::RewriteInPlace { stmt, ast, actions }) + Ok(StepOutput::RewriteInPlace { + stmt, + ast, + actions, + stats, + }) } } } diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index bbe57d880..77a77e10c 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -14,6 +14,7 @@ pub mod output; pub mod prepared; pub mod request; pub mod state; +pub mod stats; pub mod unique_id; pub use context::Context; @@ -21,6 +22,7 @@ pub use error::Error; pub use interface::RewriteModule; pub use output::{RewriteAction, StepOutput}; pub use request::RewriteRequest; +pub use state::RewriteState; use crate::frontend::PreparedStatements; diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index 13b4e387d..9c2049a1b 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,5 +1,6 @@ use pg_query::protobuf::ParseResult; +use super::stats::RewriteStats; use crate::{frontend::ClientRequest, net::ProtocolMessage}; use std::mem::discriminant; @@ -51,6 +52,7 @@ pub enum StepOutput { actions: Vec, ast: ParseResult, stmt: String, + stats: RewriteStats, }, } diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs index 756fd4efb..90ee2f1df 100644 --- a/pgdog/src/frontend/router/rewrite/request.rs +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -1,14 +1,11 @@ use pg_query::ParseResult; use tracing::debug; -use super::{Context, Error, Rewrite, StepOutput}; +use super::{Context, Error, Rewrite, RewriteModule, RewriteState, StepOutput}; use crate::{ backend::Cluster, frontend::{ - router::{ - parser::{cache::CachedAst, Cache}, - rewrite::RewriteModule, - }, + router::parser::{cache::CachedAst, Cache}, ClientRequest, PreparedStatements, }, net::ProtocolMessage, @@ -18,6 +15,7 @@ pub struct RewriteRequest<'a> { request: &'a mut ClientRequest, cluster: &'a Cluster, prepared_statements: &'a mut PreparedStatements, + state: &'a mut RewriteState, } impl<'a> RewriteRequest<'a> { @@ -26,11 +24,13 @@ impl<'a> RewriteRequest<'a> { request: &'a mut ClientRequest, cluster: &'a Cluster, prepared_statements: &'a mut PreparedStatements, + state: &'a mut RewriteState, ) -> Self { Self { request, cluster, prepared_statements, + state, } } @@ -49,8 +49,7 @@ impl<'a> RewriteRequest<'a> { match message { ProtocolMessage::Parse(p) => { ast = Some(Cache::get().parse(p.query(), &schema)?); - self.prepared_statements - .save_original_ast(p.name(), ast.as_ref().unwrap()); + self.state.save(p.name(), ast.as_ref().unwrap()); parse = Some(p); } @@ -59,13 +58,19 @@ impl<'a> RewriteRequest<'a> { } ProtocolMessage::Bind(b) => { - let existing = self.prepared_statements.get_original_ast(b.statement()); + let existing = self.state.get(b.statement()); if let Some(existing) = existing { ast = Some(existing.clone()); bind = Some(b); } } + ProtocolMessage::Close(close) => { + if close.is_statement() { + self.state.remove(close.name()); + } + } + _ => (), } } @@ -89,11 +94,22 @@ impl<'a> RewriteRequest<'a> { debug!("rewrite was a no-op"); ast } - StepOutput::RewriteInPlace { stmt, ast, actions } => { + StepOutput::RewriteInPlace { + stmt, + ast, + actions, + stats, + } => { debug!("rewrite in-place: {}", stmt); for action in actions { action.execute(self.request); } + // Update stats. + { + let cluster_stats = self.cluster.stats(); + let mut lock = cluster_stats.lock(); + lock.rewrite = lock.rewrite.clone() + stats; + } let ast = ParseResult::new(ast, "".into()); // Cache new rewritten prepared statement. if extended { diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs index 8b1378917..351680b97 100644 --- a/pgdog/src/frontend/router/rewrite/state.rs +++ b/pgdog/src/frontend/router/rewrite/state.rs @@ -1 +1,37 @@ +//! Rewrite engine state. To be preserved between requests. +use std::collections::HashMap; + +use crate::frontend::router::parser::cache::CachedAst; + +#[derive(Debug, Default, Clone)] +pub struct RewriteState { + originals: HashMap, +} + +impl RewriteState { + /// Save original AST into rewrite state. + /// + /// We use it to rewrite Bind messages. Instead of encapsulating + /// complex rewrite rules in an enum, we walk the AST and + /// perform whatever changes we need. + /// + pub fn save(&mut self, name: &str, ast: &CachedAst) { + self.originals.insert(name.to_string(), ast.clone()); + } + + /// Get the original AST by statement name. + pub fn get(&self, name: &str) -> Option<&CachedAst> { + self.originals.get(name) + } + + /// Remove AST from state. + pub fn remove(&mut self, name: &str) -> bool { + self.originals.remove(name).is_some() + } + + /// Number of ASTs in the state. + pub fn len(&self) -> usize { + self.originals.len() + } +} diff --git a/pgdog/src/frontend/router/rewrite/stats.rs b/pgdog/src/frontend/router/rewrite/stats.rs new file mode 100644 index 000000000..c1f80b37c --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/stats.rs @@ -0,0 +1,22 @@ +//! Rewrite engine stats. + +use std::ops::Add; + +#[derive(Debug, Default, Clone, Copy)] +pub struct RewriteStats { + pub parse: usize, + pub bind: usize, + pub simple: usize, +} + +impl Add for RewriteStats { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + parse: self.parse + rhs.parse, + bind: self.bind + rhs.bind, + simple: self.simple + rhs.simple, + } + } +} diff --git a/pgdog/src/stats/http_server.rs b/pgdog/src/stats/http_server.rs index 9f6628924..a8d5e2a17 100644 --- a/pgdog/src/stats/http_server.rs +++ b/pgdog/src/stats/http_server.rs @@ -10,7 +10,7 @@ use hyper_util::rt::TokioIo; use tokio::net::TcpListener; use tracing::info; -use super::{Clients, MirrorStatsMetrics, Pools, QueryCache}; +use super::{Clients, MirrorStatsMetrics, Pools, QueryCache, RewriteStatsMetrics}; async fn metrics(_: Request) -> Result>, Infallible> { let clients = Clients::load(); @@ -26,13 +26,20 @@ async fn metrics(_: Request) -> Result = RewriteStatsMetrics::load() + .into_iter() + .map(|m| m.to_string()) + .collect(); + let rewrite_stats = rewrite_stats.join("\n"); let metrics_data = clients.to_string() + "\n" + &pools.to_string() + "\n" + &mirror_stats + "\n" - + &query_cache; + + &query_cache + + "\n" + + &rewrite_stats; let response = Response::builder() .header( hyper::header::CONTENT_TYPE, diff --git a/pgdog/src/stats/mirror_stats.rs b/pgdog/src/stats/mirror_stats.rs index d67a34ca5..98d33f6a9 100644 --- a/pgdog/src/stats/mirror_stats.rs +++ b/pgdog/src/stats/mirror_stats.rs @@ -24,7 +24,7 @@ impl MirrorStatsMetrics { for (user, cluster) in databases().all() { let stats = cluster.stats(); let stats = stats.lock(); - let counts = stats.counts; + let counts = stats.mirrors; // Per-cluster metrics with labels let labels = vec![ diff --git a/pgdog/src/stats/mod.rs b/pgdog/src/stats/mod.rs index e7b1843cc..15c8d0a23 100644 --- a/pgdog/src/stats/mod.rs +++ b/pgdog/src/stats/mod.rs @@ -8,9 +8,11 @@ pub use open_metric::*; pub mod logger; pub mod memory; pub mod query_cache; +pub mod rewrite_stats; pub use clients::Clients; pub use logger::Logger as StatsLogger; pub use mirror_stats::MirrorStatsMetrics; pub use pools::{PoolMetric, Pools}; pub use query_cache::QueryCache; +pub use rewrite_stats::RewriteStatsMetrics; diff --git a/pgdog/src/stats/rewrite_stats.rs b/pgdog/src/stats/rewrite_stats.rs new file mode 100644 index 000000000..30b62244d --- /dev/null +++ b/pgdog/src/stats/rewrite_stats.rs @@ -0,0 +1,119 @@ +//! Rewrite stats OpenMetrics. + +use crate::backend::databases::databases; + +use super::{Measurement, Metric, OpenMetric}; + +pub struct RewriteStatsMetrics; + +impl RewriteStatsMetrics { + pub fn load() -> Vec { + let mut metrics = vec![]; + + let mut parse_measurements = vec![]; + let mut bind_measurements = vec![]; + let mut simple_measurements = vec![]; + + let mut global_parse = 0usize; + let mut global_bind = 0usize; + let mut global_simple = 0usize; + + for (user, cluster) in databases().all() { + let stats = cluster.stats(); + let stats = stats.lock(); + let rewrite = stats.rewrite; + + let labels = vec![ + ("user".into(), user.user.clone()), + ("database".into(), user.database.clone()), + ]; + + parse_measurements.push(Measurement { + labels: labels.clone(), + measurement: rewrite.parse.into(), + }); + + bind_measurements.push(Measurement { + labels: labels.clone(), + measurement: rewrite.bind.into(), + }); + + simple_measurements.push(Measurement { + labels, + measurement: rewrite.simple.into(), + }); + + global_parse += rewrite.parse; + global_bind += rewrite.bind; + global_simple += rewrite.simple; + } + + parse_measurements.push(Measurement { + labels: vec![], + measurement: global_parse.into(), + }); + + bind_measurements.push(Measurement { + labels: vec![], + measurement: global_bind.into(), + }); + + simple_measurements.push(Measurement { + labels: vec![], + measurement: global_simple.into(), + }); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_parse_count".into(), + measurements: parse_measurements, + help: "Number of Parse messages rewritten.".into(), + })); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_bind_count".into(), + measurements: bind_measurements, + help: "Number of Bind messages rewritten.".into(), + })); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_simple_count".into(), + measurements: simple_measurements, + help: "Number of simple queries rewritten.".into(), + })); + + metrics + } +} + +struct RewriteStatsMetric { + name: String, + measurements: Vec, + help: String, +} + +impl OpenMetric for RewriteStatsMetric { + fn name(&self) -> String { + self.name.clone() + } + + fn measurements(&self) -> Vec { + self.measurements.clone() + } + + fn help(&self) -> Option { + Some(self.help.clone()) + } + + fn metric_type(&self) -> String { + "counter".into() + } +} + +impl std::fmt::Display for RewriteStatsMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for metric in RewriteStatsMetrics::load() { + writeln!(f, "{}", metric)?; + } + Ok(()) + } +} From 076ba4c72f99be3cd982887c948cbf2e89851582 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 Dec 2025 14:50:20 -0800 Subject: [PATCH 15/23] fix --- .../router/rewrite/unique_id/explain.rs | 12 +- .../router/rewrite/unique_id/insert.rs | 41 +---- .../frontend/router/rewrite/unique_id/mod.rs | 141 ++++++++++++++++-- 3 files changed, 144 insertions(+), 50 deletions(-) diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs index b683915b9..c397e0dd2 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -4,7 +4,7 @@ use pg_query::NodeEnum; use super::{ super::{Context, Error, RewriteModule}, - InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite, + max_param_number, InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite, }; #[derive(Default)] @@ -81,7 +81,7 @@ impl ExplainUniqueIdRewrite { let mut bind = input.bind_take(); let extended = input.extended(); - let mut parameter_counter = 0; + let mut parameter_counter = max_param_number(input.parse_result()); if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt_mut()? @@ -130,6 +130,7 @@ impl ExplainUniqueIdRewrite { let mut bind = input.bind_take(); let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); if let Some(NodeEnum::ExplainStmt(stmt)) = input .stmt_mut()? @@ -140,7 +141,12 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::InsertStmt(insert)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - InsertUniqueIdRewrite::rewrite_insert(insert, &mut bind, extended)?; + InsertUniqueIdRewrite::rewrite_insert( + insert, + &mut bind, + extended, + &mut param_counter, + )?; } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index c6af0c125..30daf6aac 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -2,7 +2,7 @@ use pg_query::{protobuf::InsertStmt, NodeEnum}; use super::{ super::{Context, Error, RewriteModule}, - bigint_const, bigint_param, + bigint_const, bigint_param, max_param_number, }; use crate::{ frontend::router::parser::{Insert, Value}, @@ -34,8 +34,8 @@ impl InsertUniqueIdRewrite { stmt: &mut InsertStmt, bind: &mut Option, extended: bool, + param_counter: &mut i32, ) -> Result<(), Error> { - let mut param_counter = Self::param_count(stmt)?; let select = stmt .select_stmt .as_mut() @@ -53,15 +53,15 @@ impl InsertUniqueIdRewrite { let id = unique_id::UniqueId::generator()?.next_id(); let node = if extended { - param_counter += 1; + *param_counter += 1; if let Some(ref mut bind) = bind { let count = bind.add_parameter(Datum::Bigint(id))?; // The number of parameters in the query doesn't match what's in the bind message. - if count != param_counter { + if count != *param_counter { return Err(Error::ParameterCountMismatch); } } - bigint_param(param_counter) + bigint_param(*param_counter) } else { bigint_const(id) }; @@ -76,34 +76,6 @@ impl InsertUniqueIdRewrite { Ok(()) } - - fn param_count(stmt: &InsertStmt) -> Result { - let mut max = 0; - - let select = stmt - .select_stmt - .as_ref() - .ok_or(Error::ParserError)? - .node - .as_ref() - .ok_or(Error::ParserError)?; - - if let NodeEnum::SelectStmt(stmt) = select { - for tuple in stmt.values_lists.iter() { - if let Some(NodeEnum::List(ref tuple)) = tuple.node { - for column in tuple.items.iter() { - if let Some(NodeEnum::ParamRef(ref param)) = column.node { - if param.number > max { - max = param.number; - } - } - } - } - } - } - - Ok(max) - } } impl RewriteModule for InsertUniqueIdRewrite { @@ -125,6 +97,7 @@ impl RewriteModule for InsertUniqueIdRewrite { let mut bind = input.bind_take(); let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); if let Some(NodeEnum::InsertStmt(stmt)) = input .stmt_mut()? @@ -132,7 +105,7 @@ impl RewriteModule for InsertUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_insert(stmt, &mut bind, extended)?; + Self::rewrite_insert(stmt, &mut bind, extended, &mut param_counter)?; } input.bind_put(bind); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs index ff2733f3f..6353cb499 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -2,7 +2,7 @@ use pg_query::{ protobuf::{a_const::Val, AConst, Node, ParamRef, ParseResult, TypeCast, TypeName}, - NodeEnum, NodeRef, + NodeEnum, }; pub mod explain; @@ -77,16 +77,131 @@ fn bigint_const(id: i64) -> NodeEnum { /// Find the maximum parameter number ($N) in a parse result. pub fn max_param_number(result: &ParseResult) -> i32 { - result - .nodes() - .iter() - .filter_map(|(node, _, _, _)| { - if let NodeRef::ParamRef(p) = node { - Some(p.number) - } else { - None - } - }) - .max() - .unwrap_or(0) + let mut max = 0; + for stmt in &result.stmts { + if let Some(ref stmt) = stmt.stmt { + find_max_param(&stmt.node, &mut max); + } + } + max +} + +fn find_max_param(node: &Option, max: &mut i32) { + let Some(node) = node else { + return; + }; + + match node { + NodeEnum::ParamRef(param) => { + if param.number > *max { + *max = param.number; + } + } + NodeEnum::TypeCast(cast) => { + if let Some(ref arg) = cast.arg { + find_max_param(&arg.node, max); + } + } + NodeEnum::FuncCall(func) => { + for arg in &func.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::AExpr(expr) => { + if let Some(ref lexpr) = expr.lexpr { + find_max_param(&lexpr.node, max); + } + if let Some(ref rexpr) = expr.rexpr { + find_max_param(&rexpr.node, max); + } + } + NodeEnum::SelectStmt(stmt) => { + for item in &stmt.target_list { + find_max_param(&item.node, max); + } + for item in &stmt.values_lists { + find_max_param(&item.node, max); + } + for item in &stmt.from_clause { + find_max_param(&item.node, max); + } + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + if let Some(ref limit) = stmt.limit_count { + find_max_param(&limit.node, max); + } + if let Some(ref offset) = stmt.limit_offset { + find_max_param(&offset.node, max); + } + } + NodeEnum::InsertStmt(stmt) => { + if let Some(ref select) = stmt.select_stmt { + find_max_param(&select.node, max); + } + } + NodeEnum::UpdateStmt(stmt) => { + for item in &stmt.target_list { + find_max_param(&item.node, max); + } + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + } + NodeEnum::DeleteStmt(stmt) => { + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + } + NodeEnum::ResTarget(res) => { + if let Some(ref val) = res.val { + find_max_param(&val.node, max); + } + } + NodeEnum::List(list) => { + for item in &list.items { + find_max_param(&item.node, max); + } + } + NodeEnum::CoalesceExpr(coalesce) => { + for arg in &coalesce.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::CaseExpr(case) => { + if let Some(ref arg) = case.arg { + find_max_param(&arg.node, max); + } + for when in &case.args { + find_max_param(&when.node, max); + } + if let Some(ref defresult) = case.defresult { + find_max_param(&defresult.node, max); + } + } + NodeEnum::CaseWhen(when) => { + if let Some(ref expr) = when.expr { + find_max_param(&expr.node, max); + } + if let Some(ref result) = when.result { + find_max_param(&result.node, max); + } + } + NodeEnum::BoolExpr(expr) => { + for arg in &expr.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::NullTest(test) => { + if let Some(ref arg) = test.arg { + find_max_param(&arg.node, max); + } + } + NodeEnum::ExplainStmt(stmt) => { + if let Some(ref query) = stmt.query { + find_max_param(&query.node, max); + } + } + _ => {} + } } From 1949aa81d656fd68e91a2e4c852060ed4f840070 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 11:51:45 -0800 Subject: [PATCH 16/23] rewrite plan --- pgdog/src/frontend/client/query_engine/mod.rs | 6 +- pgdog/src/frontend/client_request.rs | 11 ++ pgdog/src/frontend/router/rewrite/context.rs | 74 +++------ pgdog/src/frontend/router/rewrite/error.rs | 3 + .../router/rewrite/insert_split/mod.rs | 21 ++- pgdog/src/frontend/router/rewrite/mod.rs | 2 + pgdog/src/frontend/router/rewrite/output.rs | 2 + pgdog/src/frontend/router/rewrite/plan.rs | 60 +++++++ pgdog/src/frontend/router/rewrite/request.rs | 148 ++++++++++-------- pgdog/src/frontend/router/rewrite/state.rs | 47 +++--- .../router/rewrite/unique_id/explain.rs | 25 +-- .../router/rewrite/unique_id/insert.rs | 27 ++-- .../router/rewrite/unique_id/select.rs | 46 +++--- .../router/rewrite/unique_id/update.rs | 27 ++-- pgdog/src/net/messages/bind.rs | 4 + pgdog/src/net/messages/parse.rs | 4 + 16 files changed, 277 insertions(+), 230 deletions(-) create mode 100644 pgdog/src/frontend/router/rewrite/plan.rs diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index b5901d962..836f13469 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -134,11 +134,7 @@ impl QueryEngine { context.prepared_statements, &mut self.rewrite_state, ); - match rewrite.execute() { - Ok(ast) => context.ast = Some(ast), - Err(rewrite::Error::EmptyQuery) => (), - Err(err) => return Err(err.into()), - } + context.ast = rewrite.execute()?; } } diff --git a/pgdog/src/frontend/client_request.rs b/pgdog/src/frontend/client_request.rs index af685ab4c..5a90e2bbe 100644 --- a/pgdog/src/frontend/client_request.rs +++ b/pgdog/src/frontend/client_request.rs @@ -140,6 +140,17 @@ impl ClientRequest { Ok(None) } + /// Get mutable reference to parameters, if any. + pub fn parameters_mut(&mut self) -> Result, Error> { + for message in self.messages.iter_mut() { + if let ProtocolMessage::Bind(bind) = message { + return Ok(Some(bind)); + } + } + + Ok(None) + } + /// Get all CopyData messages. pub fn copy_data(&self) -> Result, Error> { let mut rows = vec![]; diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs index f050301f8..1266ad044 100644 --- a/pgdog/src/frontend/router/rewrite/context.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -2,7 +2,9 @@ use pg_query::protobuf::{ParseResult, RawStmt}; -use super::{output::RewriteActionKind, stats::RewriteStats, Error, RewriteAction, StepOutput}; +use super::{ + output::RewriteActionKind, stats::RewriteStats, Error, RewriteAction, RewritePlan, StepOutput, +}; use crate::net::{Bind, Parse, ProtocolMessage, Query}; #[derive(Debug, Clone)] @@ -12,30 +14,23 @@ pub struct Context<'a> { original: &'a ParseResult, // If an in-place rewrite was done, the statement is saved here. rewrite: Option, - /// Original bind message, if any. - bind: Option<&'a Bind>, - /// Bind rewritten. - rewrite_bind: Option, /// Additional messages to add to the request. result: Vec, /// Extended protocol. parse: Option<&'a Parse>, + /// Rewrite plan. + plan: RewritePlan, } impl<'a> Context<'a> { /// Create new input. - pub(super) fn new( - original: &'a ParseResult, - bind: Option<&'a Bind>, - parse: Option<&'a Parse>, - ) -> Self { + pub(super) fn new(original: &'a ParseResult, parse: Option<&'a Parse>) -> Self { Self { original, - bind, rewrite: None, - rewrite_bind: None, result: vec![], parse, + plan: RewritePlan::default(), } } @@ -46,32 +41,12 @@ impl<'a> Context<'a> { /// We are rewriting an extended protocol request. pub fn extended(&self) -> bool { - self.parse().is_some() || self.bind().is_some() + self.parse().is_some() } - /// Get the Bind message, if set. - pub fn bind(&'a self) -> Option<&'a Bind> { - if let Some(ref rewrite_bind) = self.rewrite_bind { - Some(rewrite_bind) - } else { - self.bind - } - } - - /// Take the Bind message for modification. - /// Don't forget to return it. - #[must_use] - pub fn bind_take(&mut self) -> Option { - if self.rewrite_bind.is_none() { - self.rewrite_bind = self.bind.cloned(); - } - - self.rewrite_bind.take() - } - - /// Put the bind message back. - pub fn bind_put(&mut self, bind: Option) { - self.rewrite_bind = bind; + /// Get reference to rewrite plan for modification. + pub fn plan(&mut self) -> &mut RewritePlan { + &mut self.plan } /// Get the original (or modified) statement. @@ -121,31 +96,19 @@ impl<'a> Context<'a> { Ok(StepOutput::NoOp) } else { let mut stats = RewriteStats::default(); - let bind = self.rewrite_bind.take(); let ast = self.rewrite.take().ok_or(Error::NoRewrite)?; let stmt = ast.deparse()?; - let extended = self.extended(); let mut parse = self.parse().cloned(); let mut actions = self.result; - if extended { - if let Some(mut parse) = parse.take() { - parse.set_query(&stmt); - actions.push(RewriteAction { - message: parse.into(), - action: RewriteActionKind::Replace, - }); - stats.parse += 1; - } - - if let Some(bind) = bind { - actions.push(RewriteAction { - message: bind.into(), - action: RewriteActionKind::Replace, - }); - stats.bind += 1; - } + if let Some(mut parse) = parse.take() { + parse.set_query(&stmt); + actions.push(RewriteAction { + message: parse.into(), + action: RewriteActionKind::Replace, + }); + stats.parse += 1; } else { actions.push(RewriteAction { message: Query::new(stmt.clone()).into(), @@ -159,6 +122,7 @@ impl<'a> Context<'a> { ast, actions, stats, + plan: self.plan.freeze(), }) } } diff --git a/pgdog/src/frontend/router/rewrite/error.rs b/pgdog/src/frontend/router/rewrite/error.rs index f29c825db..2034b6d82 100644 --- a/pgdog/src/frontend/router/rewrite/error.rs +++ b/pgdog/src/frontend/router/rewrite/error.rs @@ -31,4 +31,7 @@ pub enum Error { #[error("parser: {0}")] Parser(#[from] crate::frontend::router::parser::Error), + + #[error("no active rewrite plan set")] + NoActiveRewritePlan, } diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs index 84ab4eb55..4a9ab1b15 100644 --- a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -3,8 +3,6 @@ use pg_query::{ Node, NodeEnum, }; -use crate::net::Bind; - use super::*; #[derive(Default)] @@ -39,20 +37,19 @@ impl RewriteModule for InsertSplitRewrite { let mut new_insert = proto_insert.clone(); let mut new_select = proto_select.clone(); let mut new_values = values.clone(); - let mut new_bind = Bind::default(); // Rewrite the parameter references // and create new Bind message for each INSERT statement. if let Some(NodeEnum::List(list)) = new_values.node.as_mut() { for value in list.items.iter_mut() { - if let Some(NodeEnum::ParamRef(param)) = value.node.as_mut() { - let parameter = input - .bind() - .and_then(|bind| bind.parameter(param.number as usize - 1).ok()) - .flatten(); - if let Some(parameter) = parameter { - param.number = new_bind.add_existing(parameter)?; - } + if let Some(NodeEnum::ParamRef(_)) = value.node.as_mut() { + // let parameter = input + // .bind() + // .and_then(|bind| bind.parameter(param.number as usize - 1).ok()) + // .flatten(); + // if let Some(parameter) = parameter { + // param.number = new_bind.add_existing(parameter)?; + // } } } } @@ -69,7 +66,7 @@ impl RewriteModule for InsertSplitRewrite { ..Default::default() }], }; - inserts.push((result, new_bind)); + inserts.push(result); } } } diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index 77a77e10c..f3a8ac002 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -11,6 +11,7 @@ pub mod error; pub mod insert_split; pub mod interface; pub mod output; +pub mod plan; pub mod prepared; pub mod request; pub mod state; @@ -21,6 +22,7 @@ pub use context::Context; pub use error::Error; pub use interface::RewriteModule; pub use output::{RewriteAction, StepOutput}; +pub use plan::{ImmutableRewritePlan, RewritePlan, UniqueIdPlan}; pub use request::RewriteRequest; pub use state::RewriteState; diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index 9c2049a1b..7fdeb464b 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -1,6 +1,7 @@ use pg_query::protobuf::ParseResult; use super::stats::RewriteStats; +use super::ImmutableRewritePlan; use crate::{frontend::ClientRequest, net::ProtocolMessage}; use std::mem::discriminant; @@ -53,6 +54,7 @@ pub enum StepOutput { ast: ParseResult, stmt: String, stats: RewriteStats, + plan: ImmutableRewritePlan, }, } diff --git a/pgdog/src/frontend/router/rewrite/plan.rs b/pgdog/src/frontend/router/rewrite/plan.rs new file mode 100644 index 000000000..af8f11f77 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/plan.rs @@ -0,0 +1,60 @@ +use std::{ops::Deref, sync::Arc}; + +use super::Error; +use crate::{ + net::{Bind, Datum}, + unique_id::UniqueId, +}; + +#[derive(Debug, Clone, Default)] +pub struct UniqueIdPlan { + /// Parameter number. + pub(super) param_ref: i32, +} + +#[derive(Debug, Clone, Default)] +pub struct RewritePlan { + /// How many unique IDs to add to the Bind message. + pub(super) unique_ids: Vec, +} + +#[derive(Debug, Clone, Default)] +pub struct ImmutableRewritePlan { + /// Compiled rewrite plan, that cannot be modified further. + pub(super) plan: Arc, +} + +impl Deref for ImmutableRewritePlan { + type Target = RewritePlan; + + fn deref(&self) -> &Self::Target { + &self.plan + } +} + +impl RewritePlan { + /// Apply rewrite plan to Bind message. + /// + /// N.B. this isn't idempotent, run this only once. + /// + pub fn apply_bind(&self, bind: &mut Bind) -> Result<(), Error> { + for unique_id in &self.unique_ids { + let id = UniqueId::generator()?.next_id(); + let counter = bind.add_parameter(Datum::Bigint(id))?; + // Params should be added to plan in order. + // This validates it. + if counter != unique_id.param_ref { + return Err(Error::ParameterCountMismatch); + } + } + + Ok(()) + } + + /// Freeze rewrite plan, without any more modifications allowed. + pub fn freeze(self) -> ImmutableRewritePlan { + ImmutableRewritePlan { + plan: Arc::new(self), + } + } +} diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs index 90ee2f1df..28b22fc1a 100644 --- a/pgdog/src/frontend/router/rewrite/request.rs +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -8,7 +8,7 @@ use crate::{ router::parser::{cache::CachedAst, Cache}, ClientRequest, PreparedStatements, }, - net::ProtocolMessage, + net::{Protocol, ProtocolMessage}, }; pub struct RewriteRequest<'a> { @@ -34,91 +34,117 @@ impl<'a> RewriteRequest<'a> { } } - /// Execute rewrite and return the query AST. - pub fn execute(&'a mut self) -> Result { + fn handle_parse(&mut self) -> Result { + let parse = self.request.iter().find(|p| p.code() == 'P'); + let parse = if let Some(ProtocolMessage::Parse(parse)) = parse { + parse + } else { + return Err(Error::EmptyQuery); + }; + let schema = self.cluster.sharding_schema(); + let ast = Cache::get().parse(parse.query(), &schema)?; + + let mut context = Context::new(&ast.ast().protobuf, Some(parse)); + Rewrite::new(self.prepared_statements).rewrite(&mut context)?; + let output = context.build()?; - let (result, ast, extended) = { - let mut parse = None; - let mut bind = None; - let mut ast = None; - - let schema = self.cluster.sharding_schema(); - - for message in self.request.iter() { - match message { - ProtocolMessage::Parse(p) => { - ast = Some(Cache::get().parse(p.query(), &schema)?); - self.state.save(p.name(), ast.as_ref().unwrap()); - parse = Some(p); - } - - ProtocolMessage::Query(query) => { - ast = Some(Cache::get().parse_uncached(query.query(), &schema)?); - } - - ProtocolMessage::Bind(b) => { - let existing = self.state.get(b.statement()); - if let Some(existing) = existing { - ast = Some(existing.clone()); - bind = Some(b); - } - } - - ProtocolMessage::Close(close) => { - if close.is_statement() { - self.state.remove(close.name()); - } - } - - _ => (), + let ast = match output { + StepOutput::NoOp => ast, + StepOutput::RewriteInPlace { + actions, + ast, + stmt, + stats, + plan, + } => { + debug!("rewrite (extended): {}", stmt); + + self.state.save_plan(Some(parse), plan); + + for action in actions { + action.execute(self.request); } - } - let ast = ast.ok_or(Error::EmptyQuery)?; + // Update stats. + { + let cluster_stats = self.cluster.stats(); + let mut lock = cluster_stats.lock(); + lock.rewrite = lock.rewrite.clone() + stats; + } - let mut context = Context::new(&ast.ast().protobuf, bind, parse); - let mut rewrite = Rewrite::new(self.prepared_statements); + let ast = ParseResult::new(ast, "".into()); + Cache::get().save(&stmt, ast, &schema)? + } + }; - let result = match rewrite.rewrite(&mut context) { - Ok(_) => context.build()?, - Err(Error::EmptyQuery) => StepOutput::NoOp, - Err(err) => return Err(err), - }; + Ok(ast) + } - (result, ast, parse.is_some()) + fn handle_query(&mut self) -> Result { + let query = self.request.iter().find(|p| p.code() == 'Q'); + let query = if let Some(ProtocolMessage::Query(query)) = query { + query + } else { + return Err(Error::EmptyQuery); }; - let ast = match result { - StepOutput::NoOp => { - debug!("rewrite was a no-op"); - ast - } + let schema = self.cluster.sharding_schema(); + let ast = Cache::get().parse_uncached(query.query(), &schema)?; + + let mut context = Context::new(&ast.ast().protobuf, None); + Rewrite::new(self.prepared_statements).rewrite(&mut context)?; + let output = context.build()?; + + let ast = match output { + StepOutput::NoOp => ast, StepOutput::RewriteInPlace { - stmt, - ast, actions, + ast, + stmt, stats, + plan, } => { - debug!("rewrite in-place: {}", stmt); + debug!("rewrite (simple): {}", stmt); + + self.state.save_plan(None, plan); + for action in actions { action.execute(self.request); } + // Update stats. { let cluster_stats = self.cluster.stats(); let mut lock = cluster_stats.lock(); lock.rewrite = lock.rewrite.clone() + stats; } + let ast = ParseResult::new(ast, "".into()); - // Cache new rewritten prepared statement. - if extended { - Cache::get().save(&stmt, ast, &schema).unwrap() - } else { - CachedAst::new_parsed(&stmt, ast, &schema).unwrap() - } + CachedAst::new_parsed(&stmt, ast, &schema)? } }; + + Ok(ast) + } + + /// Execute rewrite and return the query AST. + pub fn execute(&mut self) -> Result, Error> { + let mut ast: Option = None; + + for result in [self.handle_parse(), self.handle_query()] { + match result { + Ok(a) => ast = Some(a), + Err(Error::EmptyQuery) => continue, + Err(err) => return Err(err), + } + } + let parameters = self.request.parameters_mut()?; + if let Some(parameters) = parameters { + let plan = self.state.activate_plan(parameters)?; + plan.apply_bind(parameters)?; + } + Ok(ast) } } diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs index 351680b97..643fab6f8 100644 --- a/pgdog/src/frontend/router/rewrite/state.rs +++ b/pgdog/src/frontend/router/rewrite/state.rs @@ -2,36 +2,43 @@ use std::collections::HashMap; -use crate::frontend::router::parser::cache::CachedAst; +use bytes::Bytes; + +use super::{Error, ImmutableRewritePlan}; +use crate::{ + frontend::router::parser::cache::CachedAst, + net::{Bind, Parse}, +}; #[derive(Debug, Default, Clone)] pub struct RewriteState { - originals: HashMap, + plans: HashMap, + active_plan: Option, } impl RewriteState { - /// Save original AST into rewrite state. - /// - /// We use it to rewrite Bind messages. Instead of encapsulating - /// complex rewrite rules in an enum, we walk the AST and - /// perform whatever changes we need. - /// - pub fn save(&mut self, name: &str, ast: &CachedAst) { - self.originals.insert(name.to_string(), ast.clone()); - } + /// Save rewrite plan for later use and active it for + /// this request. + pub fn save_plan(&mut self, parse: Option<&Parse>, plan: ImmutableRewritePlan) { + if let Some(parse) = parse { + self.plans.insert(parse.name_ref(), plan.clone()); + } - /// Get the original AST by statement name. - pub fn get(&self, name: &str) -> Option<&CachedAst> { - self.originals.get(name) + self.active_plan = Some(plan); } - /// Remove AST from state. - pub fn remove(&mut self, name: &str) -> bool { - self.originals.remove(name).is_some() + /// Activate plan for Bind, or error out if plan doesn't exist. + pub fn activate_plan(&mut self, bind: &Bind) -> Result<&ImmutableRewritePlan, Error> { + if let Some(plan) = self.plans.get(bind.statement_ref()) { + self.active_plan = Some(plan.clone()); + self.plan() + } else { + Err(Error::NoRewrite) + } } - /// Number of ASTs in the state. - pub fn len(&self) -> usize { - self.originals.len() + /// Get currently active rewrite plan. + pub fn plan(&self) -> Result<&ImmutableRewritePlan, Error> { + self.active_plan.as_ref().ok_or(Error::NoActiveRewritePlan) } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs index c397e0dd2..4123ef0a7 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -79,7 +79,6 @@ impl ExplainUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut parameter_counter = max_param_number(input.parse_result()); @@ -92,16 +91,14 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::SelectStmt(select)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - SelectUniqueIdRewrite::rewrite_select( + input.plan().unique_ids = SelectUniqueIdRewrite::rewrite_select( select, - &mut bind, extended, &mut parameter_counter, )?; } } - input.bind_put(bind); Ok(()) } @@ -128,7 +125,6 @@ impl ExplainUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut param_counter = max_param_number(input.parse_result()); @@ -141,16 +137,11 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::InsertStmt(insert)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - InsertUniqueIdRewrite::rewrite_insert( - insert, - &mut bind, - extended, - &mut param_counter, - )?; + input.plan().unique_ids = + InsertUniqueIdRewrite::rewrite_insert(insert, extended, &mut param_counter)?; } } - input.bind_put(bind); Ok(()) } @@ -177,7 +168,6 @@ impl ExplainUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut param_counter = super::max_param_number(input.parse_result()); @@ -190,16 +180,11 @@ impl ExplainUniqueIdRewrite { if let Some(NodeEnum::UpdateStmt(update)) = stmt.query.as_mut().and_then(|q| q.node.as_mut()) { - UpdateUniqueIdRewrite::rewrite_update( - update, - &mut bind, - extended, - &mut param_counter, - )?; + input.plan().unique_ids = + UpdateUniqueIdRewrite::rewrite_update(update, extended, &mut param_counter)?; } } - input.bind_put(bind); Ok(()) } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index 30daf6aac..f112ddb2a 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -5,8 +5,10 @@ use super::{ bigint_const, bigint_param, max_param_number, }; use crate::{ - frontend::router::parser::{Insert, Value}, - net::Datum, + frontend::router::{ + parser::{Insert, Value}, + rewrite::UniqueIdPlan, + }, unique_id, }; @@ -32,10 +34,10 @@ impl InsertUniqueIdRewrite { pub fn rewrite_insert( stmt: &mut InsertStmt, - bind: &mut Option, extended: bool, param_counter: &mut i32, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut plans = vec![]; let select = stmt .select_stmt .as_mut() @@ -54,13 +56,9 @@ impl InsertUniqueIdRewrite { let node = if extended { *param_counter += 1; - if let Some(ref mut bind) = bind { - let count = bind.add_parameter(Datum::Bigint(id))?; - // The number of parameters in the query doesn't match what's in the bind message. - if count != *param_counter { - return Err(Error::ParameterCountMismatch); - } - } + plans.push(UniqueIdPlan { + param_ref: *param_counter, + }); bigint_param(*param_counter) } else { bigint_const(id) @@ -74,7 +72,7 @@ impl InsertUniqueIdRewrite { } } - Ok(()) + Ok(plans) } } @@ -95,7 +93,6 @@ impl RewriteModule for InsertUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut param_counter = max_param_number(input.parse_result()); @@ -105,11 +102,9 @@ impl RewriteModule for InsertUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_insert(stmt, &mut bind, extended, &mut param_counter)?; + input.plan().unique_ids = Self::rewrite_insert(stmt, extended, &mut param_counter)?; } - input.bind_put(bind); - Ok(()) } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index 8ec520019..ab5324d44 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -6,12 +6,11 @@ use pg_query::{ }; use super::{ - super::{Context, Error, RewriteModule}, + super::{Context, Error, RewriteModule, UniqueIdPlan}, bigint_const, bigint_param, }; use crate::{ frontend::router::{parser::Value, rewrite::unique_id::max_param_number}, - net::Datum, unique_id, }; @@ -97,10 +96,11 @@ impl SelectUniqueIdRewrite { pub fn rewrite_select( stmt: &mut SelectStmt, - bind: &mut Option, extended: bool, paramter_counter: &mut i32, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut plans = vec![]; + // Rewrite target_list for target in stmt.target_list.iter_mut() { if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { @@ -111,13 +111,9 @@ impl SelectUniqueIdRewrite { let node = if extended { *paramter_counter += 1; - - if let Some(bind) = bind { - let counter = bind.add_parameter(Datum::Bigint(id))?; - if counter != *paramter_counter { - return Err(Error::ParameterCountMismatch); - } - } + plans.push(UniqueIdPlan { + param_ref: *paramter_counter, + }); bigint_param(*paramter_counter) } else { @@ -137,7 +133,7 @@ impl SelectUniqueIdRewrite { if let Some(NodeEnum::CommonTableExpr(ref mut expr)) = cte.node { if let Some(ref mut query) = expr.ctequery { if let Some(NodeEnum::SelectStmt(ref mut inner)) = query.node { - Self::rewrite_select(inner, bind, extended, paramter_counter)?; + Self::rewrite_select(inner, extended, paramter_counter)?; } } } @@ -146,45 +142,45 @@ impl SelectUniqueIdRewrite { // Rewrite subqueries in FROM clause for from in stmt.from_clause.iter_mut() { - Self::rewrite_from_node(from, bind, extended, paramter_counter)?; + Self::rewrite_from_node(from, extended, paramter_counter)?; } // Rewrite UNION/INTERSECT/EXCEPT (larg/rarg are Box) if let Some(ref mut larg) = stmt.larg { - Self::rewrite_select(larg, bind, extended, paramter_counter)?; + plans.extend(Self::rewrite_select(larg, extended, paramter_counter)?); } if let Some(ref mut rarg) = stmt.rarg { - Self::rewrite_select(rarg, bind, extended, paramter_counter)?; + plans.extend(Self::rewrite_select(rarg, extended, paramter_counter)?); } - Ok(()) + Ok(plans) } fn rewrite_from_node( node: &mut Node, - bind: &mut Option, extended: bool, paramter_counter: &mut i32, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut plans = vec![]; match node.node.as_mut() { Some(NodeEnum::RangeSubselect(ref mut subselect)) => { if let Some(ref mut subquery) = subselect.subquery { if let Some(NodeEnum::SelectStmt(ref mut inner)) = subquery.node { - Self::rewrite_select(inner, bind, extended, paramter_counter)?; + plans.extend(Self::rewrite_select(inner, extended, paramter_counter)?); } } } Some(NodeEnum::JoinExpr(ref mut join)) => { if let Some(ref mut larg) = join.larg { - Self::rewrite_from_node(larg, bind, extended, paramter_counter)?; + plans.extend(Self::rewrite_from_node(larg, extended, paramter_counter)?); } if let Some(ref mut rarg) = join.rarg { - Self::rewrite_from_node(rarg, bind, extended, paramter_counter)?; + plans.extend(Self::rewrite_from_node(rarg, extended, paramter_counter)?); } } _ => {} } - Ok(()) + Ok(plans) } } @@ -205,7 +201,6 @@ impl RewriteModule for SelectUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut parameter_counter = max_param_number(input.parse_result()); @@ -215,11 +210,10 @@ impl RewriteModule for SelectUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_select(stmt, &mut bind, extended, &mut parameter_counter)?; + let plans = Self::rewrite_select(stmt, extended, &mut parameter_counter)?; + input.plan().unique_ids = plans; } - input.bind_put(bind); - Ok(()) } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index ee4c5041b..0bf2c013d 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -3,10 +3,10 @@ use pg_query::{protobuf::UpdateStmt, NodeEnum}; use super::{ - super::{Context, Error, RewriteModule}, + super::{Context, Error, RewriteModule, UniqueIdPlan}, bigint_const, bigint_param, max_param_number, }; -use crate::{frontend::router::parser::Value, net::Datum, unique_id}; +use crate::{frontend::router::parser::Value, unique_id}; #[derive(Default)] pub struct UpdateUniqueIdRewrite {} @@ -29,10 +29,11 @@ impl UpdateUniqueIdRewrite { pub fn rewrite_update( stmt: &mut UpdateStmt, - bind: &mut Option, extended: bool, param_counter: &mut i32, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut plans = vec![]; + for target in stmt.target_list.iter_mut() { if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { if let Some(ref mut val) = res.val { @@ -42,12 +43,10 @@ impl UpdateUniqueIdRewrite { let node = if extended { *param_counter += 1; - if let Some(ref mut bind) = bind { - let count = bind.add_parameter(Datum::Bigint(id))?; - if count != *param_counter { - return Err(Error::ParameterCountMismatch); - } - } + plans.push(UniqueIdPlan { + param_ref: *param_counter, + }); + bigint_param(*param_counter) } else { bigint_const(id) @@ -59,7 +58,8 @@ impl UpdateUniqueIdRewrite { } } } - Ok(()) + + Ok(plans) } } @@ -80,7 +80,6 @@ impl RewriteModule for UpdateUniqueIdRewrite { return Ok(()); } - let mut bind = input.bind_take(); let extended = input.extended(); let mut param_counter = max_param_number(input.parse_result()); @@ -90,11 +89,9 @@ impl RewriteModule for UpdateUniqueIdRewrite { .as_mut() .and_then(|stmt| stmt.node.as_mut()) { - Self::rewrite_update(stmt, &mut bind, extended, &mut param_counter)?; + input.plan().unique_ids = Self::rewrite_update(stmt, extended, &mut param_counter)?; } - input.bind_put(bind); - Ok(()) } } diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index da562c653..3cb076c1b 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -199,6 +199,10 @@ impl Bind { unsafe { from_utf8_unchecked(&self.statement[0..self.statement.len() - 1]) } } + pub fn statement_ref(&self) -> &Bytes { + &self.statement + } + /// Format codes, if any. pub fn codes(&self) -> &[Format] { &self.codes diff --git a/pgdog/src/net/messages/parse.rs b/pgdog/src/net/messages/parse.rs index 315abd066..1675d4555 100644 --- a/pgdog/src/net/messages/parse.rs +++ b/pgdog/src/net/messages/parse.rs @@ -67,6 +67,10 @@ impl Parse { unsafe { from_utf8_unchecked(&self.query[0..self.query.len() - 1]) } } + pub fn name_ref(&self) -> Bytes { + self.name.clone() + } + /// Get query reference. pub fn query_ref(&self) -> Bytes { self.query.clone() From 1657ad59e2f5d54f2c1c154c8de08c59a7d75d7b Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 12:15:59 -0800 Subject: [PATCH 17/23] Fix tests, clippy --- pgdog/src/frontend/client/query_engine/mod.rs | 2 +- .../client/query_engine/route_query.rs | 4 +- pgdog/src/frontend/router/rewrite/context.rs | 2 +- .../router/rewrite/insert_split/mod.rs | 26 +-- pgdog/src/frontend/router/rewrite/output.rs | 8 + pgdog/src/frontend/router/rewrite/plan.rs | 151 ++++++++++++++++++ .../router/rewrite/prepared/execute.rs | 6 +- .../router/rewrite/prepared/prepare.rs | 2 +- pgdog/src/frontend/router/rewrite/request.rs | 4 +- pgdog/src/frontend/router/rewrite/state.rs | 5 +- .../router/rewrite/unique_id/explain.rs | 10 +- .../router/rewrite/unique_id/insert.rs | 34 ++-- .../router/rewrite/unique_id/select.rs | 33 ++-- .../router/rewrite/unique_id/update.rs | 32 ++-- 14 files changed, 219 insertions(+), 100 deletions(-) diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 836f13469..c42b50d3f 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -5,7 +5,7 @@ use crate::{ client::query_engine::hooks::QueryEngineHooks, router::{ parser::Shard, - rewrite::{self, RewriteRequest, RewriteState}, + rewrite::{RewriteRequest, RewriteState}, Route, }, BufferedQuery, Client, Command, Comms, Error, Router, RouterContext, Stats, diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index a8aff8054..7424a30bd 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -17,14 +17,14 @@ impl QueryEngine { if !in_transaction && !cluster.online() { // Reload cluster config. - if let Err(_) = self.backend.safe_reload().await { + self.backend.safe_reload().await.is_err() { return Some(ErrorResponse::connection( &identifier.user, &identifier.database, )); } - if let Err(_) = self.backend.cluster() { + if self.backend.cluster().is_err() { return Some(ErrorResponse::connection( &identifier.user, &identifier.database, diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs index 1266ad044..881c88cfc 100644 --- a/pgdog/src/frontend/router/rewrite/context.rs +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -5,7 +5,7 @@ use pg_query::protobuf::{ParseResult, RawStmt}; use super::{ output::RewriteActionKind, stats::RewriteStats, Error, RewriteAction, RewritePlan, StepOutput, }; -use crate::net::{Bind, Parse, ProtocolMessage, Query}; +use crate::net::{Parse, ProtocolMessage, Query}; #[derive(Debug, Clone)] pub struct Context<'a> { diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs index 4a9ab1b15..a6bdbdbf2 100644 --- a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -79,31 +79,17 @@ impl RewriteModule for InsertSplitRewrite { #[cfg(test)] mod test { - use crate::net::bind::Parameter; + use crate::net::Parse; use super::*; #[test] fn test_insert_split() { - let stmt = pg_query::parse( - "INSERT INTO users (id, email, created_at) - VALUES ($1, 'test@test.com', NOW()), (123, $2, '2025-01-01') RETURNING *", - ) - .unwrap(); - let bind = Bind::new_params( - "", - &[ - Parameter { - len: 4, - data: "1234".into(), - }, - Parameter { - len: 14, - data: "hello@test.com".into(), - }, - ], - ); - let mut context = Context::new(&stmt.protobuf, Some(&bind), None); + let query = "INSERT INTO users (id, email, created_at) + VALUES ($1, 'test@test.com', NOW()), (123, $2, '2025-01-01') RETURNING *"; + let stmt = pg_query::parse(query).unwrap(); + let parse = Parse::new_anonymous(query); + let mut context = Context::new(&stmt.protobuf, Some(&parse)); let mut module = InsertSplitRewrite::default(); module.rewrite(&mut context).unwrap(); } diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs index 7fdeb464b..72417acbe 100644 --- a/pgdog/src/frontend/router/rewrite/output.rs +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -66,6 +66,14 @@ impl StepOutput { Self::RewriteInPlace { stmt, .. } => Ok(stmt.as_str()), } } + + /// Get the rewrite plan, if any. + pub fn plan(&self) -> Result<&ImmutableRewritePlan, ()> { + match self { + Self::NoOp => Err(()), + Self::RewriteInPlace { plan, .. } => Ok(plan), + } + } } #[derive(Debug, Clone)] diff --git a/pgdog/src/frontend/router/rewrite/plan.rs b/pgdog/src/frontend/router/rewrite/plan.rs index af8f11f77..654093823 100644 --- a/pgdog/src/frontend/router/rewrite/plan.rs +++ b/pgdog/src/frontend/router/rewrite/plan.rs @@ -58,3 +58,154 @@ impl RewritePlan { } } } + +#[cfg(test)] +mod test { + use super::*; + use crate::net::bind::Parameter; + use std::env::set_var; + + #[test] + fn test_apply_bind_adds_parameters() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a rewrite plan expecting params $3 and $4 + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 3 }, UniqueIdPlan { param_ref: 4 }], + }; + + // Create a Bind with 2 existing parameters + let mut bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 2, + data: "{}".into(), + }, + ], + ); + + // Apply the plan + plan.apply_bind(&mut bind).unwrap(); + + // Verify we now have 4 parameters + assert_eq!(bind.params_raw().len(), 4); + + // Verify the added parameters are BIGINT values (8 bytes in text format) + let param3 = bind.parameter(2).unwrap().unwrap(); + let param4 = bind.parameter(3).unwrap().unwrap(); + + // The parameters should be valid bigints + let id3 = param3.bigint(); + let id4 = param4.bigint(); + assert!(id3.is_some(), "Third parameter should be a valid bigint"); + assert!(id4.is_some(), "Fourth parameter should be a valid bigint"); + + // IDs should be different (unique) + assert_ne!(id3, id4); + } + + #[test] + fn test_apply_bind_parameter_count_mismatch() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a plan expecting param $5 (but bind only has 2 params, so next will be $3) + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 5 }], + }; + + // Create a Bind with 2 existing parameters + let mut bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 2, + data: "{}".into(), + }, + ], + ); + + // Apply should fail due to mismatch + let result = plan.apply_bind(&mut bind); + assert!(result.is_err()); + } + + #[test] + fn test_apply_bind_empty_plan() { + // Empty plan should be a no-op + let plan = RewritePlan::default(); + + let mut bind = Bind::new_params( + "", + &[Parameter { + len: 4, + data: "test".into(), + }], + ); + + plan.apply_bind(&mut bind).unwrap(); + + // Should still have just 1 parameter + assert_eq!(bind.params_raw().len(), 1); + } + + #[test] + fn test_apply_bind_single_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a plan for a single unique ID as $2 + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 2 }], + }; + + // Create a Bind with 1 existing parameter + let mut bind = Bind::new_params( + "", + &[Parameter { + len: 4, + data: "test".into(), + }], + ); + + plan.apply_bind(&mut bind).unwrap(); + + // Should now have 2 parameters + assert_eq!(bind.params_raw().len(), 2); + + // Second parameter should be a bigint + let param2 = bind.parameter(1).unwrap().unwrap(); + assert!(param2.bigint().is_some()); + } + + #[test] + fn test_immutable_plan_apply_bind() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Test that ImmutableRewritePlan can also apply to binds via Deref + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 1 }], + } + .freeze(); + + let mut bind = Bind::default(); + plan.apply_bind(&mut bind).unwrap(); + + assert_eq!(bind.params_raw().len(), 1); + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/execute.rs b/pgdog/src/frontend/router/rewrite/prepared/execute.rs index b59dbed6a..9071c3460 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/execute.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/execute.rs @@ -67,13 +67,13 @@ mod test { // First prepare the statement let mut prepare_rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Context::new(&prepare_stmt, None, None); + let mut input = Context::new(&prepare_stmt, None); prepare_rewrite.rewrite(&mut input).unwrap(); // Now execute it let execute_stmt = pg_query::parse("EXECUTE test(1, 2, 3)").unwrap().protobuf; let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Context::new(&execute_stmt, None, None); + let mut input = Context::new(&execute_stmt, None); execute_rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -87,7 +87,7 @@ mod test { .protobuf; let prepared_statements = PreparedStatements::default(); let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); - let mut input = Context::new(&execute_stmt, None, None); + let mut input = Context::new(&execute_stmt, None); let result = execute_rewrite.rewrite(&mut input); assert!(result.is_err()); } diff --git a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs index d0f9f8d66..f4b63c2a1 100644 --- a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs +++ b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs @@ -64,7 +64,7 @@ mod test { .protobuf; let mut prepared_statements = PreparedStatements::default(); let mut rewrite = PrepareRewrite::new(&mut prepared_statements); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs index 28b22fc1a..a91595ff5 100644 --- a/pgdog/src/frontend/router/rewrite/request.rs +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -70,7 +70,7 @@ impl<'a> RewriteRequest<'a> { { let cluster_stats = self.cluster.stats(); let mut lock = cluster_stats.lock(); - lock.rewrite = lock.rewrite.clone() + stats; + lock.rewrite = lock.rewrite + stats; } let ast = ParseResult::new(ast, "".into()); @@ -117,7 +117,7 @@ impl<'a> RewriteRequest<'a> { { let cluster_stats = self.cluster.stats(); let mut lock = cluster_stats.lock(); - lock.rewrite = lock.rewrite.clone() + stats; + lock.rewrite = lock.rewrite + stats; } let ast = ParseResult::new(ast, "".into()); diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs index 643fab6f8..13e2a8b31 100644 --- a/pgdog/src/frontend/router/rewrite/state.rs +++ b/pgdog/src/frontend/router/rewrite/state.rs @@ -5,10 +5,7 @@ use std::collections::HashMap; use bytes::Bytes; use super::{Error, ImmutableRewritePlan}; -use crate::{ - frontend::router::parser::cache::CachedAst, - net::{Bind, Parse}, -}; +use crate::net::{Bind, Parse}; #[derive(Debug, Default, Clone)] pub struct RewriteState { diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs index 4123ef0a7..26d5fce1b 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -203,7 +203,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -221,7 +221,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -240,7 +240,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -258,7 +258,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -271,7 +271,7 @@ mod test { fn test_explain_no_unique_id() { let stmt = pg_query::parse(r#"EXPLAIN SELECT 1"#).unwrap().protobuf; let mut rewrite = ExplainUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs index f112ddb2a..fd74e6cf9 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -112,7 +112,7 @@ impl RewriteModule for InsertUniqueIdRewrite { #[cfg(test)] mod test { use super::*; - use crate::net::bind::{Bind, Parameter}; + use crate::net::Parse; use std::env::set_var; #[test] @@ -130,7 +130,7 @@ mod test { .unwrap() .protobuf; let mut insert = InsertUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); insert.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); let query = output.query().unwrap(); @@ -147,29 +147,14 @@ mod test { unsafe { set_var("NODE_ID", "pgdog-prod-1"); } - let stmt = pg_query::parse( - r#" + let query = r#" INSERT INTO omnisharded (id, settings) VALUES (pgdog.unique_id(), $1::JSONB), - (pgdog.unique_id(), $2::JSONB)"#, - ) - .unwrap() - .protobuf; - let bind = Bind::new_params( - "", - &[ - Parameter { - len: 2, - data: "{}".into(), - }, - Parameter { - len: 2, - data: "{}".into(), - }, - ], - ); - let mut input = Context::new(&stmt, Some(&bind), None); + (pgdog.unique_id(), $2::JSONB)"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); InsertUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); @@ -178,5 +163,10 @@ mod test { output.query().unwrap(), "INSERT INTO omnisharded (id, settings) VALUES ($3::bigint, $1::jsonb), ($4::bigint, $2::jsonb)" ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 2); + assert_eq!(plan.unique_ids[0].param_ref, 3); + assert_eq!(plan.unique_ids[1].param_ref, 4); } } diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs index ab5324d44..bf37aa0df 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/select.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -221,7 +221,7 @@ impl RewriteModule for SelectUniqueIdRewrite { #[cfg(test)] mod test { use super::*; - use crate::net::{bind::Parameter, Bind}; + use crate::net::Parse; use std::env::set_var; #[test] @@ -233,7 +233,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); println!("output: {}", output.query().unwrap()); @@ -241,21 +241,14 @@ mod test { } #[test] - fn test_unique_id_select_with_bind() { + fn test_unique_id_select_with_parse() { unsafe { set_var("NODE_ID", "pgdog-prod-1"); } - let stmt = pg_query::parse(r#"SELECT pgdog.unique_id() AS id, $1 AS name"#) - .unwrap() - .protobuf; - let bind = Bind::new_params( - "", - &[Parameter { - len: 4, - data: "test".into(), - }], - ); - let mut input = Context::new(&stmt, Some(&bind), None); + let query = r#"SELECT pgdog.unique_id() AS id, $1 AS name"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); SelectUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); @@ -264,6 +257,10 @@ mod test { output.query().unwrap(), "SELECT $2::bigint AS id, $1 AS name" ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 1); + assert_eq!(plan.unique_ids[0].param_ref, 2); } #[test] @@ -276,7 +273,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -291,7 +288,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -308,7 +305,7 @@ mod test { .unwrap() .protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -318,7 +315,7 @@ mod test { fn test_no_rewrite_when_no_unique_id() { let stmt = pg_query::parse(r#"SELECT id FROM users"#).unwrap().protobuf; let mut rewrite = SelectUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); rewrite.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(matches!(output, super::super::super::StepOutput::NoOp)); diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs index 0bf2c013d..4dd1fb643 100644 --- a/pgdog/src/frontend/router/rewrite/unique_id/update.rs +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -99,7 +99,7 @@ impl RewriteModule for UpdateUniqueIdRewrite { #[cfg(test)] mod test { use super::*; - use crate::net::{bind::Parameter, Bind}; + use crate::net::Parse; use std::env::set_var; #[test] @@ -112,7 +112,7 @@ mod test { .unwrap() .protobuf; let mut update = UpdateUniqueIdRewrite::default(); - let mut input = Context::new(&stmt, None, None); + let mut input = Context::new(&stmt, None); update.rewrite(&mut input).unwrap(); let output = input.build().unwrap(); assert!(!output.query().unwrap().contains("pgdog.unique_id")); @@ -123,25 +123,11 @@ mod test { unsafe { set_var("NODE_ID", "pgdog-prod-1"); } - let stmt = pg_query::parse( - r#"UPDATE omnisharded SET id = pgdog.unique_id(), settings = $1 WHERE old_id = $2"#, - ) - .unwrap() - .protobuf; - let bind = Bind::new_params( - "", - &[ - Parameter { - len: 2, - data: "{}".into(), - }, - Parameter { - len: 3, - data: "123".into(), - }, - ], - ); - let mut input = Context::new(&stmt, Some(&bind), None); + let query = + r#"UPDATE omnisharded SET id = pgdog.unique_id(), settings = $1 WHERE old_id = $2"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); UpdateUniqueIdRewrite::default() .rewrite(&mut input) .unwrap(); @@ -150,5 +136,9 @@ mod test { output.query().unwrap(), "UPDATE omnisharded SET id = $3::bigint, settings = $1 WHERE old_id = $2" ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 1); + assert_eq!(plan.unique_ids[0].param_ref, 3); } } From 9a9e9f435e3cacecd37d13863a32cec2ace51185 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 12:16:11 -0800 Subject: [PATCH 18/23] fix --- pgdog/src/frontend/client/query_engine/route_query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index 7424a30bd..685b1727d 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -17,7 +17,7 @@ impl QueryEngine { if !in_transaction && !cluster.online() { // Reload cluster config. - self.backend.safe_reload().await.is_err() { + if self.backend.safe_reload().await.is_err() { return Some(ErrorResponse::connection( &identifier.user, &identifier.database, From d7543815a272eb40e140f9e593acedba8c4889bc Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 12:27:47 -0800 Subject: [PATCH 19/23] thats why we write tests --- pgdog/src/frontend/router/rewrite/request.rs | 5 +++-- pgdog/src/frontend/router/rewrite/state.rs | 13 ++++++------- pgdog/src/util.rs | 3 +++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs index a91595ff5..62ad0cf96 100644 --- a/pgdog/src/frontend/router/rewrite/request.rs +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -141,8 +141,9 @@ impl<'a> RewriteRequest<'a> { } let parameters = self.request.parameters_mut()?; if let Some(parameters) = parameters { - let plan = self.state.activate_plan(parameters)?; - plan.apply_bind(parameters)?; + if let Some(plan) = self.state.activate_plan(parameters) { + plan.apply_bind(parameters)?; + } } Ok(ast) diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs index 13e2a8b31..7ee932875 100644 --- a/pgdog/src/frontend/router/rewrite/state.rs +++ b/pgdog/src/frontend/router/rewrite/state.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use bytes::Bytes; -use super::{Error, ImmutableRewritePlan}; +use super::ImmutableRewritePlan; use crate::net::{Bind, Parse}; #[derive(Debug, Default, Clone)] @@ -25,17 +25,16 @@ impl RewriteState { } /// Activate plan for Bind, or error out if plan doesn't exist. - pub fn activate_plan(&mut self, bind: &Bind) -> Result<&ImmutableRewritePlan, Error> { + pub fn activate_plan(&mut self, bind: &Bind) -> Option<&ImmutableRewritePlan> { if let Some(plan) = self.plans.get(bind.statement_ref()) { self.active_plan = Some(plan.clone()); - self.plan() - } else { - Err(Error::NoRewrite) } + + self.plan() } /// Get currently active rewrite plan. - pub fn plan(&self) -> Result<&ImmutableRewritePlan, Error> { - self.active_plan.as_ref().ok_or(Error::NoActiveRewritePlan) + pub fn plan(&self) -> Option<&ImmutableRewritePlan> { + self.active_plan.as_ref() } } diff --git a/pgdog/src/util.rs b/pgdog/src/util.rs index 0e970a2b9..7615ff31d 100644 --- a/pgdog/src/util.rs +++ b/pgdog/src/util.rs @@ -166,6 +166,9 @@ mod test { #[test] fn test_instance_id_format() { + unsafe { + remove_var("NODE_ID"); + } let id = instance_id(); assert_eq!(id.len(), 8); // All characters should be valid hex digits (0-9, a-f) From 2224845e6805a75ca7a638b4a730085194428e98 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 13:43:12 -0800 Subject: [PATCH 20/23] Actually make it work --- pgdog/src/frontend/client/query_engine/mod.rs | 22 +------- .../query_engine/prepared_statements.rs | 22 ++++++-- .../prepared_statements/global_cache.rs | 42 ++++++++++++++- .../frontend/router/parser/rewrite_plan.rs | 16 ++++-- pgdog/src/frontend/router/rewrite/mod.rs | 2 - pgdog/src/frontend/router/rewrite/plan.rs | 6 +-- pgdog/src/frontend/router/rewrite/request.rs | 51 ++++++++++++++----- pgdog/src/frontend/router/rewrite/state.rs | 40 --------------- 8 files changed, 114 insertions(+), 87 deletions(-) delete mode 100644 pgdog/src/frontend/router/rewrite/state.rs diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index c42b50d3f..142048109 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -3,11 +3,7 @@ use crate::{ config::config, frontend::{ client::query_engine::hooks::QueryEngineHooks, - router::{ - parser::Shard, - rewrite::{RewriteRequest, RewriteState}, - Route, - }, + router::{parser::Shard, rewrite::RewriteRequest, Route}, BufferedQuery, Client, Command, Comms, Error, Router, RouterContext, Stats, }, net::{BackendKeyData, ErrorResponse, Message, Parameters}, @@ -61,7 +57,6 @@ pub struct QueryEngine { notify_buffer: NotifyBuffer, pending_explain: Option, hooks: QueryEngineHooks, - rewrite_state: RewriteState, } impl QueryEngine { @@ -125,25 +120,12 @@ impl QueryEngine { return Ok(QueryEngineOutput::Executed); } - if let Ok(cluster) = self.backend.cluster() { - if cluster.use_parser() && context.ast.is_none() { - // Execute request rewrite, if needed. - let mut rewrite = RewriteRequest::new( - context.client_request, - self.backend.cluster()?, - context.prepared_statements, - &mut self.rewrite_state, - ); - context.ast = rewrite.execute()?; - } - } - self.stats .received(context.client_request.total_message_len()); self.set_state(State::Active); // Client is active. // Rewrite prepared statements. - self.rewrite_extended(context)?; + self.rewrite(context)?; // Intercept commands we don't have to forward to a server. if self.intercept_incomplete(context).await? { diff --git a/pgdog/src/frontend/client/query_engine/prepared_statements.rs b/pgdog/src/frontend/client/query_engine/prepared_statements.rs index 8fac4d2a6..775734e25 100644 --- a/pgdog/src/frontend/client/query_engine/prepared_statements.rs +++ b/pgdog/src/frontend/client/query_engine/prepared_statements.rs @@ -4,10 +4,8 @@ use super::*; impl QueryEngine { /// Rewrite extended protocol messages. - pub(super) fn rewrite_extended( - &mut self, - context: &mut QueryEngineContext<'_>, - ) -> Result<(), Error> { + pub(super) fn rewrite(&mut self, context: &mut QueryEngineContext<'_>) -> Result<(), Error> { + // Rewrite prepared statements to use global names. for message in context.client_request.iter_mut() { if message.extended() { let level = context.prepared_statements.level; @@ -20,6 +18,22 @@ impl QueryEngine { } } } + + // Rewrite the statement itself. + if let Ok(cluster) = self.backend.cluster() { + if cluster.use_parser() && context.ast.is_none() { + // Execute request rewrite, if needed. + let mut rewrite = RewriteRequest::new( + context.client_request, + cluster, + context.prepared_statements, + ); + context.ast = rewrite.execute()?; + } + } + + println!("after: {:#?}", context.client_request); + Ok(()) } } diff --git a/pgdog/src/frontend/prepared_statements/global_cache.rs b/pgdog/src/frontend/prepared_statements/global_cache.rs index 281c02372..3358108b2 100644 --- a/pgdog/src/frontend/prepared_statements/global_cache.rs +++ b/pgdog/src/frontend/prepared_statements/global_cache.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use crate::{ - frontend::router::parser::RewritePlan, + frontend::router::{parser::RewritePlan, rewrite::ImmutableRewritePlan}, net::messages::{Parse, RowDescription}, stats::memory::MemoryUsage, }; @@ -17,13 +17,26 @@ fn global_name(counter: usize) -> String { #[derive(Debug, Clone)] pub struct Statement { + /// The rewritten Parse message. parse: Parse, + /// Saved RowDescription received from a backend. row_description: Option, + /// Version used to force-insert identical prepared + /// statements into the cache. #[allow(dead_code)] version: usize, + /// Rewrite plan used by the query parser. + /// TODO: Move this to the `crate::frontend::rewrite` module. rewrite_plan: Option, + /// The statement cache key, used to remove it from the other HashMap. cache_key: CacheKey, + /// Remove the statement from the cache upon closing it. + /// This is for one-time rewrites done by the parser rewrite module. + /// TODO: Remove this once code is moved + /// to the `crate::frontend::rewrite` module. evict_on_close: bool, + /// Rewrite engine. + rewrite_engine_plan: Option, } impl MemoryUsage for Statement { @@ -176,6 +189,7 @@ impl GlobalCache { rewrite_plan: None, cache_key, evict_on_close: false, + rewrite_engine_plan: None, }, ); @@ -215,6 +229,7 @@ impl GlobalCache { rewrite_plan: None, cache_key: key, evict_on_close: false, + rewrite_engine_plan: None, }, ); @@ -252,10 +267,35 @@ impl GlobalCache { } } + /// Get a copy of the rewrite plan. pub fn rewrite_plan(&self, name: &str) -> Option { self.names.get(name).and_then(|s| s.rewrite_plan.clone()) } + /// Get a mutable reference to the rewrite plan. + pub fn rewrite_plan_mut(&mut self, name: &str) -> Option<&mut RewritePlan> { + self.names + .get_mut(name) + .and_then(|statement| statement.rewrite_plan.as_mut()) + } + + /// Get the rewrite module rewrite plan by statement name. + pub fn rewrite_engine_plan(&self, name: &str) -> Option { + self.names + .get(name) + .and_then(|statement| statement.rewrite_engine_plan.clone()) + } + + /// Set the rewrite module rewrite plan for the given statement. + pub fn set_rewrite(&mut self, parse: &Parse, plan: ImmutableRewritePlan) { + if let Some(statement) = self.names.get_mut(parse.name()) { + if statement.rewrite_engine_plan.is_none() { + statement.rewrite_engine_plan = Some(plan); + statement.parse = parse.clone(); + } + } + } + pub fn reset(&mut self) { self.statements.clear(); self.names.clear(); diff --git a/pgdog/src/frontend/router/parser/rewrite_plan.rs b/pgdog/src/frontend/router/parser/rewrite_plan.rs index 030ea0cb6..c3eed0f26 100644 --- a/pgdog/src/frontend/router/parser/rewrite_plan.rs +++ b/pgdog/src/frontend/router/parser/rewrite_plan.rs @@ -24,31 +24,39 @@ pub struct RewritePlan { } impl RewritePlan { + /// Create new rewrite plan. pub fn new() -> Self { - Self { - drop_columns: Vec::new(), - helpers: Vec::new(), - } + Self::default() } + /// The plan doesn't do anything. + /// + /// N.B. This is a noop used inside the parser. As we move the + /// rewrite logic to its own rewrite module, the no_op will be updated + /// to include `unique_ids`. + /// pub fn is_noop(&self) -> bool { self.drop_columns.is_empty() && self.helpers.is_empty() } + /// Get column positions that should be removed from DataRows. pub fn drop_columns(&self) -> &[usize] { &self.drop_columns } + /// Add column to be removed from DataRows. pub fn add_drop_column(&mut self, column: usize) { if !self.drop_columns.contains(&column) { self.drop_columns.push(column); } } + /// Get per-column helpers, used for aggregation. pub fn helpers(&self) -> &[HelperMapping] { &self.helpers } + /// Add per-column aggregate helper. pub fn add_helper(&mut self, mapping: HelperMapping) { self.helpers.push(mapping); } diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs index f3a8ac002..9f9232207 100644 --- a/pgdog/src/frontend/router/rewrite/mod.rs +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -14,7 +14,6 @@ pub mod output; pub mod plan; pub mod prepared; pub mod request; -pub mod state; pub mod stats; pub mod unique_id; @@ -24,7 +23,6 @@ pub use interface::RewriteModule; pub use output::{RewriteAction, StepOutput}; pub use plan::{ImmutableRewritePlan, RewritePlan, UniqueIdPlan}; pub use request::RewriteRequest; -pub use state::RewriteState; use crate::frontend::PreparedStatements; diff --git a/pgdog/src/frontend/router/rewrite/plan.rs b/pgdog/src/frontend/router/rewrite/plan.rs index 654093823..f3c064b42 100644 --- a/pgdog/src/frontend/router/rewrite/plan.rs +++ b/pgdog/src/frontend/router/rewrite/plan.rs @@ -6,19 +6,19 @@ use crate::{ unique_id::UniqueId, }; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub struct UniqueIdPlan { /// Parameter number. pub(super) param_ref: i32, } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub struct RewritePlan { /// How many unique IDs to add to the Bind message. pub(super) unique_ids: Vec, } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub struct ImmutableRewritePlan { /// Compiled rewrite plan, that cannot be modified further. pub(super) plan: Arc, diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs index 62ad0cf96..be9e0cbef 100644 --- a/pgdog/src/frontend/router/rewrite/request.rs +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -1,11 +1,14 @@ use pg_query::ParseResult; use tracing::debug; -use super::{Context, Error, Rewrite, RewriteModule, RewriteState, StepOutput}; +use super::{Context, Error, Rewrite, RewriteModule, StepOutput}; use crate::{ backend::Cluster, frontend::{ - router::parser::{cache::CachedAst, Cache}, + router::{ + parser::{cache::CachedAst, Cache}, + rewrite::ImmutableRewritePlan, + }, ClientRequest, PreparedStatements, }, net::{Protocol, ProtocolMessage}, @@ -15,7 +18,7 @@ pub struct RewriteRequest<'a> { request: &'a mut ClientRequest, cluster: &'a Cluster, prepared_statements: &'a mut PreparedStatements, - state: &'a mut RewriteState, + plan: Option, } impl<'a> RewriteRequest<'a> { @@ -24,13 +27,12 @@ impl<'a> RewriteRequest<'a> { request: &'a mut ClientRequest, cluster: &'a Cluster, prepared_statements: &'a mut PreparedStatements, - state: &'a mut RewriteState, ) -> Self { Self { request, cluster, prepared_statements, - state, + plan: None, } } @@ -50,7 +52,10 @@ impl<'a> RewriteRequest<'a> { let output = context.build()?; let ast = match output { - StepOutput::NoOp => ast, + StepOutput::NoOp => { + debug!("rewrite (extended) is a no-op"); + ast + } StepOutput::RewriteInPlace { actions, ast, @@ -60,12 +65,23 @@ impl<'a> RewriteRequest<'a> { } => { debug!("rewrite (extended): {}", stmt); - self.state.save_plan(Some(parse), plan); - for action in actions { action.execute(self.request); + + // Update the rewritten parse in the global cache + // and save its rewrite plan. + if let ProtocolMessage::Parse(parse) = action.message { + if !parse.anonymous() { + self.prepared_statements + .global + .write() + .set_rewrite(&parse, plan.clone()); + } + } } + self.plan = Some(plan); + // Update stats. { let cluster_stats = self.cluster.stats(); @@ -97,18 +113,19 @@ impl<'a> RewriteRequest<'a> { let output = context.build()?; let ast = match output { - StepOutput::NoOp => ast, + StepOutput::NoOp => { + debug!("rewrite (simple) is a no-op"); + ast + } StepOutput::RewriteInPlace { actions, ast, stmt, stats, - plan, + plan: _, } => { debug!("rewrite (simple): {}", stmt); - self.state.save_plan(None, plan); - for action in actions { action.execute(self.request); } @@ -141,7 +158,15 @@ impl<'a> RewriteRequest<'a> { } let parameters = self.request.parameters_mut()?; if let Some(parameters) = parameters { - if let Some(plan) = self.state.activate_plan(parameters) { + if self.plan.is_none() { + self.plan = self + .prepared_statements + .global + .read() + .rewrite_engine_plan(parameters.statement()); + } + + if let Some(plan) = self.plan.take() { plan.apply_bind(parameters)?; } } diff --git a/pgdog/src/frontend/router/rewrite/state.rs b/pgdog/src/frontend/router/rewrite/state.rs deleted file mode 100644 index 7ee932875..000000000 --- a/pgdog/src/frontend/router/rewrite/state.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Rewrite engine state. To be preserved between requests. - -use std::collections::HashMap; - -use bytes::Bytes; - -use super::ImmutableRewritePlan; -use crate::net::{Bind, Parse}; - -#[derive(Debug, Default, Clone)] -pub struct RewriteState { - plans: HashMap, - active_plan: Option, -} - -impl RewriteState { - /// Save rewrite plan for later use and active it for - /// this request. - pub fn save_plan(&mut self, parse: Option<&Parse>, plan: ImmutableRewritePlan) { - if let Some(parse) = parse { - self.plans.insert(parse.name_ref(), plan.clone()); - } - - self.active_plan = Some(plan); - } - - /// Activate plan for Bind, or error out if plan doesn't exist. - pub fn activate_plan(&mut self, bind: &Bind) -> Option<&ImmutableRewritePlan> { - if let Some(plan) = self.plans.get(bind.statement_ref()) { - self.active_plan = Some(plan.clone()); - } - - self.plan() - } - - /// Get currently active rewrite plan. - pub fn plan(&self) -> Option<&ImmutableRewritePlan> { - self.active_plan.as_ref() - } -} From f0b3ba60294937d609cdf82443db9c125463be7a Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 14:03:03 -0800 Subject: [PATCH 21/23] Remove println --- pgdog/src/frontend/client/query_engine/mod.rs | 2 +- .../frontend/client/query_engine/prepared_statements.rs | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 142048109..29bc00cdb 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -125,7 +125,7 @@ impl QueryEngine { self.set_state(State::Active); // Client is active. // Rewrite prepared statements. - self.rewrite(context)?; + self.rewrite_request(context)?; // Intercept commands we don't have to forward to a server. if self.intercept_incomplete(context).await? { diff --git a/pgdog/src/frontend/client/query_engine/prepared_statements.rs b/pgdog/src/frontend/client/query_engine/prepared_statements.rs index 775734e25..db740b8eb 100644 --- a/pgdog/src/frontend/client/query_engine/prepared_statements.rs +++ b/pgdog/src/frontend/client/query_engine/prepared_statements.rs @@ -4,7 +4,10 @@ use super::*; impl QueryEngine { /// Rewrite extended protocol messages. - pub(super) fn rewrite(&mut self, context: &mut QueryEngineContext<'_>) -> Result<(), Error> { + pub(super) fn rewrite_request( + &mut self, + context: &mut QueryEngineContext<'_>, + ) -> Result<(), Error> { // Rewrite prepared statements to use global names. for message in context.client_request.iter_mut() { if message.extended() { @@ -32,8 +35,6 @@ impl QueryEngine { } } - println!("after: {:#?}", context.client_request); - Ok(()) } } From 15b44cb1ba6a1ff4fbeae0a0fc053ac3fd6fa2a4 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 14:16:50 -0800 Subject: [PATCH 22/23] Make unique_id work in CI --- integration/common.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration/common.sh b/integration/common.sh index 8001131ed..e37c084d8 100644 --- a/integration/common.sh +++ b/integration/common.sh @@ -4,6 +4,8 @@ # correctly. # COMMON_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export NODE_ID=pgdog-dev-0 + function wait_for_pgdog() { echo "Waiting for PgDog" while ! pg_isready -h 127.0.0.1 -p 6432 -U pgdog -d pgdog > /dev/null; do From 29502af64a2f91470597320c32e69e6b2a07e7de Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 14:30:56 -0800 Subject: [PATCH 23/23] Test bind rewrite --- integration/ruby/pg_spec.rb | 9 ++++ .../query_engine/prepared_statements.rs | 2 + pgdog/src/net/messages/bind.rs | 51 +++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/integration/ruby/pg_spec.rb b/integration/ruby/pg_spec.rb index 5076ffdce..621278633 100644 --- a/integration/ruby/pg_spec.rb +++ b/integration/ruby/pg_spec.rb @@ -111,4 +111,13 @@ def connect(dbname = 'pgdog', user = 'pgdog') expect(res[0]['one']).to eq('2') end end + + it 'unique_id' do + conn = connect "pgdog_sharded" + 100.times do |i| + res = conn.exec "SELECT pgdog.unique_id() AS id, $1 AS counter", [i] + expect(res[0]["id"].to_i).to be > 0 + expect(res[0]["counter"].to_i).to eq(i) + end + end end diff --git a/pgdog/src/frontend/client/query_engine/prepared_statements.rs b/pgdog/src/frontend/client/query_engine/prepared_statements.rs index db740b8eb..33115df46 100644 --- a/pgdog/src/frontend/client/query_engine/prepared_statements.rs +++ b/pgdog/src/frontend/client/query_engine/prepared_statements.rs @@ -35,6 +35,8 @@ impl QueryEngine { } } + println!("req: {:#?}", context.client_request); + Ok(()) } } diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index 3cb076c1b..2c7aacfc1 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -226,6 +226,7 @@ impl Bind { len: bytes.len() as i32, data: bytes, }); + self.original = None; // Param codes are 1-indexed. Ok(self.params.len() as i32) } @@ -244,6 +245,7 @@ impl Bind { } } self.params.push(param.parameter.clone()); + self.original = None; Ok(self.params.len() as i32) } @@ -506,4 +508,53 @@ mod test { assert_eq!(decoded.statement(), "__pgdog_large"); assert_eq!(bytes.len(), decoded.len()); } + + #[test] + fn test_add_parameter_produces_correct_bind() { + // Start with an existing Bind message parsed from bytes + let original_bind = Bind::new_params_codes( + "original_stmt", + &[Parameter::new(b"original_value")], + &[Format::Text], + ); + let original_bytes = original_bind.to_bytes().unwrap(); + let mut bind = Bind::from_bytes(original_bytes.clone()).unwrap(); + + // Verify original is set after parsing + assert!(bind.original.is_some()); + + // Add a new parameter - this should clear original + bind.add_parameter(Datum::Text("added_param".to_string())) + .unwrap(); + + // Verify original is now None + assert!(bind.original.is_none()); + + // Serialize to bytes + let new_bytes = bind.to_bytes().unwrap(); + + // The new bytes should be different from original (new param added) + assert_ne!(original_bytes, new_bytes); + + // Deserialize and verify the message is correct + let decoded = Bind::from_bytes(new_bytes.clone()).unwrap(); + + // Verify statement name preserved + assert_eq!(decoded.statement(), "original_stmt"); + + // Verify we have 2 parameters now (original + added) + assert_eq!(decoded.params_raw().len(), 2); + + // Verify first parameter (original) + let param0 = decoded.parameter(0).unwrap().unwrap(); + assert_eq!(param0.text(), Some("original_value")); + + // Verify second parameter (added) + let param1 = decoded.parameter(1).unwrap().unwrap(); + assert_eq!(param1.text(), Some("added_param")); + + // Verify round-trip produces identical bytes + let re_encoded = decoded.to_bytes().unwrap(); + assert_eq!(new_bytes, re_encoded); + } }