diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f94fd4..6754709 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/). --- +## [0.3.0] - 2026-03-27 + +### Changed +- **Breaking:** `PgWireError::Io` now wraps `Arc` instead of `String` + - Consumers can match on `.kind()` (e.g. `ErrorKind::UnexpectedEof`, `ConnectionReset`) instead of brittle substring matching on error messages + - `PgWireError` remains `Clone` via the `Arc` wrapper + +### Improved +- Wrapped the replication stream in a 128KB `BufReader`, batching WAL messages into fewer `recv()` syscalls + - Reduces syscall overhead significantly during backlog drain scenarios + +### Removed +- Replaced `rustls-pemfile` dependency with `rustls-pki-types` PEM parsing (already in the dependency tree via `rustls`) + - Resolves RUSTSEC-2025-0134 (unmaintained crate) + +--- + ## [0.2.0] - 2026-02-08 ### Added @@ -70,7 +87,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - Fuzz testing for pgwire framing -[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...HEAD +[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...HEAD +[0.3.0]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.2.0 [0.1.2]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.2 [0.1.1]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.1 diff --git a/Cargo.toml b/Cargo.toml index b5d30e1..13754b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgwire-replication" -version = "0.2.0" +version = "0.3.0" edition = "2021" rust-version = "1.88" resolver = "2" @@ -29,7 +29,6 @@ tls-rustls = [ "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots", - "dep:rustls-pemfile", ] scram = ["dep:base64", "dep:hmac", "dep:sha2", "dep:rand"] md5 = ["dep:md5"] @@ -58,7 +57,6 @@ tracing = "0.1" rustls = { version = "0.23", features = ["ring"], optional = true } tokio-rustls = { version = "0.26", optional = true } webpki-roots = { version = "0.26", optional = true } -rustls-pemfile = { version = "2", optional = true } # Optional auth base64 = { version = "0.22", optional = true } @@ -67,7 +65,7 @@ sha2 = { version = "0.10", optional = true } rand = { version = "0.9", optional = true } md5 = { version = "0.7", optional = true } -testcontainers = { version = "0.25", optional = true } +testcontainers = { version = "0.27", optional = true } tokio-postgres = { version = "0.7", optional = true } tracing-subscriber = { version = "0.3", optional = true, features = [ "env-filter", diff --git a/src/client/tokio_client.rs b/src/client/tokio_client.rs index 47f8bc1..88f26c1 100644 --- a/src/client/tokio_client.rs +++ b/src/client/tokio_client.rs @@ -263,10 +263,10 @@ async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result let path = cfg.unix_socket_path(); let mut stream = UnixStream::connect(&path).await.map_err(|e| { - PgWireError::Io(format!( - "failed to connect to Unix socket {}: {e}", - path.display() - )) + PgWireError::Io(std::sync::Arc::new(std::io::Error::new( + e.kind(), + format!("failed to connect to Unix socket {}: {e}", path.display()), + ))) })?; return worker.run_on_stream(&mut stream).await; diff --git a/src/client/worker.rs b/src/client/worker.rs index c3038d1..f4a4672 100644 --- a/src/client/worker.rs +++ b/src/client/worker.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; use tokio::sync::{mpsc, watch}; use tokio::time::Instant; @@ -153,10 +153,14 @@ impl WorkerState { &mut self, stream: &mut S, ) -> Result<()> { - self.startup(stream).await?; - self.authenticate(stream).await?; - self.start_replication(stream).await?; - self.stream_loop(stream).await + // Wrap in a 128KB read buffer to batch multiple WAL messages into fewer + // recv() syscalls. BufReader delegates AsyncWrite to the inner stream, + // so writes (standby status replies, etc.) are unaffected. + let mut stream = BufReader::with_capacity(128 * 1024, stream); + self.startup(&mut stream).await?; + self.authenticate(&mut stream).await?; + self.start_replication(&mut stream).await?; + self.stream_loop(&mut stream).await } /// Send startup message with replication parameters. diff --git a/src/error.rs b/src/error.rs index 47bafe7..7b7a2d8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,6 +8,7 @@ //! - TLS errors (handshake failure, certificate issues) //! - Task errors (worker panics, unexpected termination) +use std::sync::Arc; use thiserror::Error; /// Error type for all pgwire-replication operations. @@ -15,9 +16,10 @@ use thiserror::Error; pub enum PgWireError { /// I/O error (network, file system). /// - /// Note: `std::io::Error` is not `Clone`, so we store the message. + /// Wraps `std::io::Error` in an `Arc` to preserve `ErrorKind` while + /// keeping `PgWireError` `Clone`. Use `.kind()` via `Arc`'s `Deref`. #[error("io error: {0}")] - Io(String), + Io(Arc), /// Protocol error - malformed message or unexpected response. #[error("protocol error: {0}")] @@ -80,10 +82,9 @@ impl PgWireError { } } -// Manual From impl since io::Error isn't Clone impl From for PgWireError { fn from(err: std::io::Error) -> Self { - PgWireError::Io(err.to_string()) + PgWireError::Io(Arc::new(err)) } } diff --git a/src/tls/rustls.rs b/src/tls/rustls.rs index 665d3e8..7281322 100644 --- a/src/tls/rustls.rs +++ b/src/tls/rustls.rs @@ -50,8 +50,8 @@ //! ``` use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; -use std::{fs::File, io::BufReader, sync::Arc}; use rustls::{ClientConfig, RootCertStore}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -289,25 +289,22 @@ fn build_root_store(tls: &TlsConfig) -> Result { if let Some(path) = &tls.ca_pem_path { // Load custom CA certificates - let f = File::open(path).map_err(|e| { - PgWireError::Tls(format!( - "TLS config error: failed to open CA PEM '{}': {e}", - path.display() - )) - })?; - let mut rd = BufReader::new(f); + use rustls::pki_types::pem::PemObject; - let certs: Vec> = rustls_pemfile::certs(&mut rd) + let certs: Vec> = CertificateDer::pem_file_iter(path) + .map_err(|e| { + PgWireError::Tls(format!( + "TLS config error: failed to open CA PEM '{}': {e}", + path.display() + )) + })? .collect::, _>>() .map_err(|e| { PgWireError::Tls(format!( "TLS config error: failed to parse CA PEM '{}': {e}", path.display() )) - })? - .into_iter() - .map(|c| c.into_owned()) - .collect(); + })?; let (added, _ignored) = roots.add_parsable_certificates(certs); if added == 0 { @@ -328,27 +325,23 @@ fn build_root_store(tls: &TlsConfig) -> Result { fn load_cert_chain( path: &std::path::Path, ) -> Result>> { + use rustls::pki_types::pem::PemObject; use rustls::pki_types::CertificateDer; - let f = File::open(path).map_err(|e| { - PgWireError::Tls(format!( - "TLS config error: failed to open client certificate '{}': {e}", - path.display() - )) - })?; - let mut rd = BufReader::new(f); - - let certs: Vec> = rustls_pemfile::certs(&mut rd) + let certs: Vec> = CertificateDer::pem_file_iter(path) + .map_err(|e| { + PgWireError::Tls(format!( + "TLS config error: failed to open client certificate '{}': {e}", + path.display() + )) + })? .collect::, _>>() .map_err(|e| { PgWireError::Tls(format!( "TLS config error: failed to parse client certificate '{}': {e}", path.display() )) - })? - .into_iter() - .map(|c| c.into_owned()) - .collect(); + })?; if certs.is_empty() { return Err(PgWireError::Tls(format!( @@ -364,110 +357,16 @@ fn load_cert_chain( /// /// Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC) key formats. fn load_private_key(path: &std::path::Path) -> Result> { - // Try PKCS#8 first (most common modern format) - if let Some(key) = try_load_pkcs8_key(path)? { - return Ok(key); - } - - // Try RSA PKCS#1 format - if let Some(key) = try_load_rsa_key(path)? { - return Ok(key); - } - - // Try EC SEC1 format - if let Some(key) = try_load_ec_key(path)? { - return Ok(key); - } - - Err(PgWireError::Tls(format!( - "TLS config error: no private key found in '{}'. \ - Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)", - path.display() - ))) -} - -fn try_load_pkcs8_key( - path: &std::path::Path, -) -> Result>> { - use rustls::pki_types::PrivateKeyDer; - - let f = File::open(path).map_err(|e| { - PgWireError::Tls(format!( - "TLS config error: failed to open private key '{}': {e}", - path.display() - )) - })?; - let mut rd = BufReader::new(f); - - let keys: Vec> = rustls_pemfile::pkcs8_private_keys(&mut rd) - .filter_map(|r| r.ok()) - .map(PrivateKeyDer::from) - .collect(); - - match keys.len() { - 0 => Ok(None), - 1 => Ok(Some(keys.into_iter().next().unwrap())), - n => Err(PgWireError::Tls(format!( - "TLS config error: found {n} PKCS#8 keys in '{}', expected 1", - path.display() - ))), - } -} - -fn try_load_rsa_key( - path: &std::path::Path, -) -> Result>> { + use rustls::pki_types::pem::PemObject; use rustls::pki_types::PrivateKeyDer; - let f = File::open(path).map_err(|e| { + PrivateKeyDer::from_pem_file(path).map_err(|e| { PgWireError::Tls(format!( - "TLS config error: failed to open private key '{}': {e}", + "TLS config error: failed to load private key from '{}': {e}. \ + Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)", path.display() )) - })?; - let mut rd = BufReader::new(f); - - let keys: Vec> = rustls_pemfile::rsa_private_keys(&mut rd) - .filter_map(|r| r.ok()) - .map(PrivateKeyDer::from) - .collect(); - - match keys.len() { - 0 => Ok(None), - 1 => Ok(Some(keys.into_iter().next().unwrap())), - n => Err(PgWireError::Tls(format!( - "TLS config error: found {n} RSA keys in '{}', expected 1", - path.display() - ))), - } -} - -fn try_load_ec_key( - path: &std::path::Path, -) -> Result>> { - use rustls::pki_types::PrivateKeyDer; - - let f = File::open(path).map_err(|e| { - PgWireError::Tls(format!( - "TLS config error: failed to open private key '{}': {e}", - path.display() - )) - })?; - let mut rd = BufReader::new(f); - - let keys: Vec> = rustls_pemfile::ec_private_keys(&mut rd) - .filter_map(|r| r.ok()) - .map(PrivateKeyDer::from) - .collect(); - - match keys.len() { - 0 => Ok(None), - 1 => Ok(Some(keys.into_iter().next().unwrap())), - n => Err(PgWireError::Tls(format!( - "TLS config error: found {n} EC keys in '{}', expected 1", - path.display() - ))), - } + }) } // ==================== Custom Certificate Verifiers ==================== @@ -661,7 +560,7 @@ mod tests { let f = NamedTempFile::new().unwrap(); let err = load_private_key(f.path()).unwrap_err().to_string(); - assert!(err.contains("no private key")); + assert!(err.contains("failed to load private key")); } #[test]