Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

---

## [0.3.0] - 2026-03-27

### Changed
- **Breaking:** `PgWireError::Io` now wraps `Arc<std::io::Error>` instead of `String`
- Consumers can match on `.kind()` (e.g. `ErrorKind::UnexpectedEof`, `ConnectionReset`) instead of brittle substring matching on error messages
- `PgWireError` remains `Clone` via the `Arc` wrapper

### Improved
- Wrapped the replication stream in a 128KB `BufReader`, batching WAL messages into fewer `recv()` syscalls
- Reduces syscall overhead significantly during backlog drain scenarios

### Removed
- Replaced `rustls-pemfile` dependency with `rustls-pki-types` PEM parsing (already in the dependency tree via `rustls`)
- Resolves RUSTSEC-2025-0134 (unmaintained crate)

---

## [0.2.0] - 2026-02-08

### Added
Expand Down Expand Up @@ -70,7 +87,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
- Fuzz testing for pgwire framing


[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...HEAD
[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...HEAD
[0.3.0]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.2.0
[0.1.2]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.2
[0.1.1]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.1
Expand Down
6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgwire-replication"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
rust-version = "1.88"
resolver = "2"
Expand Down Expand Up @@ -29,7 +29,6 @@ tls-rustls = [
"dep:rustls",
"dep:tokio-rustls",
"dep:webpki-roots",
"dep:rustls-pemfile",
]
scram = ["dep:base64", "dep:hmac", "dep:sha2", "dep:rand"]
md5 = ["dep:md5"]
Expand Down Expand Up @@ -58,7 +57,6 @@ tracing = "0.1"
rustls = { version = "0.23", features = ["ring"], optional = true }
tokio-rustls = { version = "0.26", optional = true }
webpki-roots = { version = "0.26", optional = true }
rustls-pemfile = { version = "2", optional = true }

# Optional auth
base64 = { version = "0.22", optional = true }
Expand All @@ -67,7 +65,7 @@ sha2 = { version = "0.10", optional = true }
rand = { version = "0.9", optional = true }
md5 = { version = "0.7", optional = true }

testcontainers = { version = "0.25", optional = true }
testcontainers = { version = "0.27", optional = true }
tokio-postgres = { version = "0.7", optional = true }
tracing-subscriber = { version = "0.3", optional = true, features = [
"env-filter",
Expand Down
8 changes: 4 additions & 4 deletions src/client/tokio_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result

let path = cfg.unix_socket_path();
let mut stream = UnixStream::connect(&path).await.map_err(|e| {
PgWireError::Io(format!(
"failed to connect to Unix socket {}: {e}",
path.display()
))
PgWireError::Io(std::sync::Arc::new(std::io::Error::new(
e.kind(),
format!("failed to connect to Unix socket {}: {e}", path.display()),
)))
})?;

return worker.run_on_stream(&mut stream).await;
Expand Down
14 changes: 9 additions & 5 deletions src/client/worker.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::Bytes;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
use tokio::sync::{mpsc, watch};
use tokio::time::Instant;

Expand Down Expand Up @@ -153,10 +153,14 @@ impl WorkerState {
&mut self,
stream: &mut S,
) -> Result<()> {
self.startup(stream).await?;
self.authenticate(stream).await?;
self.start_replication(stream).await?;
self.stream_loop(stream).await
// Wrap in a 128KB read buffer to batch multiple WAL messages into fewer
// recv() syscalls. BufReader delegates AsyncWrite to the inner stream,
// so writes (standby status replies, etc.) are unaffected.
let mut stream = BufReader::with_capacity(128 * 1024, stream);
self.startup(&mut stream).await?;
self.authenticate(&mut stream).await?;
self.start_replication(&mut stream).await?;
self.stream_loop(&mut stream).await
}

/// Send startup message with replication parameters.
Expand Down
9 changes: 5 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
//! - TLS errors (handshake failure, certificate issues)
//! - Task errors (worker panics, unexpected termination)

use std::sync::Arc;
use thiserror::Error;

/// Error type for all pgwire-replication operations.
#[derive(Debug, Error, Clone)]
pub enum PgWireError {
/// I/O error (network, file system).
///
/// Note: `std::io::Error` is not `Clone`, so we store the message.
/// Wraps `std::io::Error` in an `Arc` to preserve `ErrorKind` while
/// keeping `PgWireError` `Clone`. Use `.kind()` via `Arc`'s `Deref`.
#[error("io error: {0}")]
Io(String),
Io(Arc<std::io::Error>),

/// Protocol error - malformed message or unexpected response.
#[error("protocol error: {0}")]
Expand Down Expand Up @@ -80,10 +82,9 @@ impl PgWireError {
}
}

// Manual From impl since io::Error isn't Clone
impl From<std::io::Error> for PgWireError {
fn from(err: std::io::Error) -> Self {
PgWireError::Io(err.to_string())
PgWireError::Io(Arc::new(err))
}
}

Expand Down
151 changes: 25 additions & 126 deletions src/tls/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
//! ```

use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fs::File, io::BufReader, sync::Arc};

use rustls::{ClientConfig, RootCertStore};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
Expand Down Expand Up @@ -289,25 +289,22 @@ fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {

if let Some(path) = &tls.ca_pem_path {
// Load custom CA certificates
let f = File::open(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open CA PEM '{}': {e}",
path.display()
))
})?;
let mut rd = BufReader::new(f);
use rustls::pki_types::pem::PemObject;

let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open CA PEM '{}': {e}",
path.display()
))
})?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to parse CA PEM '{}': {e}",
path.display()
))
})?
.into_iter()
.map(|c| c.into_owned())
.collect();
})?;

let (added, _ignored) = roots.add_parsable_certificates(certs);
if added == 0 {
Expand All @@ -328,27 +325,23 @@ fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
fn load_cert_chain(
path: &std::path::Path,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::CertificateDer;

let f = File::open(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open client certificate '{}': {e}",
path.display()
))
})?;
let mut rd = BufReader::new(f);

let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open client certificate '{}': {e}",
path.display()
))
})?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to parse client certificate '{}': {e}",
path.display()
))
})?
.into_iter()
.map(|c| c.into_owned())
.collect();
})?;

if certs.is_empty() {
return Err(PgWireError::Tls(format!(
Expand All @@ -364,110 +357,16 @@ fn load_cert_chain(
///
/// Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC) key formats.
fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
// Try PKCS#8 first (most common modern format)
if let Some(key) = try_load_pkcs8_key(path)? {
return Ok(key);
}

// Try RSA PKCS#1 format
if let Some(key) = try_load_rsa_key(path)? {
return Ok(key);
}

// Try EC SEC1 format
if let Some(key) = try_load_ec_key(path)? {
return Ok(key);
}

Err(PgWireError::Tls(format!(
"TLS config error: no private key found in '{}'. \
Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
path.display()
)))
}

fn try_load_pkcs8_key(
path: &std::path::Path,
) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
use rustls::pki_types::PrivateKeyDer;

let f = File::open(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open private key '{}': {e}",
path.display()
))
})?;
let mut rd = BufReader::new(f);

let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::pkcs8_private_keys(&mut rd)
.filter_map(|r| r.ok())
.map(PrivateKeyDer::from)
.collect();

match keys.len() {
0 => Ok(None),
1 => Ok(Some(keys.into_iter().next().unwrap())),
n => Err(PgWireError::Tls(format!(
"TLS config error: found {n} PKCS#8 keys in '{}', expected 1",
path.display()
))),
}
}

fn try_load_rsa_key(
path: &std::path::Path,
) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::PrivateKeyDer;

let f = File::open(path).map_err(|e| {
PrivateKeyDer::from_pem_file(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open private key '{}': {e}",
"TLS config error: failed to load private key from '{}': {e}. \
Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
path.display()
))
})?;
let mut rd = BufReader::new(f);

let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::rsa_private_keys(&mut rd)
.filter_map(|r| r.ok())
.map(PrivateKeyDer::from)
.collect();

match keys.len() {
0 => Ok(None),
1 => Ok(Some(keys.into_iter().next().unwrap())),
n => Err(PgWireError::Tls(format!(
"TLS config error: found {n} RSA keys in '{}', expected 1",
path.display()
))),
}
}

fn try_load_ec_key(
path: &std::path::Path,
) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
use rustls::pki_types::PrivateKeyDer;

let f = File::open(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open private key '{}': {e}",
path.display()
))
})?;
let mut rd = BufReader::new(f);

let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::ec_private_keys(&mut rd)
.filter_map(|r| r.ok())
.map(PrivateKeyDer::from)
.collect();

match keys.len() {
0 => Ok(None),
1 => Ok(Some(keys.into_iter().next().unwrap())),
n => Err(PgWireError::Tls(format!(
"TLS config error: found {n} EC keys in '{}', expected 1",
path.display()
))),
}
})
}

// ==================== Custom Certificate Verifiers ====================
Expand Down Expand Up @@ -661,7 +560,7 @@ mod tests {
let f = NamedTempFile::new().unwrap();

let err = load_private_key(f.path()).unwrap_err().to_string();
assert!(err.contains("no private key"));
assert!(err.contains("failed to load private key"));
}

#[test]
Expand Down
Loading