Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[profile.ci.junit]
path = "junit.xml"
48 changes: 48 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Rust
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Cache Postgres Image
id: cache-postgres-image
uses: actions/cache@v4
with:
path: /tmp/postgres-17.tar
key: postgres-17-${{ runner.os }}
- name: Pull and Save Postgres 17 (if cache miss)
if: steps.cache-postgres-image.outputs.cache-hit != 'true'
run: |
docker pull postgres:17
docker save postgres:17 -o /tmp/postgres-17.tar
- name: Load Postgres 17 (if cache hit)
if: steps.cache-postgres-image.outputs.cache-hit == 'true'
run: |
docker load -i /tmp/postgres-17.tar
- uses: actions/checkout@v4
- uses: taiki-e/install-action@cargo-nextest
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Run tests w/ aws-lc-rs (default)
run: cargo llvm-cov nextest --profile ci
- name: Run tests w/ ring (optional)
run: >
cargo llvm-cov nextest
--no-default-features
--features ring
- name: Generate coverage report
run: cargo llvm-cov report --codecov --output-path coverage.json
- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
8 changes: 8 additions & 0 deletions .yamllint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
extends: default
rules:
document-start: disable
brackets:
max-spaces-inside: 2
braces:
max-spaces-inside: 2
truthy: disable
33 changes: 23 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
[package]
name = "tokio-postgres-rustls"
description = "Rustls integration for tokio-postgres"
version = "0.13.0"
authors = ["Jasper Hugo <jasper@jasperhugo.com>"]
version = "0.14.0"
authors = [
"Jasper Hugo <jasper@jasperhugo.com>",
"Dwayne Sykes <github-public@sykes.pw>",
"Aumetra Weisman <aumetra@cryptolab.net>",
"Conrad Ludgate <conradludgate@gmail.com>",
"Karsten Borgwaldt <kb@spambri.de>",
"Philip Dubé <philip@peerdb.io>",
"Michael Sowka <soilfiction@gmail.com>",
"ol <ol@teuto.net>",
]
repository = "https://github.com/jbg/tokio-postgres-rustls"
edition = "2018"
license = "MIT"
readme = "README.md"

[features]
default = ["aws-lc-rs"]
aws-lc-rs = ["rustls/aws-lc-rs"]
ring = ["rustls/ring"]

[dependencies]
const-oid = { version = "0.9.6", default-features = false, features = ["db"] }
ring = { version = "0.17", default-features = false }
rustls = { version = "0.23", default-features = false }
sha2 = { version = "0.10", default-features = false, features = ["oid"] }
tokio = { version = "1", default-features = false }
tokio-postgres = { version = "0.7", default-features = false }
tokio-postgres = { version = "0.7", default-features = false, features = [
"runtime",
] }
tokio-rustls = { version = "0.26", default-features = false }
x509-cert = { version = "0.2.5", default-features = false, features = ["std"] }

[dev-dependencies]
bollard = { version = "0.19.2" }
env_logger = { version = "0.11", default-features = false }
tokio = { version = "1", default-features = false, features = ["macros", "rt"] }
tokio-postgres = { version = "0.7", default-features = false, features = [
"runtime",
] }
rustls = { version = "0.23", default-features = false, features = [
"std",
"logging",
"tls12",
"ring",
] }
sha2 = { version = "0.10", default-features = false, features = ["std", "oid"] }
tokio = { version = "1", default-features = false, features = ["macros", "rt"] }
158 changes: 36 additions & 122 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![doc = include_str!("../README.md")]
#![forbid(rust_2018_idioms)]
#![deny(missing_docs, unsafe_code)]
#![forbid(missing_docs, unsafe_code)]
#![warn(clippy::all, clippy::pedantic)]

use std::{convert::TryFrom, sync::Arc};
Expand All @@ -17,23 +17,23 @@ mod private {
task::{Context, Poll},
};

use const_oid::db::{
use rustls::pki_types::ServerName;
use sha2::digest::const_oid::db::{
rfc5912::{
ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
SHA_512_WITH_RSA_ENCRYPTION,
},
rfc8410::ID_ED_25519,
};
use ring::digest;
use rustls::pki_types::ServerName;
use sha2::{Digest, Sha256, Sha384, Sha512};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use tokio_rustls::{client::TlsStream, TlsConnector};
use x509_cert::{der::Decode, TbsCertificate};
use x509_cert::{der::Decode, Certificate};

pub struct TlsConnectFuture<S> {
pub inner: tokio_rustls::Connect<S>,
inner: tokio_rustls::Connect<S>,
}

impl<S> Future for TlsConnectFuture<S>
Expand All @@ -42,11 +42,8 @@ mod private {
{
type Output = io::Result<RustlsStream<S>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: If `self` is pinned, so is `inner`.
#[allow(unsafe_code)]
let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
fut.poll(cx).map_ok(RustlsStream)
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
}
}

Expand Down Expand Up @@ -74,48 +71,39 @@ mod private {

pub struct RustlsStream<S>(TlsStream<S>);

impl<S> RustlsStream<S> {
pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
// SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
#[allow(unsafe_code)]
unsafe {
self.map_unchecked_mut(|this| &mut this.0)
}
}
}

impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
let (_, session) = self.0.get_ref();
match session.peer_certificates() {
Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
.ok()
.and_then(|cert| {
let digest = match cert.signature.oid {
Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
|_| ChannelBinding::none(),
|cert| {
match cert.signature_algorithm.oid {
// Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
ID_SHA_1
| ID_SHA_256
| SHA_1_WITH_RSA_ENCRYPTION
| SHA_256_WITH_RSA_ENCRYPTION
| ECDSA_WITH_SHA_256 => &digest::SHA256,
| ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
Sha256::digest(certs[0].as_ref()).to_vec(),
),
ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
&digest::SHA384
ChannelBinding::tls_server_end_point(
Sha384::digest(certs[0].as_ref()).to_vec(),
)
}
ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
&digest::SHA512
ChannelBinding::tls_server_end_point(
Sha512::digest(certs[0].as_ref()).to_vec(),
)
}
_ => return None,
};

Some(digest)
})
.map_or_else(ChannelBinding::none, |algorithm| {
let hash = digest::digest(algorithm, certs[0].as_ref());
ChannelBinding::tls_server_end_point(hash.as_ref().into())
}),
_ => ChannelBinding::none(),
}
},
),
_ => ChannelBinding::none(),
}
}
Expand All @@ -126,11 +114,11 @@ mod private {
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
self.project_stream().poll_read(cx, buf)
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

Expand All @@ -139,22 +127,25 @@ mod private {
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
self.project_stream().poll_write(cx, buf)
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
self.project_stream().poll_flush(cx)
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
self.project_stream().poll_shutdown(cx)
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}
Expand Down Expand Up @@ -194,80 +185,3 @@ where
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use rustls::pki_types::{CertificateDer, UnixTime};
use rustls::{
client::danger::ServerCertVerifier,
client::danger::{HandshakeSignatureValid, ServerCertVerified},
Error, SignatureScheme,
};

#[derive(Debug)]
struct AcceptAllVerifier {}
impl ServerCertVerifier for AcceptAllVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::ED25519,
]
}
}

#[tokio::test]
async fn it_works() {
env_logger::builder().is_test(true).try_init().unwrap();

let mut config = rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
let tls = super::MakeRustlsConnect::new(config);
let (client, conn) = tokio_postgres::connect(
"sslmode=require host=localhost port=5432 user=postgres",
tls,
)
.await
.expect("connect");
tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
let stmt = client.prepare("SELECT 1").await.expect("prepare");
let _ = client.query(&stmt, &[]).await.expect("query");
}
}
Loading