diff --git a/CLAUDE.md b/CLAUDE.md index 5e70dfc80..0cba01475 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,45 +7,15 @@ # Code style -Use standard Rust code style. Use `cargo fmt` to reformat code automatically after every edit. -Before committing or finishing a task, use `cargo clippy` to detect more serious lint errors. - -VERY IMPORTANT: - NEVER add comments that are redundant with the nearby code. - ALWAYS be sure comments document "Why", not "What", the code is doing. - ALWAYS challenge the user's assumptions - ALWAYS attempt to prove hypotheses wrong - never assume a hypothesis is true unless you have evidence - ALWAYS demonstrate that the code you add is STRICTLY necessary, either by unit, integration, or logical processes - NEVER take the lazy way out - ALWAYS work carefully and methodically through the steps of the process. - NEVER use quick fixes. Always carefully work through the problem unless specifically asked. - ALWAYS Ask clarifying questions before implementing - ALWAYS Break large tasks into single-session chunks - -VERY IMPORTANT: you are to act as a detective, attempting to find ways to falsify the code or planning we've done by discovering gaps or inconsistencies. ONLY write code when it is absolutely required to pass tests, the build, or typecheck. - -VERY IMPORTANT: NEVER comment out code or skip tests unless specifically requested by the user - -## Principles -- **Data first**: Define types before implementation -- **Small Modules**: Try to keep files under 200 lines, unless required by implementation. NEVER allow files to exceed 1000 lines unless specifically instructed. +- Use standard Rust code style. +- Use `cargo fmt` to reformat code automatically after every edit. +- Don't write functions with many arguments: create a struct and use that as input instead. # Workflow - Prefer to run individual tests with `cargo nextest run --test-threads=1 --no-fail-fast `. This is much faster. -- A local PostgreSQL server is required for some tests to pass. Ensure it is set up, and if necessary create a database called "pgdog", and create a user called "pgdog" with password "pgdog". -- Focus on files in `./pgdog` and `./integration` - other files are LOWEST priority - -## Test-Driven Development (TDD) - STRICT ENFORCEMENT -- **MANDATORY WORKFLOW - NO EXCEPTIONS:** - 1. Write exactly ONE test that fails - 2. Write ONLY the minimal code to make that test pass - 3. Refactor if needed (tests must still pass) - 4. Return to step 1 for next test -- **CRITICAL RULES:** - - NO implementation code without a failing test first - - NO untested code is allowed to exist - - Every line of production code must be justified by a test +- A local PostgreSQL server is required for some tests to pass. Assume it's running, if not, stop and ask the user to start it. +- Coe # About the project diff --git a/integration/common.sh b/integration/common.sh index 8001131ed..e37c084d8 100644 --- a/integration/common.sh +++ b/integration/common.sh @@ -4,6 +4,8 @@ # correctly. # COMMON_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export NODE_ID=pgdog-dev-0 + function wait_for_pgdog() { echo "Waiting for PgDog" while ! pg_isready -h 127.0.0.1 -p 6432 -U pgdog -d pgdog > /dev/null; do diff --git a/integration/ruby/pg_spec.rb b/integration/ruby/pg_spec.rb index 5076ffdce..621278633 100644 --- a/integration/ruby/pg_spec.rb +++ b/integration/ruby/pg_spec.rb @@ -111,4 +111,13 @@ def connect(dbname = 'pgdog', user = 'pgdog') expect(res[0]['one']).to eq('2') end end + + it 'unique_id' do + conn = connect "pgdog_sharded" + 100.times do |i| + res = conn.exec "SELECT pgdog.unique_id() AS id, $1 AS counter", [i] + expect(res[0]["id"].to_i).to be > 0 + expect(res[0]["counter"].to_i).to eq(i) + end + end end diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 4c9229a39..5f68ececf 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -20,3 +20,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unique_id; diff --git a/integration/rust/tests/integration/unique_id.rs b/integration/rust/tests/integration/unique_id.rs new file mode 100644 index 000000000..725c7f197 --- /dev/null +++ b/integration/rust/tests/integration/unique_id.rs @@ -0,0 +1,37 @@ +use rust::setup::connections_sqlx; +use sqlx::{Executor, Row}; + +#[tokio::test] +async fn unique_id_returns_bigint() -> Result<(), Box> { + let conns = connections_sqlx().await; + let sharded = conns.get(1).cloned().unwrap(); + + // Simple query + let row = sharded.fetch_one("SELECT pgdog.unique_id() AS id").await?; + let mut id: i64 = row.get("id"); + + assert!( + id > 0, + "unique_id should return a positive bigint, got {id}" + ); + + for _ in 0..100 { + // Prepared statement + let row = sqlx::query("SELECT pgdog.unique_id() AS id") + .fetch_one(&sharded) + .await?; + let prepared_id: i64 = row.get("id"); + assert!( + prepared_id > 0, + "prepared unique_id should return a positive bigint, got {prepared_id}" + ); + + assert!( + prepared_id > id, + "prepared id should be greater than simple query id" + ); + id = prepared_id; + } + + Ok(()) +} diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index a43c34437..fe58b1e17 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -195,8 +195,7 @@ impl Config { /// Organize all databases by name for quicker retrieval. pub fn databases(&self) -> HashMap>> { let mut databases = HashMap::new(); - let mut number = 0; - for database in &self.databases { + for (number, database) in self.databases.iter().enumerate() { let entry = databases .entry(database.name.clone()) .or_insert_with(Vec::new); @@ -210,7 +209,6 @@ impl Config { number, database: database.clone(), }); - number += 1; } databases } @@ -323,7 +321,7 @@ impl Config { ); } } else { - pooler_mode.insert(database.name.clone(), database.pooler_mode.clone()); + pooler_mode.insert(database.name.clone(), database.pooler_mode); } } @@ -352,17 +350,15 @@ impl Config { _ => (), } - if !self.general.two_phase_commit { - if self.rewrite.enabled { - if self.rewrite.shard_key == RewriteMode::Rewrite { - warn!("rewrite.shard_key=rewrite will apply non-atomic shard-key rewrites; enabling two_phase_commit is strongly recommended" + if !self.general.two_phase_commit && self.rewrite.enabled { + if self.rewrite.shard_key == RewriteMode::Rewrite { + warn!("rewrite.shard_key=rewrite will apply non-atomic shard-key rewrites; enabling two_phase_commit is strongly recommended" ); - } + } - if self.rewrite.split_inserts == RewriteMode::Rewrite { - warn!("rewrite.split_inserts=rewrite may commit partial multi-row INSERTs; enabling two_phase_commit is strongly recommended" + if self.rewrite.split_inserts == RewriteMode::Rewrite { + warn!("rewrite.split_inserts=rewrite may commit partial multi-row INSERTs; enabling two_phase_commit is strongly recommended" ); - } } } } diff --git a/pgdog-config/src/sharding.rs b/pgdog-config/src/sharding.rs index 35dcf3bbf..8e3121c44 100644 --- a/pgdog-config/src/sharding.rs +++ b/pgdog-config/src/sharding.rs @@ -213,7 +213,7 @@ impl ShardedSchema { } pub fn name(&self) -> &str { - self.name.as_ref().map(|name| name.as_str()).unwrap_or("*") + self.name.as_deref().unwrap_or("*") } pub fn shard(&self) -> Option { diff --git a/pgdog-config/src/url.rs b/pgdog-config/src/url.rs index 8440ff9bf..6d746fccb 100644 --- a/pgdog-config/src/url.rs +++ b/pgdog-config/src/url.rs @@ -136,7 +136,7 @@ impl ConfigAndUsers { let mirroring = mirror_strs .iter() - .map(|s| Mirroring::from_str(s).map_err(|e| Error::ParseError(e))) + .map(|s| Mirroring::from_str(s).map_err(Error::ParseError)) .collect::, _>>()?; self.config.mirroring = mirroring; diff --git a/pgdog/src/admin/mod.rs b/pgdog/src/admin/mod.rs index 521119c8c..3e4732d4c 100644 --- a/pgdog/src/admin/mod.rs +++ b/pgdog/src/admin/mod.rs @@ -30,6 +30,7 @@ pub mod show_pools; pub mod show_prepared_statements; pub mod show_query_cache; pub mod show_replication; +pub mod show_rewrite; pub mod show_server_memory; pub mod show_servers; pub mod show_stats; diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 5d94ab445..16a2ee929 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -7,7 +7,7 @@ use super::{ show_client_memory::ShowClientMemory, show_clients::ShowClients, show_config::ShowConfig, show_instance_id::ShowInstanceId, show_lists::ShowLists, show_mirrors::ShowMirrors, show_peers::ShowPeers, show_pools::ShowPools, show_prepared_statements::ShowPreparedStatements, - show_query_cache::ShowQueryCache, show_replication::ShowReplication, + show_query_cache::ShowQueryCache, show_replication::ShowReplication, show_rewrite::ShowRewrite, show_server_memory::ShowServerMemory, show_servers::ShowServers, show_stats::ShowStats, show_transactions::ShowTransactions, show_version::ShowVersion, shutdown::Shutdown, Command, Error, @@ -30,6 +30,7 @@ pub enum ParseResult { ShowStats(ShowStats), ShowTransactions(ShowTransactions), ShowMirrors(ShowMirrors), + ShowRewrite(ShowRewrite), ShowVersion(ShowVersion), ShowInstanceId(ShowInstanceId), SetupSchema(SetupSchema), @@ -65,6 +66,7 @@ impl ParseResult { ShowStats(show_stats) => show_stats.execute().await, ShowTransactions(show_transactions) => show_transactions.execute().await, ShowMirrors(show_mirrors) => show_mirrors.execute().await, + ShowRewrite(show_rewrite) => show_rewrite.execute().await, ShowVersion(show_version) => show_version.execute().await, ShowInstanceId(show_instance_id) => show_instance_id.execute().await, SetupSchema(setup_schema) => setup_schema.execute().await, @@ -100,6 +102,7 @@ impl ParseResult { ShowStats(show_stats) => show_stats.name(), ShowTransactions(show_transactions) => show_transactions.name(), ShowMirrors(show_mirrors) => show_mirrors.name(), + ShowRewrite(show_rewrite) => show_rewrite.name(), ShowVersion(show_version) => show_version.name(), ShowInstanceId(show_instance_id) => show_instance_id.name(), SetupSchema(setup_schema) => setup_schema.name(), @@ -163,6 +166,7 @@ impl Parser { "lists" => ParseResult::ShowLists(ShowLists::parse(&sql)?), "prepared" => ParseResult::ShowPrepared(ShowPreparedStatements::parse(&sql)?), "replication" => ParseResult::ShowReplication(ShowReplication::parse(&sql)?), + "rewrite" => ParseResult::ShowRewrite(ShowRewrite::parse(&sql)?), command => { debug!("unknown admin show command: '{}'", command); return Err(Error::Syntax); diff --git a/pgdog/src/admin/show_mirrors.rs b/pgdog/src/admin/show_mirrors.rs index 896e62e61..1ae61a546 100644 --- a/pgdog/src/admin/show_mirrors.rs +++ b/pgdog/src/admin/show_mirrors.rs @@ -35,7 +35,7 @@ impl Command for ShowMirrors { let counts = { let stats = cluster.stats(); let stats = stats.lock(); - stats.counts + stats.mirrors }; // Create a data row for this cluster diff --git a/pgdog/src/admin/show_rewrite.rs b/pgdog/src/admin/show_rewrite.rs new file mode 100644 index 000000000..df8d4cfbf --- /dev/null +++ b/pgdog/src/admin/show_rewrite.rs @@ -0,0 +1,49 @@ +//! SHOW REWRITE - per-cluster rewrite statistics + +use crate::backend::databases::databases; + +use super::prelude::*; + +pub struct ShowRewrite; + +#[async_trait] +impl Command for ShowRewrite { + fn name(&self) -> String { + "SHOW REWRITE".into() + } + + fn parse(_: &str) -> Result { + Ok(Self) + } + + async fn execute(&self) -> Result, Error> { + let fields = vec![ + Field::text("database"), + Field::text("user"), + Field::numeric("parse"), + Field::numeric("bind"), + Field::numeric("simple"), + ]; + + let mut messages = vec![RowDescription::new(&fields).message()?]; + + for (user, cluster) in databases().all() { + let rewrite = { + let stats = cluster.stats(); + let stats = stats.lock(); + stats.rewrite + }; + + let mut dr = DataRow::new(); + dr.add(user.database.as_str()) + .add(user.user.as_str()) + .add(rewrite.parse as i64) + .add(rewrite.bind as i64) + .add(rewrite.simple as i64); + + messages.push(dr.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/tests/mod.rs b/pgdog/src/admin/tests/mod.rs index a5fda216d..57f709d04 100644 --- a/pgdog/src/admin/tests/mod.rs +++ b/pgdog/src/admin/tests/mod.rs @@ -1,6 +1,6 @@ use crate::admin::Command; use crate::backend::databases::{databases, from_config, replace_databases, Databases}; -use crate::backend::pool::mirror_stats::Counts; +use crate::backend::pool::cluster_stats::MirrorStats; use crate::config::{self, ConfigAndUsers, Database, Role, User as ConfigUser}; use crate::net::messages::{DataRow, DataType, FromBytes, Protocol, RowDescription}; @@ -226,7 +226,7 @@ async fn show_mirrors_reports_counts() { { let cluster_stats = cluster.stats(); let mut stats = cluster_stats.lock(); - stats.counts = Counts { + stats.mirrors = MirrorStats { total_count: 5, mirrored_count: 4, dropped_count: 1, diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index afacd902c..5ed91bf9f 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -359,9 +359,9 @@ impl Databases { for cluster in self.all().values() { cluster.launch(); - if cluster.pooler_mode() == PoolerMode::Session && cluster.router_needed() { + if cluster.pooler_mode() == PoolerMode::Session && cluster.use_parser() { warn!( - r#"user "{}" for database "{}" requires transaction mode to route queries"#, + r#"user "{}" for database "{}" requires transaction mode to parse and route queries"#, cluster.user(), cluster.name() ); diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 1b5445530..b0c606e9a 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -26,7 +26,7 @@ use crate::{ net::{messages::BackendKeyData, Query}, }; -use super::{Address, Config, Error, Guard, MirrorStats, Request, Shard, ShardConfig}; +use super::{Address, ClusterStats, Config, Error, Guard, Request, Shard, ShardConfig}; use crate::config::LoadBalancingStrategy; #[derive(Clone, Debug, Default)] @@ -53,7 +53,7 @@ pub struct Cluster { multi_tenant: Option, rw_strategy: ReadWriteStrategy, schema_admin: bool, - stats: Arc>, + stats: Arc>, cross_shard_disabled: bool, two_phase_commit: bool, two_phase_commit_auto: bool, @@ -245,7 +245,7 @@ impl Cluster { multi_tenant: multi_tenant.clone(), rw_strategy, schema_admin, - stats: Arc::new(Mutex::new(MirrorStats::default())), + stats: Arc::new(Mutex::new(ClusterStats::default())), cross_shard_disabled, two_phase_commit: two_pc && shards.len() > 1, two_phase_commit_auto: two_pc_auto && shards.len() > 1, @@ -409,14 +409,18 @@ impl Cluster { self.schema_admin = owner; } - pub fn stats(&self) -> Arc> { + pub fn stats(&self) -> Arc> { self.stats.clone() } - /// We'll need the query router to figure out - /// where a query should go. - pub fn router_needed(&self) -> bool { + /// We need to parse the query using pg_query. + pub fn use_parser(&self) -> bool { !(self.shards().len() == 1 && (self.read_only() || self.write_only())) + || self.query_parser_enabled + || self.multi_tenant.is_some() + || self.pub_sub_enabled() + || self.prepared_statements() == &PreparedStatements::Full + || self.dry_run } /// Multi-tenant config. diff --git a/pgdog/src/backend/pool/mirror_stats.rs b/pgdog/src/backend/pool/cluster_stats.rs similarity index 82% rename from pgdog/src/backend/pool/mirror_stats.rs rename to pgdog/src/backend/pool/cluster_stats.rs index a31d857b7..7d540ced3 100644 --- a/pgdog/src/backend/pool/mirror_stats.rs +++ b/pgdog/src/backend/pool/cluster_stats.rs @@ -5,8 +5,10 @@ use std::{ ops::{Add, Div, Sub}, }; +use crate::frontend::router::rewrite::stats::RewriteStats; + #[derive(Debug, Clone, Default, Copy)] -pub struct Counts { +pub struct MirrorStats { pub total_count: usize, pub mirrored_count: usize, pub dropped_count: usize, @@ -14,8 +16,8 @@ pub struct Counts { pub queue_length: usize, } -impl Sub for Counts { - type Output = Counts; +impl Sub for MirrorStats { + type Output = MirrorStats; fn sub(self, rhs: Self) -> Self::Output { Self { @@ -28,8 +30,8 @@ impl Sub for Counts { } } -impl Div for Counts { - type Output = Counts; +impl Div for MirrorStats { + type Output = MirrorStats; fn div(self, rhs: usize) -> Self::Output { Self { @@ -42,11 +44,11 @@ impl Div for Counts { } } -impl Add for Counts { - type Output = Counts; +impl Add for MirrorStats { + type Output = MirrorStats; - fn add(self, rhs: Counts) -> Self::Output { - Counts { + fn add(self, rhs: MirrorStats) -> Self::Output { + MirrorStats { total_count: self.total_count + rhs.total_count, mirrored_count: self.mirrored_count + rhs.mirrored_count, dropped_count: self.dropped_count + rhs.dropped_count, @@ -56,9 +58,9 @@ impl Add for Counts { } } -impl Sum for Counts { +impl Sum for MirrorStats { fn sum>(iter: I) -> Self { - let mut result = Counts::default(); + let mut result = MirrorStats::default(); for next in iter { result = result + next; } @@ -68,8 +70,9 @@ impl Sum for Counts { } #[derive(Debug, Clone, Default, Copy)] -pub struct MirrorStats { - pub counts: Counts, +pub struct ClusterStats { + pub mirrors: MirrorStats, + pub rewrite: RewriteStats, } #[cfg(test)] @@ -78,16 +81,16 @@ mod tests { #[test] fn test_queue_length_default_is_zero() { - let stats = MirrorStats::default(); + let stats = ClusterStats::default(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 by default" ); } #[test] fn test_queue_length_arithmetic_operations() { - let counts1 = Counts { + let counts1 = MirrorStats { total_count: 10, mirrored_count: 5, dropped_count: 3, @@ -95,7 +98,7 @@ mod tests { queue_length: 7, }; - let counts2 = Counts { + let counts2 = MirrorStats { total_count: 5, mirrored_count: 3, dropped_count: 1, @@ -127,7 +130,7 @@ mod tests { #[test] fn test_queue_length_saturating_sub() { - let counts1 = Counts { + let counts1 = MirrorStats { total_count: 10, mirrored_count: 5, dropped_count: 3, @@ -135,7 +138,7 @@ mod tests { queue_length: 3, }; - let counts2 = Counts { + let counts2 = MirrorStats { total_count: 5, mirrored_count: 3, dropped_count: 1, diff --git a/pgdog/src/backend/pool/connection/mirror/handler.rs b/pgdog/src/backend/pool/connection/mirror/handler.rs index 92c2e406c..86147b276 100644 --- a/pgdog/src/backend/pool/connection/mirror/handler.rs +++ b/pgdog/src/backend/pool/connection/mirror/handler.rs @@ -4,7 +4,7 @@ //! use super::*; -use crate::backend::pool::MirrorStats; +use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; @@ -35,7 +35,7 @@ pub struct MirrorHandler { /// Request timer, to simulate delays between queries. timer: Instant, /// Reference to cluster stats for tracking mirror metrics. - stats: Arc>, + stats: Arc>, } impl MirrorHandler { @@ -45,7 +45,7 @@ impl MirrorHandler { } /// Create new mirror handle with exposure. - pub fn new(tx: Sender, exposure: f32, stats: Arc>) -> Self { + pub fn new(tx: Sender, exposure: f32, stats: Arc>) -> Self { Self { tx, exposure, @@ -139,44 +139,44 @@ impl MirrorHandler { /// Increment the total request count. pub fn increment_total_count(&self) { let mut stats = self.stats.lock(); - stats.counts.total_count += 1; + stats.mirrors.total_count += 1; } /// Increment the mirrored request count. pub fn increment_mirrored_count(&self) { let mut stats = self.stats.lock(); - stats.counts.mirrored_count += 1; + stats.mirrors.mirrored_count += 1; } /// Increment the dropped request count. pub fn increment_dropped_count(&self) { let mut stats = self.stats.lock(); - stats.counts.dropped_count += 1; + stats.mirrors.dropped_count += 1; } /// Increment the error count. pub fn increment_error_count(&self) { let mut stats = self.stats.lock(); - stats.counts.error_count += 1; + stats.mirrors.error_count += 1; } /// Increment the queue length. pub fn increment_queue_length(&self) { let mut stats = self.stats.lock(); - stats.counts.queue_length += 1; + stats.mirrors.queue_length += 1; } /// Decrement the queue length. pub fn decrement_queue_length(&self) { let mut stats = self.stats.lock(); - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); } } #[cfg(test)] mod tests { use super::*; - use crate::backend::pool::MirrorStats; + use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; use tokio::sync::mpsc::{channel, Receiver}; @@ -185,22 +185,22 @@ mod tests { exposure: f32, ) -> ( MirrorHandler, - Arc>, + Arc>, Receiver, ) { let (tx, rx) = channel(1000); // Keep receiver to prevent channel closure - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let handler = MirrorHandler::new(tx, exposure, stats.clone()); (handler, stats, rx) } - fn get_stats_counts(stats: &Arc>) -> (usize, usize, usize, usize) { + fn get_stats_counts(stats: &Arc>) -> (usize, usize, usize, usize) { let stats = stats.lock(); ( - stats.counts.total_count, - stats.counts.mirrored_count, - stats.counts.dropped_count, - stats.counts.error_count, + stats.mirrors.total_count, + stats.mirrors.mirrored_count, + stats.mirrors.dropped_count, + stats.mirrors.error_count, ) } @@ -344,7 +344,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 initially" ); } @@ -358,7 +358,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should still be 0 before flush" ); } @@ -368,7 +368,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 1, + stats.mirrors.queue_length, 1, "queue_length should be 1 after flush" ); } @@ -388,7 +388,7 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should be 0 initially" ); } @@ -403,17 +403,17 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should remain 0 for dropped transactions" ); - assert_eq!(stats.counts.dropped_count, 1, "dropped_count should be 1"); + assert_eq!(stats.mirrors.dropped_count, 1, "dropped_count should be 1"); } } #[test] fn test_queue_length_with_channel_overflow() { let (tx, _rx) = channel(1); // Channel with capacity of 1 - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let mut handler = MirrorHandler::new(tx, 1.0, stats.clone()); // Fill the channel @@ -428,11 +428,11 @@ mod tests { { let stats = stats.lock(); assert_eq!( - stats.counts.queue_length, 1, + stats.mirrors.queue_length, 1, "queue_length should be 1 (first successful send)" ); assert_eq!( - stats.counts.error_count, 1, + stats.mirrors.error_count, 1, "error_count should be 1 due to overflow" ); } @@ -441,16 +441,16 @@ mod tests { #[test] fn test_queue_length_never_negative() { // Test to ensure queue_length never goes negative even with mismatched increment/decrement - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); // Manually try to decrement without incrementing (should use saturating_sub) // This will be tested more thoroughly once decrement_queue_length is implemented { let mut stats = stats.lock(); // Simulating a decrement when queue is already 0 - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); assert_eq!( - stats.counts.queue_length, 0, + stats.mirrors.queue_length, 0, "queue_length should not go negative" ); } diff --git a/pgdog/src/backend/pool/connection/mirror/mod.rs b/pgdog/src/backend/pool/connection/mirror/mod.rs index 28363dea4..7acc88a07 100644 --- a/pgdog/src/backend/pool/connection/mirror/mod.rs +++ b/pgdog/src/backend/pool/connection/mirror/mod.rs @@ -119,14 +119,14 @@ impl Mirror { // Decrement queue_length when we receive a message from the channel { let mut stats = stats_for_errors.lock(); - stats.counts.queue_length = stats.counts.queue_length.saturating_sub(1); + stats.mirrors.queue_length = stats.mirrors.queue_length.saturating_sub(1); } // TODO: timeout these. if let Err(err) = mirror.handle(&mut req, &mut query_engine).await { error!("mirror error: {}", err); // Increment error count on mirror handling error let mut stats = stats_for_errors.lock(); - stats.counts.error_count += 1; + stats.mirrors.error_count += 1; } } else { debug!("mirror client shutting down"); @@ -170,12 +170,12 @@ mod test { #[tokio::test] async fn test_mirror_exposure() { - use crate::backend::pool::MirrorStats; + use crate::backend::pool::ClusterStats; use parking_lot::Mutex; use std::sync::Arc; let (tx, rx) = channel(25); - let stats = Arc::new(Mutex::new(MirrorStats::default())); + let stats = Arc::new(Mutex::new(ClusterStats::default())); let mut handle = MirrorHandler::new(tx.clone(), 1.0, stats.clone()); for _ in 0..25 { @@ -189,7 +189,7 @@ mod test { assert_eq!(rx.len(), 25); let (tx, rx) = channel(25); - let stats2 = Arc::new(Mutex::new(MirrorStats::default())); + let stats2 = Arc::new(Mutex::new(ClusterStats::default())); let mut handle = MirrorHandler::new(tx.clone(), 0.5, stats2); let dropped = (0..25) .into_iter() @@ -270,7 +270,7 @@ mod test { let initial_stats = { let stats_arc = cluster.stats(); let stats = stats_arc.lock(); - stats.counts + stats.mirrors }; let mut mirror = Mirror::spawn("pgdog", &cluster, None).unwrap(); @@ -290,7 +290,7 @@ mod test { let final_stats = { let stats_arc = cluster.stats(); let stats = stats_arc.lock(); - stats.counts + stats.mirrors }; assert_eq!( diff --git a/pgdog/src/backend/pool/mod.rs b/pgdog/src/backend/pool/mod.rs index b4f54d2cd..e79d47b32 100644 --- a/pgdog/src/backend/pool/mod.rs +++ b/pgdog/src/backend/pool/mod.rs @@ -3,6 +3,7 @@ pub mod address; pub mod cleanup; pub mod cluster; +pub mod cluster_stats; pub mod comms; pub mod config; pub mod connection; @@ -13,7 +14,6 @@ pub mod healthcheck; pub mod inner; pub mod lsn_monitor; pub mod mapping; -pub mod mirror_stats; pub mod monitor; pub mod oids; pub mod pool_impl; @@ -27,13 +27,13 @@ pub mod waiting; pub use address::Address; pub use cluster::{Cluster, ClusterConfig, ClusterShardConfig, PoolConfig, ShardingSchema}; +pub use cluster_stats::ClusterStats; pub use config::Config; pub use connection::Connection; pub use error::Error; pub use guard::Guard; pub use healthcheck::Healtcheck; pub use lsn_monitor::LsnStats; -pub use mirror_stats::MirrorStats; pub use monitor::Monitor; pub use oids::Oids; pub use pool_impl::Pool; diff --git a/pgdog/src/backend/replication/logical/subscriber/context.rs b/pgdog/src/backend/replication/logical/subscriber/context.rs index da7ed70ac..85c6ca785 100644 --- a/pgdog/src/backend/replication/logical/subscriber/context.rs +++ b/pgdog/src/backend/replication/logical/subscriber/context.rs @@ -57,6 +57,7 @@ impl<'a> StreamContext<'a> { &self.params, None, 1, + None, )?) } } diff --git a/pgdog/src/frontend/client/query_engine/context.rs b/pgdog/src/frontend/client/query_engine/context.rs index 87ede236a..368ba1bcc 100644 --- a/pgdog/src/frontend/client/query_engine/context.rs +++ b/pgdog/src/frontend/client/query_engine/context.rs @@ -4,6 +4,7 @@ use crate::{ backend::pool::{connection::mirror::Mirror, stats::MemoryStats}, frontend::{ client::{timeouts::Timeouts, TransactionType}, + router::parser::cache::CachedAst, Client, ClientRequest, PreparedStatements, }, net::{BackendKeyData, Parameters, Stream}, @@ -38,6 +39,8 @@ pub struct QueryEngineContext<'a> { pub(super) rollback: bool, /// Omnisharded modulo. pub(super) omni_sticky_index: usize, + /// Query AST. + pub(super) ast: Option, } impl<'a> QueryEngineContext<'a> { @@ -58,6 +61,7 @@ impl<'a> QueryEngineContext<'a> { requests_left: 0, rollback: false, omni_sticky_index: client.omni_sticky_index, + ast: None, } } @@ -83,6 +87,7 @@ impl<'a> QueryEngineContext<'a> { requests_left: 0, rollback: false, omni_sticky_index: thread_rng().gen_range(1..usize::MAX), + ast: None, } } diff --git a/pgdog/src/frontend/client/query_engine/internal_values.rs b/pgdog/src/frontend/client/query_engine/internal_values.rs index 7d7245932..4721d30b4 100644 --- a/pgdog/src/frontend/client/query_engine/internal_values.rs +++ b/pgdog/src/frontend/client/query_engine/internal_values.rs @@ -32,7 +32,7 @@ impl QueryEngine { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result<(), Error> { - let id = unique_id::UniqueId::generator()?.next_id().await; + let id = unique_id::UniqueId::generator()?.next_id(); let bytes_sent = context .stream .send_many(&[ diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 9ff2c4381..29bc00cdb 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -3,7 +3,7 @@ use crate::{ config::config, frontend::{ client::query_engine::hooks::QueryEngineHooks, - router::{parser::Shard, Route}, + router::{parser::Shard, rewrite::RewriteRequest, Route}, BufferedQuery, Client, Command, Comms, Error, Router, RouterContext, Stats, }, net::{BackendKeyData, ErrorResponse, Message, Parameters}, @@ -22,6 +22,7 @@ pub mod incomplete_requests; pub mod insert_split; pub mod internal_values; pub mod notify_buffer; +pub mod output; pub mod prepared_statements; pub mod pub_sub; pub mod query; @@ -37,6 +38,7 @@ pub mod unknown_command; use self::query::ExplainResponseState; pub use context::QueryEngineContext; use notify_buffer::NotifyBuffer; +pub use output::QueryEngineOutput; pub use two_pc::phase::TwoPcPhase; use two_pc::TwoPc; @@ -108,25 +110,34 @@ impl QueryEngine { } /// Handle client request. - pub async fn handle(&mut self, context: &mut QueryEngineContext<'_>) -> Result<(), Error> { + pub async fn handle( + &mut self, + context: &mut QueryEngineContext<'_>, + ) -> Result { + // Check that we have the latest version of the config. + if let Some(error) = self.ensure_cluster(context.in_transaction()).await { + self.error_response(context, error).await?; + return Ok(QueryEngineOutput::Executed); + } + self.stats .received(context.client_request.total_message_len()); self.set_state(State::Active); // Client is active. // Rewrite prepared statements. - self.rewrite_extended(context)?; + self.rewrite_request(context)?; // Intercept commands we don't have to forward to a server. if self.intercept_incomplete(context).await? { self.update_stats(context); - return Ok(()); + return Ok(QueryEngineOutput::Executed); } // Route transaction to the right servers. if !self.route_transaction(context).await? { self.update_stats(context); debug!("transaction has nowhere to go"); - return Ok(()); + return Ok(QueryEngineOutput::Executed); } self.hooks.before_execution(context)?; @@ -221,8 +232,10 @@ impl QueryEngine { self.set_route(context, route.clone()).await?; } Command::Copy(_) => self.execute(context, &route).await?, - Command::Rewrite(query) => { - context.client_request.rewrite(query)?; + Command::Rewrite(requests) => { + for request in requests { + request.execute(context.client_request); + } self.execute(context, &route).await?; } Command::InsertSplit(plan) => self.insert_split(context, *plan.clone()).await?, @@ -246,7 +259,7 @@ impl QueryEngine { self.update_stats(context); - Ok(()) + Ok(QueryEngineOutput::Executed) } fn update_stats(&mut self, context: &mut QueryEngineContext<'_>) { diff --git a/pgdog/src/frontend/client/query_engine/output.rs b/pgdog/src/frontend/client/query_engine/output.rs new file mode 100644 index 000000000..6256ec44d --- /dev/null +++ b/pgdog/src/frontend/client/query_engine/output.rs @@ -0,0 +1,10 @@ +use crate::frontend::ClientRequest; + +#[derive(Debug, Clone)] +pub enum QueryEngineOutput { + // The request has been executed as-is. + Executed, + // The request has been rewritten and needs to + // be resent. + Rewritten(Vec), +} diff --git a/pgdog/src/frontend/client/query_engine/prepared_statements.rs b/pgdog/src/frontend/client/query_engine/prepared_statements.rs index 8fac4d2a6..33115df46 100644 --- a/pgdog/src/frontend/client/query_engine/prepared_statements.rs +++ b/pgdog/src/frontend/client/query_engine/prepared_statements.rs @@ -4,10 +4,11 @@ use super::*; impl QueryEngine { /// Rewrite extended protocol messages. - pub(super) fn rewrite_extended( + pub(super) fn rewrite_request( &mut self, context: &mut QueryEngineContext<'_>, ) -> Result<(), Error> { + // Rewrite prepared statements to use global names. for message in context.client_request.iter_mut() { if message.extended() { let level = context.prepared_statements.level; @@ -20,6 +21,22 @@ impl QueryEngine { } } } + + // Rewrite the statement itself. + if let Ok(cluster) = self.backend.cluster() { + if cluster.use_parser() && context.ast.is_none() { + // Execute request rewrite, if needed. + let mut rewrite = RewriteRequest::new( + context.client_request, + cluster, + context.prepared_statements, + ); + context.ast = rewrite.execute()?; + } + } + + println!("req: {:#?}", context.client_request); + Ok(()) } } diff --git a/pgdog/src/frontend/client/query_engine/query.rs b/pgdog/src/frontend/client/query_engine/query.rs index b60561c35..7802e9985 100644 --- a/pgdog/src/frontend/client/query_engine/query.rs +++ b/pgdog/src/frontend/client/query_engine/query.rs @@ -48,7 +48,10 @@ impl QueryEngine { } if let Some(sql) = route.rewritten_sql() { - match context.client_request.rewrite(&[Query::new(sql).into()]) { + match context + .client_request + .rewrite_simple(&[Query::new(sql).into()]) + { Ok(()) => (), Err(crate::net::Error::OnlySimpleForRewrites) => { context.client_request.rewrite_prepared( diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index 6dd9382f0..685b1727d 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -1,9 +1,41 @@ use pgdog_config::PoolerMode; use tracing::trace; +use crate::backend::Cluster; + use super::*; impl QueryEngine { + pub fn cluster(&self) -> Result<&Cluster, Error> { + Ok(self.backend.cluster()?) + } + + /// Check that cluster still exists. + pub async fn ensure_cluster(&mut self, in_transaction: bool) -> Option { + if let Ok(cluster) = self.backend.cluster() { + let identifier = cluster.identifier(); + + if !in_transaction && !cluster.online() { + // Reload cluster config. + if self.backend.safe_reload().await.is_err() { + return Some(ErrorResponse::connection( + &identifier.user, + &identifier.database, + )); + } + + if self.backend.cluster().is_err() { + return Some(ErrorResponse::connection( + &identifier.user, + &identifier.database, + )); + } + } + } + + None + } + pub(super) async fn route_transaction( &mut self, context: &mut QueryEngineContext<'_>, @@ -18,28 +50,7 @@ impl QueryEngine { // Admin doesn't have a cluster. let cluster = if let Ok(cluster) = self.backend.cluster() { - if !context.in_transaction() && !cluster.online() { - let identifier = cluster.identifier(); - - // Reload cluster config. - self.backend.safe_reload().await?; - - match self.backend.cluster() { - Ok(cluster) => cluster, - Err(_) => { - // Cluster is gone. - self.error_response( - context, - ErrorResponse::connection(&identifier.user, &identifier.database), - ) - .await?; - - return Ok(false); - } - } - } else { - cluster - } + cluster } else { return Ok(true); }; @@ -51,6 +62,7 @@ impl QueryEngine { context.params, context.transaction, context.omni_sticky_index, + context.ast.as_ref(), )?; match self.router.query(router_context) { Ok(cmd) => { diff --git a/pgdog/src/frontend/client_request.rs b/pgdog/src/frontend/client_request.rs index 825073174..5a90e2bbe 100644 --- a/pgdog/src/frontend/client_request.rs +++ b/pgdog/src/frontend/client_request.rs @@ -140,6 +140,17 @@ impl ClientRequest { Ok(None) } + /// Get mutable reference to parameters, if any. + pub fn parameters_mut(&mut self) -> Result, Error> { + for message in self.messages.iter_mut() { + if let ProtocolMessage::Bind(bind) = message { + return Ok(Some(bind)); + } + } + + Ok(None) + } + /// Get all CopyData messages. pub fn copy_data(&self) -> Result, Error> { let mut rows = vec![]; @@ -180,7 +191,7 @@ impl ClientRequest { } /// Rewrite query in buffer. - pub fn rewrite(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { + pub fn rewrite_simple(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { if self.messages.iter().any(|c| c.code() != 'Q') { return Err(Error::OnlySimpleForRewrites); } @@ -189,6 +200,22 @@ impl ClientRequest { Ok(()) } + pub fn rewrite_extended(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> { + for new_message in request { + if let Some(pos) = self + .messages + .iter() + .position(|p| p.code() == new_message.code()) + { + self.messages[pos] = new_message.clone(); + } else { + self.messages.insert(0, new_message.clone()); + } + } + + Ok(()) + } + /// Rewrite prepared statement SQL before sending it to the backend. pub fn rewrite_prepared( &mut self, diff --git a/pgdog/src/frontend/error.rs b/pgdog/src/frontend/error.rs index ac23367b0..c43409e69 100644 --- a/pgdog/src/frontend/error.rs +++ b/pgdog/src/frontend/error.rs @@ -50,6 +50,9 @@ pub enum Error { #[error("unique id: {0}")] UniqueId(#[from] unique_id::Error), + + #[error("rewrite: {0}")] + Rewrite(#[from] crate::frontend::router::rewrite::Error), } impl Error { diff --git a/pgdog/src/frontend/prepared_statements/global_cache.rs b/pgdog/src/frontend/prepared_statements/global_cache.rs index e66a90daf..3358108b2 100644 --- a/pgdog/src/frontend/prepared_statements/global_cache.rs +++ b/pgdog/src/frontend/prepared_statements/global_cache.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use crate::{ - frontend::router::parser::RewritePlan, + frontend::router::{parser::RewritePlan, rewrite::ImmutableRewritePlan}, net::messages::{Parse, RowDescription}, stats::memory::MemoryUsage, }; @@ -17,13 +17,26 @@ fn global_name(counter: usize) -> String { #[derive(Debug, Clone)] pub struct Statement { + /// The rewritten Parse message. parse: Parse, + /// Saved RowDescription received from a backend. row_description: Option, + /// Version used to force-insert identical prepared + /// statements into the cache. #[allow(dead_code)] version: usize, + /// Rewrite plan used by the query parser. + /// TODO: Move this to the `crate::frontend::rewrite` module. rewrite_plan: Option, + /// The statement cache key, used to remove it from the other HashMap. cache_key: CacheKey, + /// Remove the statement from the cache upon closing it. + /// This is for one-time rewrites done by the parser rewrite module. + /// TODO: Remove this once code is moved + /// to the `crate::frontend::rewrite` module. evict_on_close: bool, + /// Rewrite engine. + rewrite_engine_plan: Option, } impl MemoryUsage for Statement { @@ -176,6 +189,7 @@ impl GlobalCache { rewrite_plan: None, cache_key, evict_on_close: false, + rewrite_engine_plan: None, }, ); @@ -215,6 +229,7 @@ impl GlobalCache { rewrite_plan: None, cache_key: key, evict_on_close: false, + rewrite_engine_plan: None, }, ); @@ -252,10 +267,35 @@ impl GlobalCache { } } + /// Get a copy of the rewrite plan. pub fn rewrite_plan(&self, name: &str) -> Option { self.names.get(name).and_then(|s| s.rewrite_plan.clone()) } + /// Get a mutable reference to the rewrite plan. + pub fn rewrite_plan_mut(&mut self, name: &str) -> Option<&mut RewritePlan> { + self.names + .get_mut(name) + .and_then(|statement| statement.rewrite_plan.as_mut()) + } + + /// Get the rewrite module rewrite plan by statement name. + pub fn rewrite_engine_plan(&self, name: &str) -> Option { + self.names + .get(name) + .and_then(|statement| statement.rewrite_engine_plan.clone()) + } + + /// Set the rewrite module rewrite plan for the given statement. + pub fn set_rewrite(&mut self, parse: &Parse, plan: ImmutableRewritePlan) { + if let Some(statement) = self.names.get_mut(parse.name()) { + if statement.rewrite_engine_plan.is_none() { + statement.rewrite_engine_plan = Some(plan); + statement.parse = parse.clone(); + } + } + } + pub fn reset(&mut self) { self.statements.clear(); self.names.clear(); @@ -280,6 +320,16 @@ impl GlobalCache { self.names.get(name).map(|p| p.parse.clone()) } + /// Get global prepared statement name. + pub fn name(&self, parse: &Parse) -> Option { + let cache_key = CacheKey { + query: parse.query_ref(), + data_types: parse.data_types_ref(), + version: 0, + }; + self.statements.get(&cache_key).map(|stmt| stmt.name()) + } + /// Get the RowDescription message for the prepared statement. /// /// It can be used to decode results received from executing the prepared diff --git a/pgdog/src/frontend/prepared_statements/mod.rs b/pgdog/src/frontend/prepared_statements/mod.rs index ef0f4c9b8..2fd75418f 100644 --- a/pgdog/src/frontend/prepared_statements/mod.rs +++ b/pgdog/src/frontend/prepared_statements/mod.rs @@ -123,8 +123,7 @@ impl PreparedStatements { pub fn parse(&self, name: &str) -> Option { self.local .get(name) - .map(|name| self.global.read().parse(name)) - .flatten() + .and_then(|name| self.global.read().parse(name)) } /// Number of prepared statements in the local cache. diff --git a/pgdog/src/frontend/router/cli.rs b/pgdog/src/frontend/router/cli.rs index ceb16b5ed..c733c8a52 100644 --- a/pgdog/src/frontend/router/cli.rs +++ b/pgdog/src/frontend/router/cli.rs @@ -55,6 +55,7 @@ impl RouterCli { ¶ms, None, 1, + None, )?)?; result.push(cmd); } diff --git a/pgdog/src/frontend/router/context.rs b/pgdog/src/frontend/router/context.rs index f430cc0f8..cbef297be 100644 --- a/pgdog/src/frontend/router/context.rs +++ b/pgdog/src/frontend/router/context.rs @@ -1,7 +1,10 @@ use super::Error; use crate::{ backend::Cluster, - frontend::{client::TransactionType, BufferedQuery, ClientRequest, PreparedStatements}, + frontend::{ + client::TransactionType, router::parser::cache::CachedAst, BufferedQuery, ClientRequest, + PreparedStatements, + }, net::{Bind, Parameters}, }; @@ -27,6 +30,8 @@ pub struct RouterContext<'a> { pub two_pc: bool, /// Sticky omnisharded index. pub omni_sticky_index: usize, + /// Query ast. + pub ast: Option<&'a CachedAst>, } impl<'a> RouterContext<'a> { @@ -37,6 +42,7 @@ impl<'a> RouterContext<'a> { params: &'a Parameters, transaction: Option, omni_sticky_index: usize, + ast: Option<&'a CachedAst>, ) -> Result { let query = buffer.query()?; let bind = buffer.parameters()?; @@ -53,6 +59,7 @@ impl<'a> RouterContext<'a> { executable: buffer.executable(), two_pc: cluster.two_pc_enabled(), omni_sticky_index, + ast, }) } diff --git a/pgdog/src/frontend/router/mod.rs b/pgdog/src/frontend/router/mod.rs index eb78d3e9e..df5543cc2 100644 --- a/pgdog/src/frontend/router/mod.rs +++ b/pgdog/src/frontend/router/mod.rs @@ -5,6 +5,7 @@ pub mod context; pub mod copy; pub mod error; pub mod parser; +pub mod rewrite; pub mod round_robin; pub mod search_path; pub mod sharding; diff --git a/pgdog/src/frontend/router/parser/cache.rs b/pgdog/src/frontend/router/parser/cache.rs index 4156c8e66..0797c1ea4 100644 --- a/pgdog/src/frontend/router/parser/cache.rs +++ b/pgdog/src/frontend/router/parser/cache.rs @@ -55,7 +55,7 @@ pub struct CachedAstInner { /// Cached AST. pub ast: ParseResult, /// AST stats. - pub stats: Mutex, + pub stats: Arc>, } impl Deref for CachedAst { @@ -70,6 +70,15 @@ impl CachedAst { /// Create new cache entry from pg_query's AST. pub fn new(query: &str, schema: &ShardingSchema) -> std::result::Result { let ast = parse(query).map_err(super::Error::PgQuery)?; + Self::new_parsed(query, ast, schema) + } + + /// Create new cached ast entry with the AST already pre-parsed. + pub fn new_parsed( + query: &str, + ast: ParseResult, + schema: &ShardingSchema, + ) -> std::result::Result { let (shard, role) = comment(query, schema)?; Ok(Self { @@ -77,10 +86,10 @@ impl CachedAst { comment_shard: shard, comment_role: role, inner: Arc::new(CachedAstInner { - stats: Mutex::new(Stats { + stats: Arc::new(Mutex::new(Stats { hits: 1, ..Default::default() - }), + })), ast, }), }) @@ -185,6 +194,26 @@ impl Cache { debug!("ast cache size set to {}", capacity); } + /// Save a pre-parsed query into the cache. + pub fn save( + &self, + query: &str, + ast: ParseResult, + schema: &ShardingSchema, + ) -> std::result::Result { + if let Some(exists) = self.check_existing(query) { + return Ok(exists); + } + + let entry = CachedAst::new_parsed(query, ast, schema)?; + + let mut guard = self.inner.lock(); + guard.queries.put(query.to_owned(), entry.clone()); + guard.stats.misses += 1; + + Ok(entry) + } + /// Parse a statement by either getting it from cache /// or using pg_query parser. /// @@ -196,16 +225,8 @@ impl Cache { query: &str, schema: &ShardingSchema, ) -> std::result::Result { - { - let mut guard = self.inner.lock(); - let ast = guard.queries.get_mut(query).map(|entry| { - entry.stats.lock().hits += 1; // No contention on this. - entry.clone() - }); - if let Some(ast) = ast { - guard.stats.hits += 1; - return Ok(ast); - } + if let Some(exists) = self.check_existing(query) { + return Ok(exists); } // Parse query without holding lock. @@ -218,6 +239,21 @@ impl Cache { Ok(entry) } + /// Check existing entry and return it if exists. + fn check_existing(&self, query: &str) -> Option { + let mut guard = self.inner.lock(); + let ast = guard.queries.get_mut(query).map(|entry| { + entry.stats.lock().hits += 1; // No contention on this. + entry.clone() + }); + if let Some(ast) = ast { + guard.stats.hits += 1; + Some(ast) + } else { + None + } + } + /// Parse a statement but do not store it in the cache. pub fn parse_uncached( &self, diff --git a/pgdog/src/frontend/router/parser/command.rs b/pgdog/src/frontend/router/parser/command.rs index f924e8f29..da9e557e4 100644 --- a/pgdog/src/frontend/router/parser/command.rs +++ b/pgdog/src/frontend/router/parser/command.rs @@ -1,7 +1,7 @@ use super::*; use crate::{ - frontend::{client::TransactionType, BufferedQuery}, - net::{parameter::ParameterValue, ProtocolMessage}, + frontend::{client::TransactionType, router::rewrite::RewriteAction, BufferedQuery}, + net::parameter::ParameterValue, }; use lazy_static::lazy_static; @@ -28,7 +28,7 @@ pub enum Command { value: ParameterValue, }, PreparedStatement(Prepare), - Rewrite(Vec), + Rewrite(Vec), InternalField { name: String, value: String, diff --git a/pgdog/src/frontend/router/parser/context.rs b/pgdog/src/frontend/router/parser/context.rs index 62469e450..cccbbf4c7 100644 --- a/pgdog/src/frontend/router/parser/context.rs +++ b/pgdog/src/frontend/router/parser/context.rs @@ -2,7 +2,6 @@ use std::os::raw::c_void; -use pgdog_config::PreparedStatements as ConfigPreparedStatements; use pgdog_plugin::pg_query::protobuf::ParseResult; use pgdog_plugin::{PdParameters, PdRouterContext, PdStatement}; @@ -22,8 +21,6 @@ use super::Error; /// and its inputs. /// pub struct QueryParserContext<'a> { - /// whether query_parser_enabled has been set. - pub(super) query_parser_enabled: bool, /// Cluster is read-only, i.e. has no primary. pub(super) read_only: bool, /// Cluster has no replicas, only a primary. @@ -36,13 +33,9 @@ pub struct QueryParserContext<'a> { pub(super) router_context: RouterContext<'a>, /// How aggressively we want to send reads to replicas. pub(super) rw_strategy: &'a ReadWriteStrategy, - /// Are we re-writing prepared statements sent over the simple protocol? - pub(super) full_prepared_statements: bool, /// Do we need the router at all? Shortcut to bypass this for unsharded /// clusters with databases that only read or write. - pub(super) router_needed: bool, - /// Do we have support for LISTEN/NOTIFY enabled? - pub(super) pub_sub_enabled: bool, + pub(super) use_parser: bool, /// Are we running multi-tenant checks? pub(super) multi_tenant: &'a Option, /// Dry run enabled? @@ -66,11 +59,7 @@ impl<'a> QueryParserContext<'a> { shards: router_context.cluster.shards().len(), sharding_schema: router_context.cluster.sharding_schema(), rw_strategy: router_context.cluster.read_write_strategy(), - full_prepared_statements: router_context.cluster.prepared_statements() - == &ConfigPreparedStatements::Full, - query_parser_enabled: router_context.cluster.query_parser_enabled(), - router_needed: router_context.cluster.router_needed(), - pub_sub_enabled: router_context.cluster.pub_sub_enabled(), + use_parser: router_context.cluster.use_parser(), multi_tenant: router_context.cluster.multi_tenant(), dry_run: router_context.cluster.dry_run(), expanded_explain: router_context.cluster.expanded_explain(), @@ -98,12 +87,7 @@ impl<'a> QueryParserContext<'a> { /// /// Shortcut to avoid the overhead if we can. pub(super) fn use_parser(&self) -> bool { - self.query_parser_enabled - || self.full_prepared_statements - || self.router_needed - || self.pub_sub_enabled - || self.multi_tenant().is_some() - || self.dry_run + self.use_parser } /// Get the query we're parsing, if any. diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs index be04e0fbb..b15f4dc24 100644 --- a/pgdog/src/frontend/router/parser/insert.rs +++ b/pgdog/src/frontend/router/parser/insert.rs @@ -57,6 +57,10 @@ impl<'a> Insert<'a> { .unwrap_or(vec![]) } + pub fn stmt(&'a self) -> &'a InsertStmt { + self.stmt + } + /// Get table name, if specified (should always be). pub fn table(&'a self) -> Option> { self.stmt.relation.as_ref().map(Table::from) diff --git a/pgdog/src/frontend/router/parser/query/explain.rs b/pgdog/src/frontend/router/parser/query/explain.rs index 8b38261a9..fd6bc0a22 100644 --- a/pgdog/src/frontend/router/parser/query/explain.rs +++ b/pgdog/src/frontend/router/parser/query/explain.rs @@ -69,13 +69,14 @@ mod tests { // Helper function to route a plain SQL statement and return its `Route`. fn route(sql: &str) -> Route { enable_expanded_explain(); - let buffer = ClientRequest::from(vec![Query::new(sql).into()]); + let mut buffer = ClientRequest::from(vec![Query::new(sql).into()]); let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = + RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1, None).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, @@ -96,13 +97,14 @@ mod tests { .collect::>(); let bind = Bind::new_params("", ¶meters); - let buffer: ClientRequest = vec![parse_msg.into(), bind.into()].into(); + let mut buffer: ClientRequest = vec![parse_msg.into(), bind.into()].into(); let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, None, 1).unwrap(); + let ctx = + RouterContext::new(&mut buffer, &cluster, &mut stmts, ¶ms, None, 1, None).unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index fd91b4d2b..f9e68c31d 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -7,7 +7,7 @@ use crate::{ frontend::{ router::{ context::RouterContext, - parser::{rewrite::Rewrite, OrderBy, Shard}, + parser::{OrderBy, Shard}, round_robin, sharding::{Centroids, ContextBuilder, Value as ShardingValue}, }, @@ -177,28 +177,32 @@ impl QueryParser { let cache = Cache::get(); // Get the AST from cache or parse the statement live. - let statement = match context.query()? { - // Only prepared statements (or just extended) are cached. - BufferedQuery::Prepared(query) => { - cache.parse(query.query(), &context.sharding_schema)? - } - // Don't cache simple queries. - // - // They contain parameter values, which makes the cache - // too large to be practical. - // - // Make your clients use prepared statements - // or at least send statements with placeholders using the - // extended protocol. - BufferedQuery::Query(query) => { - cache.parse_uncached(query.query(), &context.sharding_schema)? + let statement = if let Some(ast) = context.router_context.ast.cloned() { + ast // The AST was already parsed by the rewriter module. + } else { + match context.query()? { + // Only prepared statements (or just extended) are cached. + BufferedQuery::Prepared(query) => { + cache.parse(query.query(), &context.sharding_schema)? + } + // Don't cache simple queries. + // + // They contain parameter values, which makes the cache + // too large to be practical. + // + // Make your clients use prepared statements + // or at least send statements with placeholders using the + // extended protocol. + BufferedQuery::Query(query) => { + cache.parse_uncached(query.query(), &context.sharding_schema)? + } } }; self.ensure_explain_recorder(statement.ast(), context); // Parse hardcoded shard from a query comment. - if context.router_needed || context.dry_run { + if context.use_parser() { self.shard = statement.comment_shard.clone(); let role_override = statement.comment_role; if let Some(role) = role_override { @@ -218,12 +222,6 @@ impl QueryParser { debug!("{}", context.query()?.query()); trace!("{:#?}", statement); - let rewrite = Rewrite::new(statement.ast()); - if rewrite.needs_rewrite() { - debug!("rewrite needed"); - return rewrite.rewrite(context.prepared_statements()); - } - if let Some(multi_tenant) = context.multi_tenant() { debug!("running multi-tenant check"); diff --git a/pgdog/src/frontend/router/parser/query/show.rs b/pgdog/src/frontend/router/parser/query/show.rs index 9a26f3c06..032d243db 100644 --- a/pgdog/src/frontend/router/parser/query/show.rs +++ b/pgdog/src/frontend/router/parser/query/show.rs @@ -42,7 +42,7 @@ mod test_show { // First call let query = "SHOW TRANSACTION ISOLATION LEVEL"; let buffer = ClientRequest::from(vec![Query::new(query).into()]); - let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1).unwrap(); + let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1, None).unwrap(); let first = parser.parse(context).unwrap().clone(); let first_shard = first.route().shard(); @@ -51,7 +51,7 @@ mod test_show { // Second call let query = "SHOW TRANSACTION ISOLATION LEVEL"; let buffer = ClientRequest::from(vec![Query::new(query).into()]); - let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1).unwrap(); + let context = RouterContext::new(&buffer, &c, &mut ps, &p, None, 1, None).unwrap(); let second = parser.parse(context).unwrap().clone(); let second_shard = second.route().shard(); diff --git a/pgdog/src/frontend/router/parser/query/test.rs b/pgdog/src/frontend/router/parser/query/test.rs index 3a8d02e8c..aca076937 100644 --- a/pgdog/src/frontend/router/parser/query/test.rs +++ b/pgdog/src/frontend/router/parser/query/test.rs @@ -78,7 +78,7 @@ fn parse_query(query: &str) -> Command { let params = Parameters::default(); let context = - RouterContext::new(&client_request, &cluster, &mut stmt, ¶ms, None, 1).unwrap(); + RouterContext::new(&client_request, &cluster, &mut stmt, ¶ms, None, 1, None).unwrap(); let command = query_parser.parse(context).unwrap().clone(); command } @@ -107,6 +107,7 @@ macro_rules! command { ¶ms, transaction, 1, + None, ) .unwrap(); let command = query_parser.parse(context).unwrap().clone(); @@ -152,6 +153,7 @@ macro_rules! query_parser { ¶ms, maybe_transaction, 1, + None, ) .unwrap(); @@ -187,6 +189,7 @@ macro_rules! parse { &Parameters::default(), None, 1, + None, ) .unwrap(), ) @@ -209,8 +212,16 @@ fn parse_with_parameters(query: &str, params: Parameters) -> Result Result { ¶meters, None, 1, + None, ) .unwrap(); @@ -490,8 +502,16 @@ fn test_set() { let mut prep_stmts = PreparedStatements::default(); let params = Parameters::default(); let transaction = Some(TransactionType::ReadWrite); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, transaction, 1).unwrap(); + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + transaction, + 1, + None, + ) + .unwrap(); let mut context = QueryParserContext::new(router_context); for read_only in [true, false] { @@ -571,8 +591,16 @@ fn update_sharding_key_errors_by_default() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let result = QueryParser::default().parse(router_context); assert!( @@ -591,8 +619,16 @@ fn update_sharding_key_ignore_mode_allows() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let command = QueryParser::default().parse(router_context).unwrap(); assert!(matches!(command, Command::Query(_))); @@ -608,8 +644,16 @@ fn update_sharding_key_rewrite_mode_not_supported() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let result = QueryParser::default().parse(router_context); assert!( @@ -628,8 +672,16 @@ fn update_sharding_key_rewrite_plan_detected() { let params = Parameters::default(); let client_request: ClientRequest = vec![Query::new(query).into()].into(); let cluster = Cluster::new_test(); - let router_context = - RouterContext::new(&client_request, &cluster, &mut prep_stmts, ¶ms, None, 1).unwrap(); + let router_context = RouterContext::new( + &client_request, + &cluster, + &mut prep_stmts, + ¶ms, + None, + 1, + None, + ) + .unwrap(); let command = QueryParser::default().parse(router_context).unwrap(); match command { @@ -804,8 +856,16 @@ WHERE t2.account = ( .into()] .into(); let transaction = Some(TransactionType::ReadWrite); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, transaction, 1).unwrap(); + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + transaction, + 1, + None, + ) + .unwrap(); let mut context = QueryParserContext::new(router_context); let route = qp.query(&mut context).unwrap(); match route { @@ -866,7 +926,8 @@ fn test_close_direct_one_shard() { let params = Parameters::default(); let transaction = None; - let context = RouterContext::new(&buf, &cluster, &mut pp, ¶ms, transaction, 1).unwrap(); + let context = + RouterContext::new(&buf, &cluster, &mut pp, ¶ms, transaction, 1, None).unwrap(); let cmd = qp.parse(context).unwrap(); diff --git a/pgdog/src/frontend/router/parser/rewrite/mod.rs b/pgdog/src/frontend/router/parser/rewrite/mod.rs index 5bf8a3333..f33b12962 100644 --- a/pgdog/src/frontend/router/parser/rewrite/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/mod.rs @@ -1,132 +1,5 @@ -use pg_query::{NodeEnum, ParseResult}; - -use super::{Command, Error}; - mod insert_split; mod shard_key; -use crate::net::{Parse, ProtocolMessage}; -use crate::{frontend::PreparedStatements, net::Query}; pub use insert_split::{InsertSplitPlan, InsertSplitRow}; pub use shard_key::{Assignment, AssignmentValue, ShardKeyRewritePlan}; - -#[derive(Debug, Clone)] -pub struct Rewrite<'a> { - ast: &'a ParseResult, -} - -impl<'a> Rewrite<'a> { - pub fn new(ast: &'a ParseResult) -> Self { - Self { ast } - } - - /// Statement needs to be rewritten. - pub fn needs_rewrite(&self) -> bool { - for stmt in &self.ast.protobuf.stmts { - if let Some(ref stmt) = stmt.stmt { - if let Some(ref node) = stmt.node { - match node { - NodeEnum::PrepareStmt(_) => return true, - NodeEnum::ExecuteStmt(_) => return true, - NodeEnum::DeallocateStmt(_) => return true, - _ => (), - } - } - } - } - - false - } - - pub fn rewrite(&self, prepared_statements: &mut PreparedStatements) -> Result { - let mut ast = self.ast.protobuf.clone(); - - for stmt in &mut ast.stmts { - if let Some(ref mut stmt) = stmt.stmt { - if let Some(ref mut node) = stmt.node { - match node { - NodeEnum::PrepareStmt(ref mut stmt) => { - let statement = stmt - .query - .as_ref() - .ok_or(Error::EmptyQuery)? - .deparse() - .map_err(|_| Error::EmptyQuery)?; - - let mut parse = Parse::named(&stmt.name, &statement); - prepared_statements.insert_anyway(&mut parse); - stmt.name = parse.name().to_string(); - - return Ok(Command::Rewrite(vec![Query::new( - ast.deparse().map_err(|_| Error::EmptyQuery)?, - ) - .into()])); - } - - NodeEnum::ExecuteStmt(ref mut stmt) => { - let parse = prepared_statements.parse(&stmt.name); - if let Some(parse) = parse { - stmt.name = parse.name().to_string(); - - return Ok(Command::Rewrite(vec![ - ProtocolMessage::Prepare { - name: stmt.name.clone(), - statement: parse.query().to_string(), - }, - Query::new(ast.deparse().map_err(|_| Error::EmptyQuery)?) - .into(), - ])); - } else { - return Err(Error::PreparedStatementDoesntExist(stmt.name.clone())); - } - } - - NodeEnum::DeallocateStmt(_) => return Ok(Command::Deallocate), - - _ => (), - } - } - } - } - - Err(Error::EmptyQuery) - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use crate::net::{FromBytes, ToBytes}; - - use super::*; - - #[test] - fn test_rewrite_prepared() { - let ast = pg_query::parse("PREPARE test AS SELECT $1, $2, $3").unwrap(); - let rewrite = Rewrite::new(&ast); - assert!(rewrite.needs_rewrite()); - let mut prepared_statements = PreparedStatements::new(); - let queries = rewrite.rewrite(&mut prepared_statements).unwrap(); - match queries { - Command::Rewrite(messages) => { - let message = Query::from_bytes(messages[0].to_bytes().unwrap()).unwrap(); - assert_eq!(message.query(), "PREPARE __pgdog_1 AS SELECT $1, $2, $3"); - } - _ => panic!("not a rewrite"), - } - } - - #[test] - fn test_deallocate() { - for q in ["DEALLOCATE ALL", "DEALLOCATE test"] { - let ast = pg_query::parse(q).unwrap(); - let ast = Arc::new(ast); - let rewrite = Rewrite::new(&ast) - .rewrite(&mut PreparedStatements::new()) - .unwrap(); - - assert!(matches!(rewrite, Command::Deallocate)); - } - } -} diff --git a/pgdog/src/frontend/router/parser/rewrite_plan.rs b/pgdog/src/frontend/router/parser/rewrite_plan.rs index 030ea0cb6..c3eed0f26 100644 --- a/pgdog/src/frontend/router/parser/rewrite_plan.rs +++ b/pgdog/src/frontend/router/parser/rewrite_plan.rs @@ -24,31 +24,39 @@ pub struct RewritePlan { } impl RewritePlan { + /// Create new rewrite plan. pub fn new() -> Self { - Self { - drop_columns: Vec::new(), - helpers: Vec::new(), - } + Self::default() } + /// The plan doesn't do anything. + /// + /// N.B. This is a noop used inside the parser. As we move the + /// rewrite logic to its own rewrite module, the no_op will be updated + /// to include `unique_ids`. + /// pub fn is_noop(&self) -> bool { self.drop_columns.is_empty() && self.helpers.is_empty() } + /// Get column positions that should be removed from DataRows. pub fn drop_columns(&self) -> &[usize] { &self.drop_columns } + /// Add column to be removed from DataRows. pub fn add_drop_column(&mut self, column: usize) { if !self.drop_columns.contains(&column) { self.drop_columns.push(column); } } + /// Get per-column helpers, used for aggregation. pub fn helpers(&self) -> &[HelperMapping] { &self.helpers } + /// Add per-column aggregate helper. pub fn add_helper(&mut self, mapping: HelperMapping) { self.helpers.push(mapping); } diff --git a/pgdog/src/frontend/router/parser/value.rs b/pgdog/src/frontend/router/parser/value.rs index 81cc1350f..d9f9a0e08 100644 --- a/pgdog/src/frontend/router/parser/value.rs +++ b/pgdog/src/frontend/router/parser/value.rs @@ -17,7 +17,7 @@ pub enum Value<'a> { Null, Placeholder(i32), Vector(Vector), - Function(&'a str), + Function(std::string::String), } impl Value<'_> { @@ -76,13 +76,19 @@ impl<'a> TryFrom<&'a Option> for Value<'a> { Some(NodeEnum::AConst(a_const)) => Ok(a_const.into()), Some(NodeEnum::ParamRef(param_ref)) => Ok(Value::Placeholder(param_ref.number)), Some(NodeEnum::FuncCall(func)) => { - if let Some(Node { - node: Some(NodeEnum::String(sval)), - }) = func.funcname.first() - { - Ok(Value::Function(&sval.sval)) - } else { + let mut name = Vec::new(); + for comp in func.funcname.iter() { + if let Node { + node: Some(NodeEnum::String(sval)), + } = comp + { + name.push(sval.sval.to_string()); + } + } + if name.is_empty() { Ok(Value::Null) + } else { + Ok(Value::Function(name.join("."))) } } Some(NodeEnum::TypeCast(cast)) => { diff --git a/pgdog/src/frontend/router/rewrite/context.rs b/pgdog/src/frontend/router/rewrite/context.rs new file mode 100644 index 000000000..881c88cfc --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/context.rs @@ -0,0 +1,129 @@ +//! Context passed throughout the rewrite engine. + +use pg_query::protobuf::{ParseResult, RawStmt}; + +use super::{ + output::RewriteActionKind, stats::RewriteStats, Error, RewriteAction, RewritePlan, StepOutput, +}; +use crate::net::{Parse, ProtocolMessage, Query}; + +#[derive(Debug, Clone)] +pub struct Context<'a> { + // Most requeries won't require a rewrite. + // This is a clone-free way to check. + original: &'a ParseResult, + // If an in-place rewrite was done, the statement is saved here. + rewrite: Option, + /// Additional messages to add to the request. + result: Vec, + /// Extended protocol. + parse: Option<&'a Parse>, + /// Rewrite plan. + plan: RewritePlan, +} + +impl<'a> Context<'a> { + /// Create new input. + pub(super) fn new(original: &'a ParseResult, parse: Option<&'a Parse>) -> Self { + Self { + original, + rewrite: None, + result: vec![], + parse, + plan: RewritePlan::default(), + } + } + + /// Get Parse reference. + pub fn parse(&'a self) -> Option<&'a Parse> { + self.parse + } + + /// We are rewriting an extended protocol request. + pub fn extended(&self) -> bool { + self.parse().is_some() + } + + /// Get reference to rewrite plan for modification. + pub fn plan(&mut self) -> &mut RewritePlan { + &mut self.plan + } + + /// Get the original (or modified) statement. + pub fn stmt(&'a self) -> Result<&'a RawStmt, Error> { + let stmt = if let Some(ref rewrite) = self.rewrite { + rewrite + } else { + self.original + }; + let root = stmt.stmts.first().ok_or(Error::EmptyQuery)?; + Ok(root) + } + + /// Get the mutable statement we're rewriting. + pub fn stmt_mut(&mut self) -> Result<&mut RawStmt, Error> { + let stmt = if let Some(ref mut rewrite) = self.rewrite { + rewrite + } else { + self.rewrite = Some(self.original.clone()); + self.rewrite.as_mut().unwrap() + }; + + stmt.stmts.first_mut().ok_or(Error::EmptyQuery) + } + + /// Get protocol version from the original statement. + pub fn proto_version(&self) -> i32 { + self.original.version + } + + /// Get the parse result (original or rewritten). + pub fn parse_result(&self) -> &ParseResult { + self.rewrite.as_ref().unwrap_or(self.original) + } + + /// Prepend new message to rewritten request. + pub fn prepend(&mut self, message: ProtocolMessage) { + self.result.push(RewriteAction { + message, + action: RewriteActionKind::Prepend, + }); + } + + /// Assemble rewrite instructions. + pub fn build(mut self) -> Result { + if self.rewrite.is_none() { + Ok(StepOutput::NoOp) + } else { + let mut stats = RewriteStats::default(); + let ast = self.rewrite.take().ok_or(Error::NoRewrite)?; + let stmt = ast.deparse()?; + let mut parse = self.parse().cloned(); + + let mut actions = self.result; + + if let Some(mut parse) = parse.take() { + parse.set_query(&stmt); + actions.push(RewriteAction { + message: parse.into(), + action: RewriteActionKind::Replace, + }); + stats.parse += 1; + } else { + actions.push(RewriteAction { + message: Query::new(stmt.clone()).into(), + action: RewriteActionKind::Replace, + }); + stats.simple += 1; + } + + Ok(StepOutput::RewriteInPlace { + stmt, + ast, + actions, + stats, + plan: self.plan.freeze(), + }) + } + } +} diff --git a/pgdog/src/frontend/router/rewrite/error.rs b/pgdog/src/frontend/router/rewrite/error.rs new file mode 100644 index 000000000..2034b6d82 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/error.rs @@ -0,0 +1,37 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("parser error")] + ParserError, + + #[error("unique id: {0}")] + UniqueId(#[from] crate::unique_id::Error), + + #[error("pg_query: {0}")] + PgQuery(#[from] pg_query::Error), + + #[error("net: {0}")] + Net(#[from] crate::net::Error), + + #[error("rewrite engine didn't rewrite bind")] + NoBind, + + #[error("empty query")] + EmptyQuery, + + #[error("no rewrite")] + NoRewrite, + + #[error("prepared statement not found: {0}")] + PreparedStatementNotFound(String), + + #[error("statement parameters and bind count mismatch")] + ParameterCountMismatch, + + #[error("parser: {0}")] + Parser(#[from] crate::frontend::router::parser::Error), + + #[error("no active rewrite plan set")] + NoActiveRewritePlan, +} diff --git a/pgdog/src/frontend/router/rewrite/insert_split/mod.rs b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs new file mode 100644 index 000000000..a6bdbdbf2 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/insert_split/mod.rs @@ -0,0 +1,96 @@ +use pg_query::{ + protobuf::{ParseResult, RawStmt}, + Node, NodeEnum, +}; + +use super::*; + +#[derive(Default)] +pub struct InsertSplitRewrite; + +impl RewriteModule for InsertSplitRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + let mut inserts = vec![]; + if let Some(NodeEnum::InsertStmt(insert)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::SelectStmt(ref select)) = insert + .select_stmt + .as_ref() + .and_then(|node| node.node.as_ref()) + { + if select.values_lists.len() <= 1 { + return Ok(()); + } + + // Clone the original statement only once. + let mut proto_select = select.clone(); + let mut proto_insert = insert.clone(); + proto_select.values_lists.clear(); + proto_insert.select_stmt = None; + + // Generate new INSERT statements, with one VALUES tuple each. + for values in &select.values_lists { + let mut new_insert = proto_insert.clone(); + let mut new_select = proto_select.clone(); + let mut new_values = values.clone(); + + // Rewrite the parameter references + // and create new Bind message for each INSERT statement. + if let Some(NodeEnum::List(list)) = new_values.node.as_mut() { + for value in list.items.iter_mut() { + if let Some(NodeEnum::ParamRef(_)) = value.node.as_mut() { + // let parameter = input + // .bind() + // .and_then(|bind| bind.parameter(param.number as usize - 1).ok()) + // .flatten(); + // if let Some(parameter) = parameter { + // param.number = new_bind.add_existing(parameter)?; + // } + } + } + } + new_select.values_lists.push(new_values); + new_insert.select_stmt = Some(Box::new(Node { + node: Some(NodeEnum::SelectStmt(new_select)), + })); + let result = ParseResult { + version: input.proto_version(), + stmts: vec![RawStmt { + stmt: Some(Box::new(Node { + node: Some(NodeEnum::InsertStmt(new_insert)), + })), + ..Default::default() + }], + }; + inserts.push(result); + } + } + } + + drop(inserts); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::net::Parse; + + use super::*; + + #[test] + fn test_insert_split() { + let query = "INSERT INTO users (id, email, created_at) + VALUES ($1, 'test@test.com', NOW()), (123, $2, '2025-01-01') RETURNING *"; + let stmt = pg_query::parse(query).unwrap(); + let parse = Parse::new_anonymous(query); + let mut context = Context::new(&stmt.protobuf, Some(&parse)); + let mut module = InsertSplitRewrite::default(); + module.rewrite(&mut context).unwrap(); + } +} diff --git a/pgdog/src/frontend/router/rewrite/interface.rs b/pgdog/src/frontend/router/rewrite/interface.rs new file mode 100644 index 000000000..d0babb5ca --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/interface.rs @@ -0,0 +1,14 @@ +//! Rewrite module interface. + +use super::{Context, Error}; + +/// Rewrite trait. +/// +/// All rewrite modules should follow this. +pub trait RewriteModule { + /// Take a statement and maybe rewrite it, if needed. + /// + /// If a rewrite is needed, the module should mutate the statement + /// and update the Bind message. + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error>; +} diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs new file mode 100644 index 000000000..9f9232207 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -0,0 +1,60 @@ +//! Query rewrite engine. +//! +//! It handles the following scenarios: +//! +//! 1. Sharding key UPDATE: rewrite to send a DELETE and INSERT +//! 2. Multi-tuple INSERT: rewrite to send multiple INSERTs +//! 3. pgdog.unique_id() call: inject a unique ID +//! +pub mod context; +pub mod error; +pub mod insert_split; +pub mod interface; +pub mod output; +pub mod plan; +pub mod prepared; +pub mod request; +pub mod stats; +pub mod unique_id; + +pub use context::Context; +pub use error::Error; +pub use interface::RewriteModule; +pub use output::{RewriteAction, StepOutput}; +pub use plan::{ImmutableRewritePlan, RewritePlan, UniqueIdPlan}; +pub use request::RewriteRequest; + +use crate::frontend::PreparedStatements; + +/// Combined rewrite engine that runs all rewrite modules. +pub struct Rewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> Rewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for Rewrite<'_> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + // N.B.: the ordering here matters! + // + // First, we need to inject the unique ID into the query. Once that's done, + // we can proceed with additional rewrites. + + // Unique ID rewrites. + unique_id::ExplainUniqueIdRewrite::default().rewrite(input)?; + unique_id::InsertUniqueIdRewrite::default().rewrite(input)?; + unique_id::UpdateUniqueIdRewrite::default().rewrite(input)?; + unique_id::SelectUniqueIdRewrite::default().rewrite(input)?; + + // Prepared statement rewrites + prepared::PreparedRewrite::new(self.prepared_statements).rewrite(input)?; + + Ok(()) + } +} diff --git a/pgdog/src/frontend/router/rewrite/output.rs b/pgdog/src/frontend/router/rewrite/output.rs new file mode 100644 index 000000000..72417acbe --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/output.rs @@ -0,0 +1,94 @@ +use pg_query::protobuf::ParseResult; + +use super::stats::RewriteStats; +use super::ImmutableRewritePlan; +use crate::{frontend::ClientRequest, net::ProtocolMessage}; + +use std::mem::discriminant; + +#[derive(Debug, Clone)] +pub struct RewrittenRequest { + pub messages: Vec, + pub action: ExecutionAction, + pub renamed: Option, +} + +#[derive(Debug, Clone)] +pub struct RewriteAction { + pub(super) message: ProtocolMessage, + pub(super) action: RewriteActionKind, +} + +impl RewriteAction { + /// Execute rewrite action. + pub fn execute(&self, request: &mut ClientRequest) { + match self.action { + RewriteActionKind::Append => request.push(self.message.clone()), + RewriteActionKind::Replace => { + if let Some(pos) = request + .iter() + .position(|p| discriminant(p) == discriminant(&self.message)) + { + request[pos] = self.message.clone(); + } + } + RewriteActionKind::Prepend => request.insert(0, self.message.clone()), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub(super) enum RewriteActionKind { + Replace, + Prepend, + #[allow(dead_code)] + Append, +} + +/// Output of a single rewrite step. +#[derive(Debug, Clone)] +pub enum StepOutput { + NoOp, + RewriteInPlace { + actions: Vec, + ast: ParseResult, + stmt: String, + stats: RewriteStats, + plan: ImmutableRewritePlan, + }, +} + +impl StepOutput { + /// Get rewritten query, if any. + pub fn query(&self) -> Result<&str, ()> { + match self { + Self::NoOp => Err(()), + Self::RewriteInPlace { stmt, .. } => Ok(stmt.as_str()), + } + } + + /// Get the rewrite plan, if any. + pub fn plan(&self) -> Result<&ImmutableRewritePlan, ()> { + match self { + Self::NoOp => Err(()), + Self::RewriteInPlace { plan, .. } => Ok(plan), + } + } +} + +#[derive(Debug, Clone)] +pub enum Output { + Passthrough, + Simple(RewrittenRequest), + Chain(Vec), +} + +#[derive(Debug, Clone)] +pub enum ExecutionAction { + /// Drop result completely. + Drop, + /// Return result to client. + Return, + /// Forward result to next step in the chain. + Forward, +} diff --git a/pgdog/src/frontend/router/rewrite/plan.rs b/pgdog/src/frontend/router/rewrite/plan.rs new file mode 100644 index 000000000..f3c064b42 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/plan.rs @@ -0,0 +1,211 @@ +use std::{ops::Deref, sync::Arc}; + +use super::Error; +use crate::{ + net::{Bind, Datum}, + unique_id::UniqueId, +}; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct UniqueIdPlan { + /// Parameter number. + pub(super) param_ref: i32, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct RewritePlan { + /// How many unique IDs to add to the Bind message. + pub(super) unique_ids: Vec, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ImmutableRewritePlan { + /// Compiled rewrite plan, that cannot be modified further. + pub(super) plan: Arc, +} + +impl Deref for ImmutableRewritePlan { + type Target = RewritePlan; + + fn deref(&self) -> &Self::Target { + &self.plan + } +} + +impl RewritePlan { + /// Apply rewrite plan to Bind message. + /// + /// N.B. this isn't idempotent, run this only once. + /// + pub fn apply_bind(&self, bind: &mut Bind) -> Result<(), Error> { + for unique_id in &self.unique_ids { + let id = UniqueId::generator()?.next_id(); + let counter = bind.add_parameter(Datum::Bigint(id))?; + // Params should be added to plan in order. + // This validates it. + if counter != unique_id.param_ref { + return Err(Error::ParameterCountMismatch); + } + } + + Ok(()) + } + + /// Freeze rewrite plan, without any more modifications allowed. + pub fn freeze(self) -> ImmutableRewritePlan { + ImmutableRewritePlan { + plan: Arc::new(self), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::bind::Parameter; + use std::env::set_var; + + #[test] + fn test_apply_bind_adds_parameters() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a rewrite plan expecting params $3 and $4 + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 3 }, UniqueIdPlan { param_ref: 4 }], + }; + + // Create a Bind with 2 existing parameters + let mut bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 2, + data: "{}".into(), + }, + ], + ); + + // Apply the plan + plan.apply_bind(&mut bind).unwrap(); + + // Verify we now have 4 parameters + assert_eq!(bind.params_raw().len(), 4); + + // Verify the added parameters are BIGINT values (8 bytes in text format) + let param3 = bind.parameter(2).unwrap().unwrap(); + let param4 = bind.parameter(3).unwrap().unwrap(); + + // The parameters should be valid bigints + let id3 = param3.bigint(); + let id4 = param4.bigint(); + assert!(id3.is_some(), "Third parameter should be a valid bigint"); + assert!(id4.is_some(), "Fourth parameter should be a valid bigint"); + + // IDs should be different (unique) + assert_ne!(id3, id4); + } + + #[test] + fn test_apply_bind_parameter_count_mismatch() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a plan expecting param $5 (but bind only has 2 params, so next will be $3) + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 5 }], + }; + + // Create a Bind with 2 existing parameters + let mut bind = Bind::new_params( + "", + &[ + Parameter { + len: 2, + data: "{}".into(), + }, + Parameter { + len: 2, + data: "{}".into(), + }, + ], + ); + + // Apply should fail due to mismatch + let result = plan.apply_bind(&mut bind); + assert!(result.is_err()); + } + + #[test] + fn test_apply_bind_empty_plan() { + // Empty plan should be a no-op + let plan = RewritePlan::default(); + + let mut bind = Bind::new_params( + "", + &[Parameter { + len: 4, + data: "test".into(), + }], + ); + + plan.apply_bind(&mut bind).unwrap(); + + // Should still have just 1 parameter + assert_eq!(bind.params_raw().len(), 1); + } + + #[test] + fn test_apply_bind_single_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Create a plan for a single unique ID as $2 + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 2 }], + }; + + // Create a Bind with 1 existing parameter + let mut bind = Bind::new_params( + "", + &[Parameter { + len: 4, + data: "test".into(), + }], + ); + + plan.apply_bind(&mut bind).unwrap(); + + // Should now have 2 parameters + assert_eq!(bind.params_raw().len(), 2); + + // Second parameter should be a bigint + let param2 = bind.parameter(1).unwrap().unwrap(); + assert!(param2.bigint().is_some()); + } + + #[test] + fn test_immutable_plan_apply_bind() { + unsafe { + set_var("NODE_ID", "pgdog-test-1"); + } + + // Test that ImmutableRewritePlan can also apply to binds via Deref + let plan = RewritePlan { + unique_ids: vec![UniqueIdPlan { param_ref: 1 }], + } + .freeze(); + + let mut bind = Bind::default(); + plan.apply_bind(&mut bind).unwrap(); + + assert_eq!(bind.params_raw().len(), 1); + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/execute.rs b/pgdog/src/frontend/router/rewrite/prepared/execute.rs new file mode 100644 index 000000000..9071c3460 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/execute.rs @@ -0,0 +1,94 @@ +//! EXECUTE statement rewriter. + +use pg_query::NodeEnum; + +use super::super::{Context, Error, RewriteModule}; +use crate::{frontend::PreparedStatements, net::ProtocolMessage}; + +/// Rewriter for EXECUTE statements. +/// +/// Renames the executed statement to use the globally cached name. +pub struct ExecuteRewrite<'a> { + prepared_statements: &'a PreparedStatements, +} + +impl<'a> ExecuteRewrite<'a> { + pub fn new(prepared_statements: &'a PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for ExecuteRewrite<'_> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + if let Some(NodeEnum::ExecuteStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + let parse = self + .prepared_statements + .parse(&stmt.name) + .ok_or_else(|| Error::PreparedStatementNotFound(stmt.name.clone()))?; + + let new_name = parse.name().to_string(); + + if let Some(NodeEnum::ExecuteStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + stmt.name = new_name.clone(); + } + + input.prepend(ProtocolMessage::Prepare { + name: new_name, + statement: parse.query().to_string(), + }); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::{super::PrepareRewrite, *}; + + #[test] + fn test_execute_rewrite() { + let prepare_stmt = pg_query::parse("PREPARE test AS SELECT $1, $2, $3") + .unwrap() + .protobuf; + let mut prepared_statements = PreparedStatements::default(); + + // First prepare the statement + let mut prepare_rewrite = PrepareRewrite::new(&mut prepared_statements); + let mut input = Context::new(&prepare_stmt, None); + prepare_rewrite.rewrite(&mut input).unwrap(); + + // Now execute it + let execute_stmt = pg_query::parse("EXECUTE test(1, 2, 3)").unwrap().protobuf; + let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); + let mut input = Context::new(&execute_stmt, None); + execute_rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert_eq!(query, "EXECUTE __pgdog_1(1, 2, 3)"); + } + + #[test] + fn test_execute_not_found() { + let execute_stmt = pg_query::parse("EXECUTE nonexistent(1, 2, 3)") + .unwrap() + .protobuf; + let prepared_statements = PreparedStatements::default(); + let mut execute_rewrite = ExecuteRewrite::new(&prepared_statements); + let mut input = Context::new(&execute_stmt, None); + let result = execute_rewrite.rewrite(&mut input); + assert!(result.is_err()); + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/mod.rs b/pgdog/src/frontend/router/rewrite/prepared/mod.rs new file mode 100644 index 000000000..3b9df8089 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/mod.rs @@ -0,0 +1,33 @@ +//! Prepared statement rewriter. +//! +//! Rewrites PREPARE and EXECUTE statements to use globally cached names. + +mod execute; +mod prepare; + +pub use execute::ExecuteRewrite; +pub use prepare::PrepareRewrite; + +use super::{Context, Error, RewriteModule}; +use crate::frontend::PreparedStatements; + +/// Combined rewriter for PREPARE and EXECUTE statements. +pub struct PreparedRewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> PreparedRewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for PreparedRewrite<'_> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + PrepareRewrite::new(self.prepared_statements).rewrite(input)?; + ExecuteRewrite::new(self.prepared_statements).rewrite(input)?; + Ok(()) + } +} diff --git a/pgdog/src/frontend/router/rewrite/prepared/prepare.rs b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs new file mode 100644 index 000000000..f4b63c2a1 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/prepared/prepare.rs @@ -0,0 +1,73 @@ +//! PREPARE statement rewriter. + +use pg_query::NodeEnum; + +use super::super::{Context, Error, RewriteModule}; +use crate::frontend::PreparedStatements; + +/// Rewriter for PREPARE statements. +/// +/// Renames the prepared statement to use a globally unique name from the cache. +pub struct PrepareRewrite<'a> { + prepared_statements: &'a mut PreparedStatements, +} + +impl<'a> PrepareRewrite<'a> { + pub fn new(prepared_statements: &'a mut PreparedStatements) -> Self { + Self { + prepared_statements, + } + } +} + +impl RewriteModule for PrepareRewrite<'_> { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + if let Some(NodeEnum::PrepareStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + let statement = stmt + .query + .as_ref() + .ok_or(Error::EmptyQuery)? + .deparse() + .map_err(|_| Error::EmptyQuery)?; + + let mut parse = crate::net::Parse::named(&stmt.name, &statement); + self.prepared_statements.insert_anyway(&mut parse); + let new_name = parse.name().to_string(); + + if let Some(NodeEnum::PrepareStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + stmt.name = new_name; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_prepare_rewrite() { + let stmt = pg_query::parse("PREPARE test AS SELECT $1, $2, $3") + .unwrap() + .protobuf; + let mut prepared_statements = PreparedStatements::default(); + let mut rewrite = PrepareRewrite::new(&mut prepared_statements); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert_eq!(query, "PREPARE __pgdog_1 AS SELECT $1, $2, $3"); + } +} diff --git a/pgdog/src/frontend/router/rewrite/request.rs b/pgdog/src/frontend/router/rewrite/request.rs new file mode 100644 index 000000000..be9e0cbef --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/request.rs @@ -0,0 +1,176 @@ +use pg_query::ParseResult; +use tracing::debug; + +use super::{Context, Error, Rewrite, RewriteModule, StepOutput}; +use crate::{ + backend::Cluster, + frontend::{ + router::{ + parser::{cache::CachedAst, Cache}, + rewrite::ImmutableRewritePlan, + }, + ClientRequest, PreparedStatements, + }, + net::{Protocol, ProtocolMessage}, +}; + +pub struct RewriteRequest<'a> { + request: &'a mut ClientRequest, + cluster: &'a Cluster, + prepared_statements: &'a mut PreparedStatements, + plan: Option, +} + +impl<'a> RewriteRequest<'a> { + /// Perform new rewrite request. + pub fn new( + request: &'a mut ClientRequest, + cluster: &'a Cluster, + prepared_statements: &'a mut PreparedStatements, + ) -> Self { + Self { + request, + cluster, + prepared_statements, + plan: None, + } + } + + fn handle_parse(&mut self) -> Result { + let parse = self.request.iter().find(|p| p.code() == 'P'); + let parse = if let Some(ProtocolMessage::Parse(parse)) = parse { + parse + } else { + return Err(Error::EmptyQuery); + }; + + let schema = self.cluster.sharding_schema(); + let ast = Cache::get().parse(parse.query(), &schema)?; + + let mut context = Context::new(&ast.ast().protobuf, Some(parse)); + Rewrite::new(self.prepared_statements).rewrite(&mut context)?; + let output = context.build()?; + + let ast = match output { + StepOutput::NoOp => { + debug!("rewrite (extended) is a no-op"); + ast + } + StepOutput::RewriteInPlace { + actions, + ast, + stmt, + stats, + plan, + } => { + debug!("rewrite (extended): {}", stmt); + + for action in actions { + action.execute(self.request); + + // Update the rewritten parse in the global cache + // and save its rewrite plan. + if let ProtocolMessage::Parse(parse) = action.message { + if !parse.anonymous() { + self.prepared_statements + .global + .write() + .set_rewrite(&parse, plan.clone()); + } + } + } + + self.plan = Some(plan); + + // Update stats. + { + let cluster_stats = self.cluster.stats(); + let mut lock = cluster_stats.lock(); + lock.rewrite = lock.rewrite + stats; + } + + let ast = ParseResult::new(ast, "".into()); + Cache::get().save(&stmt, ast, &schema)? + } + }; + + Ok(ast) + } + + fn handle_query(&mut self) -> Result { + let query = self.request.iter().find(|p| p.code() == 'Q'); + let query = if let Some(ProtocolMessage::Query(query)) = query { + query + } else { + return Err(Error::EmptyQuery); + }; + + let schema = self.cluster.sharding_schema(); + let ast = Cache::get().parse_uncached(query.query(), &schema)?; + + let mut context = Context::new(&ast.ast().protobuf, None); + Rewrite::new(self.prepared_statements).rewrite(&mut context)?; + let output = context.build()?; + + let ast = match output { + StepOutput::NoOp => { + debug!("rewrite (simple) is a no-op"); + ast + } + StepOutput::RewriteInPlace { + actions, + ast, + stmt, + stats, + plan: _, + } => { + debug!("rewrite (simple): {}", stmt); + + for action in actions { + action.execute(self.request); + } + + // Update stats. + { + let cluster_stats = self.cluster.stats(); + let mut lock = cluster_stats.lock(); + lock.rewrite = lock.rewrite + stats; + } + + let ast = ParseResult::new(ast, "".into()); + CachedAst::new_parsed(&stmt, ast, &schema)? + } + }; + + Ok(ast) + } + + /// Execute rewrite and return the query AST. + pub fn execute(&mut self) -> Result, Error> { + let mut ast: Option = None; + + for result in [self.handle_parse(), self.handle_query()] { + match result { + Ok(a) => ast = Some(a), + Err(Error::EmptyQuery) => continue, + Err(err) => return Err(err), + } + } + let parameters = self.request.parameters_mut()?; + if let Some(parameters) = parameters { + if self.plan.is_none() { + self.plan = self + .prepared_statements + .global + .read() + .rewrite_engine_plan(parameters.statement()); + } + + if let Some(plan) = self.plan.take() { + plan.apply_bind(parameters)?; + } + } + + Ok(ast) + } +} diff --git a/pgdog/src/frontend/router/rewrite/stats.rs b/pgdog/src/frontend/router/rewrite/stats.rs new file mode 100644 index 000000000..c1f80b37c --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/stats.rs @@ -0,0 +1,22 @@ +//! Rewrite engine stats. + +use std::ops::Add; + +#[derive(Debug, Default, Clone, Copy)] +pub struct RewriteStats { + pub parse: usize, + pub bind: usize, + pub simple: usize, +} + +impl Add for RewriteStats { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + parse: self.parse + rhs.parse, + bind: self.bind + rhs.bind, + simple: self.simple + rhs.simple, + } + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/explain.rs b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs new file mode 100644 index 000000000..26d5fce1b --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/explain.rs @@ -0,0 +1,279 @@ +//! EXPLAIN statement rewriter for unique_id. + +use pg_query::NodeEnum; + +use super::{ + super::{Context, Error, RewriteModule}, + max_param_number, InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite, +}; + +#[derive(Default)] +pub struct ExplainUniqueIdRewrite {} + +impl RewriteModule for ExplainUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + // Check if this is an EXPLAIN statement + let is_explain = matches!( + input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()), + Some(NodeEnum::ExplainStmt(_)) + ); + + if !is_explain { + return Ok(()); + } + + // Get the inner query type and dispatch to appropriate rewriter + let inner_type = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + stmt.query + .as_ref() + .and_then(|q| q.node.as_ref()) + .map(|n| match n { + NodeEnum::SelectStmt(_) => "select", + NodeEnum::InsertStmt(_) => "insert", + NodeEnum::UpdateStmt(_) => "update", + _ => "other", + }) + } else { + None + }; + + match inner_type { + Some("select") => self.rewrite_explain_select(input), + Some("insert") => self.rewrite_explain_insert(input), + Some("update") => self.rewrite_explain_update(input), + _ => Ok(()), + } + } +} + +impl ExplainUniqueIdRewrite { + fn rewrite_explain_select(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + // Check if the inner SELECT needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::SelectStmt(select)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + SelectUniqueIdRewrite::needs_rewrite(select) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut parameter_counter = max_param_number(input.parse_result()); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::SelectStmt(select)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + input.plan().unique_ids = SelectUniqueIdRewrite::rewrite_select( + select, + extended, + &mut parameter_counter, + )?; + } + } + + Ok(()) + } + + fn rewrite_explain_insert(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + // Check if the inner INSERT needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::InsertStmt(insert)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + InsertUniqueIdRewrite::needs_rewrite(insert) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::InsertStmt(insert)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + input.plan().unique_ids = + InsertUniqueIdRewrite::rewrite_insert(insert, extended, &mut param_counter)?; + } + } + + Ok(()) + } + + fn rewrite_explain_update(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + // Check if the inner UPDATE needs rewriting + let needs_rewrite = if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + if let Some(NodeEnum::UpdateStmt(update)) = + stmt.query.as_ref().and_then(|q| q.node.as_ref()) + { + UpdateUniqueIdRewrite::needs_rewrite(update) + } else { + false + } + } else { + false + }; + + if !needs_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut param_counter = super::max_param_number(input.parse_result()); + + if let Some(NodeEnum::ExplainStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + if let Some(NodeEnum::UpdateStmt(update)) = + stmt.query.as_mut().and_then(|q| q.node.as_mut()) + { + input.plan().unique_ids = + UpdateUniqueIdRewrite::rewrite_update(update, extended, &mut param_counter)?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::env::set_var; + + #[test] + fn test_explain_select_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_insert_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN INSERT INTO test (id) VALUES (pgdog.unique_id())"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_update_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"EXPLAIN UPDATE test SET id = pgdog.unique_id() WHERE old_id = 1"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("bigint")); + } + + #[test] + fn test_explain_analyze_select_unique_id() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"EXPLAIN ANALYZE SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!(query.contains("EXPLAIN")); + assert!(query.contains("ANALYZE")); + } + + #[test] + fn test_explain_no_unique_id() { + let stmt = pg_query::parse(r#"EXPLAIN SELECT 1"#).unwrap().protobuf; + let mut rewrite = ExplainUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(matches!(output, super::super::super::StepOutput::NoOp)); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/insert.rs b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs new file mode 100644 index 000000000..fd74e6cf9 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/insert.rs @@ -0,0 +1,172 @@ +use pg_query::{protobuf::InsertStmt, NodeEnum}; + +use super::{ + super::{Context, Error, RewriteModule}, + bigint_const, bigint_param, max_param_number, +}; +use crate::{ + frontend::router::{ + parser::{Insert, Value}, + rewrite::UniqueIdPlan, + }, + unique_id, +}; + +#[derive(Default)] +pub struct InsertUniqueIdRewrite {} + +impl InsertUniqueIdRewrite { + pub fn needs_rewrite(stmt: &InsertStmt) -> bool { + let wrapper = Insert::new(stmt); + + for tuple in wrapper.tuples() { + for value in tuple.values { + if let Value::Function(ref func) = value { + if *func == "pgdog.unique_id" { + return true; + } + } + } + } + + false + } + + pub fn rewrite_insert( + stmt: &mut InsertStmt, + extended: bool, + param_counter: &mut i32, + ) -> Result, Error> { + let mut plans = vec![]; + let select = stmt + .select_stmt + .as_mut() + .ok_or(Error::ParserError)? + .node + .as_mut() + .ok_or(Error::ParserError)?; + + if let NodeEnum::SelectStmt(select_stmt) = select { + for tuple in select_stmt.values_lists.iter_mut() { + if let Some(NodeEnum::List(ref mut tuple)) = tuple.node { + for column in tuple.items.iter_mut() { + if let Ok(Value::Function(name)) = Value::try_from(&column.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if extended { + *param_counter += 1; + plans.push(UniqueIdPlan { + param_ref: *param_counter, + }); + bigint_param(*param_counter) + } else { + bigint_const(id) + }; + + column.node = Some(node); + } + } + } + } + } + } + + Ok(plans) + } +} + +impl RewriteModule for InsertUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + let need_rewrite = if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + Self::needs_rewrite(stmt) + } else { + false + }; + + if !need_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); + + if let Some(NodeEnum::InsertStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + input.plan().unique_ids = Self::rewrite_insert(stmt, extended, &mut param_counter)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::Parse; + use std::env::set_var; + + #[test] + fn test_unique_id_insert() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse( + r#" + INSERT INTO omnisharded (id, settings) + VALUES + (pgdog.unique_id(), '{}'::JSONB), + (pgdog.unique_id(), '{"hello": "world"}'::JSONB)"#, + ) + .unwrap() + .protobuf; + let mut insert = InsertUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + insert.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + let query = output.query().unwrap(); + assert!(!query.contains("pgdog.unique_id")); + assert!( + query.contains("bigint"), + "Query should contain bigint cast: {}", + query + ); + } + + #[test] + fn test_unique_id_insert_parse() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let query = r#" + INSERT INTO omnisharded (id, settings) + VALUES + (pgdog.unique_id(), $1::JSONB), + (pgdog.unique_id(), $2::JSONB)"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); + InsertUniqueIdRewrite::default() + .rewrite(&mut input) + .unwrap(); + let output = input.build().unwrap(); + assert_eq!( + output.query().unwrap(), + "INSERT INTO omnisharded (id, settings) VALUES ($3::bigint, $1::jsonb), ($4::bigint, $2::jsonb)" + ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 2); + assert_eq!(plan.unique_ids[0].param_ref, 3); + assert_eq!(plan.unique_ids[1].param_ref, 4); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/mod.rs b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs new file mode 100644 index 000000000..6353cb499 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/mod.rs @@ -0,0 +1,207 @@ +//! Unique ID rewrite engine. + +use pg_query::{ + protobuf::{a_const::Val, AConst, Node, ParamRef, ParseResult, TypeCast, TypeName}, + NodeEnum, +}; + +pub mod explain; +pub mod insert; +pub mod select; +pub mod update; + +pub use explain::ExplainUniqueIdRewrite; +pub use insert::InsertUniqueIdRewrite; +pub use select::SelectUniqueIdRewrite; +pub use update::UpdateUniqueIdRewrite; + +pub struct UniqueIdRewrite; + +/// Create a bigint-typed parameter reference node. +fn bigint_param(number: i32) -> NodeEnum { + NodeEnum::TypeCast(Box::new(TypeCast { + arg: Some(Box::new(Node { + node: Some(NodeEnum::ParamRef(ParamRef { + number, + ..Default::default() + })), + })), + type_name: Some(TypeName { + names: vec![ + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "pg_catalog".to_string(), + })), + }, + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "int8".to_string(), + })), + }, + ], + ..Default::default() + }), + ..Default::default() + })) +} + +/// Create a bigint-typed constant node for the given ID. +fn bigint_const(id: i64) -> NodeEnum { + NodeEnum::TypeCast(Box::new(TypeCast { + arg: Some(Box::new(Node { + node: Some(NodeEnum::AConst(AConst { + val: Some(Val::Sval(pg_query::protobuf::String { + sval: id.to_string(), + })), + ..Default::default() + })), + })), + type_name: Some(TypeName { + names: vec![ + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "pg_catalog".to_string(), + })), + }, + Node { + node: Some(NodeEnum::String(pg_query::protobuf::String { + sval: "int8".to_string(), + })), + }, + ], + ..Default::default() + }), + ..Default::default() + })) +} + +/// Find the maximum parameter number ($N) in a parse result. +pub fn max_param_number(result: &ParseResult) -> i32 { + let mut max = 0; + for stmt in &result.stmts { + if let Some(ref stmt) = stmt.stmt { + find_max_param(&stmt.node, &mut max); + } + } + max +} + +fn find_max_param(node: &Option, max: &mut i32) { + let Some(node) = node else { + return; + }; + + match node { + NodeEnum::ParamRef(param) => { + if param.number > *max { + *max = param.number; + } + } + NodeEnum::TypeCast(cast) => { + if let Some(ref arg) = cast.arg { + find_max_param(&arg.node, max); + } + } + NodeEnum::FuncCall(func) => { + for arg in &func.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::AExpr(expr) => { + if let Some(ref lexpr) = expr.lexpr { + find_max_param(&lexpr.node, max); + } + if let Some(ref rexpr) = expr.rexpr { + find_max_param(&rexpr.node, max); + } + } + NodeEnum::SelectStmt(stmt) => { + for item in &stmt.target_list { + find_max_param(&item.node, max); + } + for item in &stmt.values_lists { + find_max_param(&item.node, max); + } + for item in &stmt.from_clause { + find_max_param(&item.node, max); + } + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + if let Some(ref limit) = stmt.limit_count { + find_max_param(&limit.node, max); + } + if let Some(ref offset) = stmt.limit_offset { + find_max_param(&offset.node, max); + } + } + NodeEnum::InsertStmt(stmt) => { + if let Some(ref select) = stmt.select_stmt { + find_max_param(&select.node, max); + } + } + NodeEnum::UpdateStmt(stmt) => { + for item in &stmt.target_list { + find_max_param(&item.node, max); + } + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + } + NodeEnum::DeleteStmt(stmt) => { + if let Some(ref clause) = stmt.where_clause { + find_max_param(&clause.node, max); + } + } + NodeEnum::ResTarget(res) => { + if let Some(ref val) = res.val { + find_max_param(&val.node, max); + } + } + NodeEnum::List(list) => { + for item in &list.items { + find_max_param(&item.node, max); + } + } + NodeEnum::CoalesceExpr(coalesce) => { + for arg in &coalesce.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::CaseExpr(case) => { + if let Some(ref arg) = case.arg { + find_max_param(&arg.node, max); + } + for when in &case.args { + find_max_param(&when.node, max); + } + if let Some(ref defresult) = case.defresult { + find_max_param(&defresult.node, max); + } + } + NodeEnum::CaseWhen(when) => { + if let Some(ref expr) = when.expr { + find_max_param(&expr.node, max); + } + if let Some(ref result) = when.result { + find_max_param(&result.node, max); + } + } + NodeEnum::BoolExpr(expr) => { + for arg in &expr.args { + find_max_param(&arg.node, max); + } + } + NodeEnum::NullTest(test) => { + if let Some(ref arg) = test.arg { + find_max_param(&arg.node, max); + } + } + NodeEnum::ExplainStmt(stmt) => { + if let Some(ref query) = stmt.query { + find_max_param(&query.node, max); + } + } + _ => {} + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/select.rs b/pgdog/src/frontend/router/rewrite/unique_id/select.rs new file mode 100644 index 000000000..bf37aa0df --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/select.rs @@ -0,0 +1,323 @@ +//! SELECT statement rewriter for unique_id. + +use pg_query::{ + protobuf::{Node, SelectStmt}, + NodeEnum, +}; + +use super::{ + super::{Context, Error, RewriteModule, UniqueIdPlan}, + bigint_const, bigint_param, +}; +use crate::{ + frontend::router::{parser::Value, rewrite::unique_id::max_param_number}, + unique_id, +}; + +#[derive(Default)] +pub struct SelectUniqueIdRewrite {} + +impl SelectUniqueIdRewrite { + pub fn needs_rewrite(stmt: &SelectStmt) -> bool { + // Check target_list (SELECT columns) + for target in &stmt.target_list { + if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { + if let Some(val) = &res.val { + if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { + if func == "pgdog.unique_id" { + return true; + } + } + } + } + } + + // Check CTEs recursively + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + if let Some(NodeEnum::CommonTableExpr(ref expr)) = cte.node { + if let Some(ref query) = expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref inner)) = query.node { + if Self::needs_rewrite(inner) { + return true; + } + } + } + } + } + } + + // Check subqueries in FROM clause + for from in &stmt.from_clause { + if Self::needs_rewrite_from_node(from) { + return true; + } + } + + // Check UNION/INTERSECT/EXCEPT (larg/rarg are Box) + if let Some(ref larg) = stmt.larg { + if Self::needs_rewrite(larg) { + return true; + } + } + if let Some(ref rarg) = stmt.rarg { + if Self::needs_rewrite(rarg) { + return true; + } + } + + false + } + + fn needs_rewrite_from_node(node: &Node) -> bool { + match node.node.as_ref() { + Some(NodeEnum::RangeSubselect(subselect)) => { + if let Some(ref subquery) = subselect.subquery { + if let Some(NodeEnum::SelectStmt(ref inner)) = subquery.node { + return Self::needs_rewrite(inner); + } + } + false + } + Some(NodeEnum::JoinExpr(join)) => { + let left = join + .larg + .as_ref() + .is_some_and(|n| Self::needs_rewrite_from_node(n)); + let right = join + .rarg + .as_ref() + .is_some_and(|n| Self::needs_rewrite_from_node(n)); + left || right + } + _ => false, + } + } + + pub fn rewrite_select( + stmt: &mut SelectStmt, + extended: bool, + paramter_counter: &mut i32, + ) -> Result, Error> { + let mut plans = vec![]; + + // Rewrite target_list + for target in stmt.target_list.iter_mut() { + if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { + if let Some(ref mut val) = res.val { + if let Ok(Value::Function(name)) = Value::try_from(&val.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if extended { + *paramter_counter += 1; + plans.push(UniqueIdPlan { + param_ref: *paramter_counter, + }); + + bigint_param(*paramter_counter) + } else { + bigint_const(id) + }; + + val.node = Some(node); + } + } + } + } + } + + // Rewrite CTEs recursively + if let Some(ref mut with_clause) = stmt.with_clause { + for cte in with_clause.ctes.iter_mut() { + if let Some(NodeEnum::CommonTableExpr(ref mut expr)) = cte.node { + if let Some(ref mut query) = expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref mut inner)) = query.node { + Self::rewrite_select(inner, extended, paramter_counter)?; + } + } + } + } + } + + // Rewrite subqueries in FROM clause + for from in stmt.from_clause.iter_mut() { + Self::rewrite_from_node(from, extended, paramter_counter)?; + } + + // Rewrite UNION/INTERSECT/EXCEPT (larg/rarg are Box) + if let Some(ref mut larg) = stmt.larg { + plans.extend(Self::rewrite_select(larg, extended, paramter_counter)?); + } + if let Some(ref mut rarg) = stmt.rarg { + plans.extend(Self::rewrite_select(rarg, extended, paramter_counter)?); + } + + Ok(plans) + } + + fn rewrite_from_node( + node: &mut Node, + extended: bool, + paramter_counter: &mut i32, + ) -> Result, Error> { + let mut plans = vec![]; + match node.node.as_mut() { + Some(NodeEnum::RangeSubselect(ref mut subselect)) => { + if let Some(ref mut subquery) = subselect.subquery { + if let Some(NodeEnum::SelectStmt(ref mut inner)) = subquery.node { + plans.extend(Self::rewrite_select(inner, extended, paramter_counter)?); + } + } + } + Some(NodeEnum::JoinExpr(ref mut join)) => { + if let Some(ref mut larg) = join.larg { + plans.extend(Self::rewrite_from_node(larg, extended, paramter_counter)?); + } + if let Some(ref mut rarg) = join.rarg { + plans.extend(Self::rewrite_from_node(rarg, extended, paramter_counter)?); + } + } + _ => {} + } + Ok(plans) + } +} + +impl RewriteModule for SelectUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + let need_rewrite = if let Some(NodeEnum::SelectStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + Self::needs_rewrite(stmt) + } else { + false + }; + + if !need_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut parameter_counter = max_param_number(input.parse_result()); + + if let Some(NodeEnum::SelectStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + let plans = Self::rewrite_select(stmt, extended, &mut parameter_counter)?; + input.plan().unique_ids = plans; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::Parse; + use std::env::set_var; + + #[test] + fn test_unique_id_select() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"SELECT pgdog.unique_id() AS id"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + println!("output: {}", output.query().unwrap()); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_with_parse() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let query = r#"SELECT pgdog.unique_id() AS id, $1 AS name"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); + SelectUniqueIdRewrite::default() + .rewrite(&mut input) + .unwrap(); + let output = input.build().unwrap(); + assert_eq!( + output.query().unwrap(), + "SELECT $2::bigint AS id, $1 AS name" + ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 1); + assert_eq!(plan.unique_ids[0].param_ref, 2); + } + + #[test] + fn test_unique_id_select_cte() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"WITH ids AS (SELECT pgdog.unique_id() AS id) SELECT * FROM ids"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_subquery() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse(r#"SELECT * FROM (SELECT pgdog.unique_id() AS id) sub"#) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_select_union() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = pg_query::parse( + r#"SELECT pgdog.unique_id() AS id UNION ALL SELECT pgdog.unique_id() AS id"#, + ) + .unwrap() + .protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_no_rewrite_when_no_unique_id() { + let stmt = pg_query::parse(r#"SELECT id FROM users"#).unwrap().protobuf; + let mut rewrite = SelectUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + rewrite.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(matches!(output, super::super::super::StepOutput::NoOp)); + } +} diff --git a/pgdog/src/frontend/router/rewrite/unique_id/update.rs b/pgdog/src/frontend/router/rewrite/unique_id/update.rs new file mode 100644 index 000000000..4dd1fb643 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/unique_id/update.rs @@ -0,0 +1,144 @@ +//! UPDATE statement rewriter for unique_id. + +use pg_query::{protobuf::UpdateStmt, NodeEnum}; + +use super::{ + super::{Context, Error, RewriteModule, UniqueIdPlan}, + bigint_const, bigint_param, max_param_number, +}; +use crate::{frontend::router::parser::Value, unique_id}; + +#[derive(Default)] +pub struct UpdateUniqueIdRewrite {} + +impl UpdateUniqueIdRewrite { + pub fn needs_rewrite(stmt: &UpdateStmt) -> bool { + for target in &stmt.target_list { + if let Some(NodeEnum::ResTarget(res)) = target.node.as_ref() { + if let Some(val) = &res.val { + if let Ok(Value::Function(ref func)) = Value::try_from(&val.node) { + if *func == "pgdog.unique_id" { + return true; + } + } + } + } + } + false + } + + pub fn rewrite_update( + stmt: &mut UpdateStmt, + extended: bool, + param_counter: &mut i32, + ) -> Result, Error> { + let mut plans = vec![]; + + for target in stmt.target_list.iter_mut() { + if let Some(NodeEnum::ResTarget(ref mut res)) = target.node { + if let Some(ref mut val) = res.val { + if let Ok(Value::Function(name)) = Value::try_from(&val.node) { + if name == "pgdog.unique_id" { + let id = unique_id::UniqueId::generator()?.next_id(); + + let node = if extended { + *param_counter += 1; + plans.push(UniqueIdPlan { + param_ref: *param_counter, + }); + + bigint_param(*param_counter) + } else { + bigint_const(id) + }; + + val.node = Some(node); + } + } + } + } + } + + Ok(plans) + } +} + +impl RewriteModule for UpdateUniqueIdRewrite { + fn rewrite(&mut self, input: &mut Context<'_>) -> Result<(), Error> { + let need_rewrite = if let Some(NodeEnum::UpdateStmt(stmt)) = input + .stmt()? + .stmt + .as_ref() + .and_then(|stmt| stmt.node.as_ref()) + { + Self::needs_rewrite(stmt) + } else { + false + }; + + if !need_rewrite { + return Ok(()); + } + + let extended = input.extended(); + let mut param_counter = max_param_number(input.parse_result()); + + if let Some(NodeEnum::UpdateStmt(stmt)) = input + .stmt_mut()? + .stmt + .as_mut() + .and_then(|stmt| stmt.node.as_mut()) + { + input.plan().unique_ids = Self::rewrite_update(stmt, extended, &mut param_counter)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::net::Parse; + use std::env::set_var; + + #[test] + fn test_unique_id_update() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let stmt = + pg_query::parse(r#"UPDATE omnisharded SET id = pgdog.unique_id() WHERE old_id = 123"#) + .unwrap() + .protobuf; + let mut update = UpdateUniqueIdRewrite::default(); + let mut input = Context::new(&stmt, None); + update.rewrite(&mut input).unwrap(); + let output = input.build().unwrap(); + assert!(!output.query().unwrap().contains("pgdog.unique_id")); + } + + #[test] + fn test_unique_id_update_parse() { + unsafe { + set_var("NODE_ID", "pgdog-prod-1"); + } + let query = + r#"UPDATE omnisharded SET id = pgdog.unique_id(), settings = $1 WHERE old_id = $2"#; + let stmt = pg_query::parse(query).unwrap().protobuf; + let parse = Parse::new_anonymous(query); + let mut input = Context::new(&stmt, Some(&parse)); + UpdateUniqueIdRewrite::default() + .rewrite(&mut input) + .unwrap(); + let output = input.build().unwrap(); + assert_eq!( + output.query().unwrap(), + "UPDATE omnisharded SET id = $3::bigint, settings = $1 WHERE old_id = $2" + ); + // Verify the rewrite plan has the correct parameters + let plan = output.plan().unwrap(); + assert_eq!(plan.unique_ids.len(), 1); + assert_eq!(plan.unique_ids[0].param_ref, 3); + } +} diff --git a/pgdog/src/net/error.rs b/pgdog/src/net/error.rs index d28bf04ee..01c849861 100644 --- a/pgdog/src/net/error.rs +++ b/pgdog/src/net/error.rs @@ -99,4 +99,7 @@ pub enum Error { #[error("not a pg_lsn")] NotPgLsn, + + #[error("multiple parameter formats in same bind message")] + MultipleBindFormats, } diff --git a/pgdog/src/net/messages/bind.rs b/pgdog/src/net/messages/bind.rs index b7042fc6d..2c7aacfc1 100644 --- a/pgdog/src/net/messages/bind.rs +++ b/pgdog/src/net/messages/bind.rs @@ -1,5 +1,6 @@ //! Bind (F) message. use crate::net::c_string_buf_len; +use crate::net::Datum; use uuid::Uuid; use super::code; @@ -198,11 +199,56 @@ impl Bind { unsafe { from_utf8_unchecked(&self.statement[0..self.statement.len() - 1]) } } + pub fn statement_ref(&self) -> &Bytes { + &self.statement + } + /// Format codes, if any. pub fn codes(&self) -> &[Format] { &self.codes } + /// Add parameter and return its placeholder number. + pub fn add_parameter(&mut self, data: Datum) -> Result { + let format = { + let codes = self.codes(); + if codes.is_empty() { + Format::Text + } else if codes.len() == 1 { + codes[0] + } else { + self.codes.push(Format::Text); + Format::Text + } + }; + let bytes = data.encode(format)?; + self.params.push(Parameter { + len: bytes.len() as i32, + data: bytes, + }); + self.original = None; + // Param codes are 1-indexed. + Ok(self.params.len() as i32) + } + + /// Add parameter with provided format. + pub fn add_existing(&mut self, param: ParameterWithFormat<'_>) -> Result { + let format = param.format; + let existing = self.codes.first().cloned(); + match (format, existing) { + (Format::Text, None) + | (Format::Text, Some(Format::Text)) + | (Format::Binary, Some(Format::Binary)) => (), + (Format::Binary, None) => self.codes.push(format), + (Format::Binary, Some(Format::Text)) | (Format::Text, Some(Format::Binary)) => { + return Err(Error::MultipleBindFormats); + } + } + self.params.push(param.parameter.clone()); + self.original = None; + Ok(self.params.len() as i32) + } + pub fn new_statement(name: &str) -> Self { Self { statement: Bytes::from(name.to_string() + "\0"), @@ -462,4 +508,53 @@ mod test { assert_eq!(decoded.statement(), "__pgdog_large"); assert_eq!(bytes.len(), decoded.len()); } + + #[test] + fn test_add_parameter_produces_correct_bind() { + // Start with an existing Bind message parsed from bytes + let original_bind = Bind::new_params_codes( + "original_stmt", + &[Parameter::new(b"original_value")], + &[Format::Text], + ); + let original_bytes = original_bind.to_bytes().unwrap(); + let mut bind = Bind::from_bytes(original_bytes.clone()).unwrap(); + + // Verify original is set after parsing + assert!(bind.original.is_some()); + + // Add a new parameter - this should clear original + bind.add_parameter(Datum::Text("added_param".to_string())) + .unwrap(); + + // Verify original is now None + assert!(bind.original.is_none()); + + // Serialize to bytes + let new_bytes = bind.to_bytes().unwrap(); + + // The new bytes should be different from original (new param added) + assert_ne!(original_bytes, new_bytes); + + // Deserialize and verify the message is correct + let decoded = Bind::from_bytes(new_bytes.clone()).unwrap(); + + // Verify statement name preserved + assert_eq!(decoded.statement(), "original_stmt"); + + // Verify we have 2 parameters now (original + added) + assert_eq!(decoded.params_raw().len(), 2); + + // Verify first parameter (original) + let param0 = decoded.parameter(0).unwrap().unwrap(); + assert_eq!(param0.text(), Some("original_value")); + + // Verify second parameter (added) + let param1 = decoded.parameter(1).unwrap().unwrap(); + assert_eq!(param1.text(), Some("added_param")); + + // Verify round-trip produces identical bytes + let re_encoded = decoded.to_bytes().unwrap(); + assert_eq!(new_bytes, re_encoded); + } } diff --git a/pgdog/src/net/messages/parse.rs b/pgdog/src/net/messages/parse.rs index 06e2be6a7..1675d4555 100644 --- a/pgdog/src/net/messages/parse.rs +++ b/pgdog/src/net/messages/parse.rs @@ -38,11 +38,10 @@ impl Parse { } /// New anonymous prepared statement. - #[cfg(test)] - pub fn new_anonymous(query: &str) -> Self { + pub fn new_anonymous(query: impl ToString) -> Self { Self { name: Bytes::from("\0"), - query: Bytes::from(query.to_owned() + "\0"), + query: Bytes::from(query.to_string() + "\0"), data_types: Bytes::copy_from_slice(&0i16.to_be_bytes()), original: None, } @@ -68,6 +67,10 @@ impl Parse { unsafe { from_utf8_unchecked(&self.query[0..self.query.len() - 1]) } } + pub fn name_ref(&self) -> Bytes { + self.name.clone() + } + /// Get query reference. pub fn query_ref(&self) -> Bytes { self.query.clone() diff --git a/pgdog/src/net/protocol_message.rs b/pgdog/src/net/protocol_message.rs index 239a267c8..5cb19592b 100644 --- a/pgdog/src/net/protocol_message.rs +++ b/pgdog/src/net/protocol_message.rs @@ -58,6 +58,14 @@ impl ProtocolMessage { Self::CopyFail(copy_fail) => copy_fail.len(), } } + + pub fn query(&self) -> Option<&str> { + match self { + ProtocolMessage::Query(query) => Some(query.query()), + ProtocolMessage::Parse(parse) => Some(parse.query()), + _ => None, + } + } } impl Protocol for ProtocolMessage { diff --git a/pgdog/src/stats/http_server.rs b/pgdog/src/stats/http_server.rs index 9f6628924..a8d5e2a17 100644 --- a/pgdog/src/stats/http_server.rs +++ b/pgdog/src/stats/http_server.rs @@ -10,7 +10,7 @@ use hyper_util::rt::TokioIo; use tokio::net::TcpListener; use tracing::info; -use super::{Clients, MirrorStatsMetrics, Pools, QueryCache}; +use super::{Clients, MirrorStatsMetrics, Pools, QueryCache, RewriteStatsMetrics}; async fn metrics(_: Request) -> Result>, Infallible> { let clients = Clients::load(); @@ -26,13 +26,20 @@ async fn metrics(_: Request) -> Result = RewriteStatsMetrics::load() + .into_iter() + .map(|m| m.to_string()) + .collect(); + let rewrite_stats = rewrite_stats.join("\n"); let metrics_data = clients.to_string() + "\n" + &pools.to_string() + "\n" + &mirror_stats + "\n" - + &query_cache; + + &query_cache + + "\n" + + &rewrite_stats; let response = Response::builder() .header( hyper::header::CONTENT_TYPE, diff --git a/pgdog/src/stats/mirror_stats.rs b/pgdog/src/stats/mirror_stats.rs index d67a34ca5..98d33f6a9 100644 --- a/pgdog/src/stats/mirror_stats.rs +++ b/pgdog/src/stats/mirror_stats.rs @@ -24,7 +24,7 @@ impl MirrorStatsMetrics { for (user, cluster) in databases().all() { let stats = cluster.stats(); let stats = stats.lock(); - let counts = stats.counts; + let counts = stats.mirrors; // Per-cluster metrics with labels let labels = vec![ diff --git a/pgdog/src/stats/mod.rs b/pgdog/src/stats/mod.rs index e7b1843cc..15c8d0a23 100644 --- a/pgdog/src/stats/mod.rs +++ b/pgdog/src/stats/mod.rs @@ -8,9 +8,11 @@ pub use open_metric::*; pub mod logger; pub mod memory; pub mod query_cache; +pub mod rewrite_stats; pub use clients::Clients; pub use logger::Logger as StatsLogger; pub use mirror_stats::MirrorStatsMetrics; pub use pools::{PoolMetric, Pools}; pub use query_cache::QueryCache; +pub use rewrite_stats::RewriteStatsMetrics; diff --git a/pgdog/src/stats/rewrite_stats.rs b/pgdog/src/stats/rewrite_stats.rs new file mode 100644 index 000000000..30b62244d --- /dev/null +++ b/pgdog/src/stats/rewrite_stats.rs @@ -0,0 +1,119 @@ +//! Rewrite stats OpenMetrics. + +use crate::backend::databases::databases; + +use super::{Measurement, Metric, OpenMetric}; + +pub struct RewriteStatsMetrics; + +impl RewriteStatsMetrics { + pub fn load() -> Vec { + let mut metrics = vec![]; + + let mut parse_measurements = vec![]; + let mut bind_measurements = vec![]; + let mut simple_measurements = vec![]; + + let mut global_parse = 0usize; + let mut global_bind = 0usize; + let mut global_simple = 0usize; + + for (user, cluster) in databases().all() { + let stats = cluster.stats(); + let stats = stats.lock(); + let rewrite = stats.rewrite; + + let labels = vec![ + ("user".into(), user.user.clone()), + ("database".into(), user.database.clone()), + ]; + + parse_measurements.push(Measurement { + labels: labels.clone(), + measurement: rewrite.parse.into(), + }); + + bind_measurements.push(Measurement { + labels: labels.clone(), + measurement: rewrite.bind.into(), + }); + + simple_measurements.push(Measurement { + labels, + measurement: rewrite.simple.into(), + }); + + global_parse += rewrite.parse; + global_bind += rewrite.bind; + global_simple += rewrite.simple; + } + + parse_measurements.push(Measurement { + labels: vec![], + measurement: global_parse.into(), + }); + + bind_measurements.push(Measurement { + labels: vec![], + measurement: global_bind.into(), + }); + + simple_measurements.push(Measurement { + labels: vec![], + measurement: global_simple.into(), + }); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_parse_count".into(), + measurements: parse_measurements, + help: "Number of Parse messages rewritten.".into(), + })); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_bind_count".into(), + measurements: bind_measurements, + help: "Number of Bind messages rewritten.".into(), + })); + + metrics.push(Metric::new(RewriteStatsMetric { + name: "rewrite_simple_count".into(), + measurements: simple_measurements, + help: "Number of simple queries rewritten.".into(), + })); + + metrics + } +} + +struct RewriteStatsMetric { + name: String, + measurements: Vec, + help: String, +} + +impl OpenMetric for RewriteStatsMetric { + fn name(&self) -> String { + self.name.clone() + } + + fn measurements(&self) -> Vec { + self.measurements.clone() + } + + fn help(&self) -> Option { + Some(self.help.clone()) + } + + fn metric_type(&self) -> String { + "counter".into() + } +} + +impl std::fmt::Display for RewriteStatsMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for metric in RewriteStatsMetrics::load() { + writeln!(f, "{}", metric)?; + } + Ok(()) + } +} diff --git a/pgdog/src/unique_id.rs b/pgdog/src/unique_id.rs index 660e2814b..916f200f9 100644 --- a/pgdog/src/unique_id.rs +++ b/pgdog/src/unique_id.rs @@ -8,13 +8,13 @@ //! clock, so `std::time::SystemTime` returns a good value. //! use std::sync::Arc; +use std::thread::sleep; use std::time::UNIX_EPOCH; use std::time::{Duration, SystemTime}; use once_cell::sync::OnceCell; +use parking_lot::Mutex; use thiserror::Error; -use tokio::sync::Mutex; -use tokio::time::sleep; use crate::config::config; use crate::util::{instance_id, node_id}; @@ -34,32 +34,23 @@ const MAX_OFFSET: u64 = i64::MAX as u64 static UNIQUE_ID: OnceCell = OnceCell::new(); -#[derive(Debug)] +#[derive(Debug, Default)] struct State { last_timestamp_ms: u64, sequence: u64, } -impl Default for State { - fn default() -> Self { - Self { - last_timestamp_ms: 0, - sequence: 0, - } - } -} - impl State { // Generate next unique ID in a distributed sequence. // The `node_id` argument must be globally unique. - async fn next_id(&mut self, node_id: u64, id_offset: u64) -> u64 { - let mut now = wait_until(self.last_timestamp_ms).await; + fn next_id(&mut self, node_id: u64, id_offset: u64) -> u64 { + let mut now = wait_until(self.last_timestamp_ms); if now == self.last_timestamp_ms { self.sequence = (self.sequence + 1) & MAX_SEQUENCE; // Wraparound. if self.sequence == 0 { - now = wait_until(now + 1).await; + now = wait_until(now + 1); } } else { // Reset sequence to zero once we reach next ms. @@ -92,13 +83,13 @@ fn now_ms() -> u64 { // Get a monotonically increasing timestamp in ms. // Protects against clock drift. -async fn wait_until(target_ms: u64) -> u64 { +fn wait_until(target_ms: u64) -> u64 { loop { let now = now_ms(); if now >= target_ms { return now; } - sleep(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)); } } @@ -145,16 +136,12 @@ impl UniqueId { /// Get (and initialize, if necessary) the unique ID generator. pub fn generator() -> Result<&'static UniqueId, Error> { - UNIQUE_ID.get_or_try_init(|| Self::new()) + UNIQUE_ID.get_or_try_init(Self::new) } /// Generate a globally unique, monotonically increasing identifier. - pub async fn next_id(&self) -> i64 { - self.inner - .lock() - .await - .next_id(self.node_id, self.id_offset) - .await as i64 + pub fn next_id(&self) -> i64 { + self.inner.lock().next_id(self.node_id, self.id_offset) as i64 } } @@ -164,8 +151,8 @@ mod test { use super::*; - #[tokio::test] - async fn test_unique_ids() { + #[test] + fn test_unique_ids() { unsafe { set_var("NODE_ID", "pgdog-1"); } @@ -174,32 +161,32 @@ mod test { let mut ids = HashSet::new(); for _ in 0..num_ids { - ids.insert(UniqueId::generator().unwrap().next_id().await); + ids.insert(UniqueId::generator().unwrap().next_id()); } assert_eq!(ids.len(), num_ids); } - #[tokio::test] - async fn test_ids_monotonically_increasing() { + #[test] + fn test_ids_monotonically_increasing() { let mut state = State::default(); let node_id = 1u64; let mut prev_id = 0u64; for _ in 0..10_000 { - let id = state.next_id(node_id, 0).await; + let id = state.next_id(node_id, 0); assert!(id > prev_id, "ID {id} not greater than previous {prev_id}"); prev_id = id; } } - #[tokio::test] - async fn test_ids_always_positive() { + #[test] + fn test_ids_always_positive() { let mut state = State::default(); let node_id = MAX_NODE_ID; // Use max node to maximize bits used for _ in 0..10_000 { - let id = state.next_id(node_id, 0).await; + let id = state.next_id(node_id, 0); let signed = id as i64; assert!(signed > 0, "ID should be positive, got {signed}"); } @@ -227,12 +214,12 @@ mod test { assert_eq!(id >> 63, 0, "Bit 63 should be clear"); } - #[tokio::test] - async fn test_extract_components() { + #[test] + fn test_extract_components() { let node: u64 = 42; let mut state = State::default(); - let id = state.next_id(node, 0).await; + let id = state.next_id(node, 0); // Extract components back let extracted_seq = id & MAX_SEQUENCE; @@ -244,7 +231,7 @@ mod test { assert!(extracted_elapsed > 0); // Elapsed time since epoch // Generate another ID and verify sequence increments - let id2 = state.next_id(node, 0).await; + let id2 = state.next_id(node, 0); let extracted_seq2 = id2 & MAX_SEQUENCE; let extracted_node2 = (id2 >> NODE_SHIFT) & MAX_NODE_ID; @@ -252,14 +239,14 @@ mod test { assert!(matches!(extracted_seq2, 1 | 0)); // Sequence incremented (or time advanced and reset to 0) } - #[tokio::test] - async fn test_id_offset() { + #[test] + fn test_id_offset() { let offset: u64 = 1_000_000_000; let node: u64 = 5; let mut state = State::default(); for _ in 0..1000 { - let id = state.next_id(node, offset).await; + let id = state.next_id(node, offset); assert!( id > offset, "ID {id} should be greater than offset {offset}" @@ -267,15 +254,15 @@ mod test { } } - #[tokio::test] - async fn test_id_offset_monotonic() { + #[test] + fn test_id_offset_monotonic() { let offset: u64 = 1_000_000_000; let node: u64 = 5; let mut state = State::default(); let mut prev_id = 0u64; for _ in 0..1000 { - let id = state.next_id(node, offset).await; + let id = state.next_id(node, offset); assert!(id > prev_id, "ID {id} not greater than previous {prev_id}"); prev_id = id; } diff --git a/pgdog/src/util.rs b/pgdog/src/util.rs index 0e970a2b9..7615ff31d 100644 --- a/pgdog/src/util.rs +++ b/pgdog/src/util.rs @@ -166,6 +166,9 @@ mod test { #[test] fn test_instance_id_format() { + unsafe { + remove_var("NODE_ID"); + } let id = instance_id(); assert_eq!(id.len(), 8); // All characters should be valid hex digits (0-9, a-f)