From fa6c2adfe1e47906c8a5f194f9c132830e8285cc Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Wed, 22 Sep 2021 12:08:12 +0200 Subject: [PATCH 01/16] Clippy fixes --- src/lib.rs | 36 +++++++++++++++++++++++------------- src/version1.rs | 38 +++++++++++++++++++++++--------------- src/version2.rs | 8 ++++---- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a1fe67d..b60114d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,15 +191,15 @@ mod parse_tests { assert!(!buf.has_remaining()); // Consume the ENTIRE header! fn valid_v4( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -223,15 +223,25 @@ mod parse_tests { ); fn valid_v6( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } @@ -313,7 +323,7 @@ mod parse_tests { 0x49, 0x54, 0x0A, - (2 << 4) | 0, + 2 << 4, ]; const PREFIX_PROXY: [u8; 13] = [ 0x0D, diff --git a/src/version1.rs b/src/version1.rs index 19fc38b..c9a0377 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -190,9 +190,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result unreachable!(), }; - Ok(super::ProxyHeader::Version1 { - addresses, - }) + Ok(super::ProxyHeader::Version1 { addresses }) } pub(crate) fn encode(addresses: ProxyAddresses) -> Result { @@ -278,15 +276,15 @@ mod parse_tests { #[test] fn test_valid_ipv4_cases() { fn valid( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -312,15 +310,25 @@ mod parse_tests { #[test] fn test_valid_ipv6_cases() { fn valid( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } diff --git a/src/version2.rs b/src/version2.rs index 6148e62..fec2b99 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -733,7 +733,7 @@ mod encode_tests { ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, ), - signed(&[(2 << 4) | 0, 0, 0, 0][..]), + signed(&[2 << 4, 0, 0, 0][..]), ); assert_eq!( @@ -756,7 +756,7 @@ mod encode_tests { signed( &[ (2 << 4) | 1, - (1 << 4) | 0, + 1 << 4, 0, 12, 1, @@ -819,7 +819,7 @@ mod encode_tests { ), signed( &[ - (2 << 4) | 0, + 2 << 4, (1 << 4) | 2, 0, 12, @@ -863,7 +863,7 @@ mod encode_tests { ), signed( &[ - (2 << 4) | 0, + 2 << 4, (2 << 4) | 2, 0, 36, From f701d20e4dc021cb074b7f1e337e7088b8aec69b Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Mon, 27 Sep 2021 16:42:47 +0200 Subject: [PATCH 02/16] Fix mojibake in the protocol specification The document is served without an encoding header. --- proxy-protocol.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy-protocol.txt b/proxy-protocol.txt index 863eac6..196a445 100644 --- a/proxy-protocol.txt +++ b/proxy-protocol.txt @@ -33,7 +33,7 @@ Revision history reserved TLV type ranges, added TLV documentation, clarified string encoding. With contributions from Andriy Palamarchuk (Amazon.com). - 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) + 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) 1. Background From 93cd650e59cb44f4f67d9e170721c0a96b9feeb3 Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Wed, 22 Sep 2021 13:51:20 +0200 Subject: [PATCH 03/16] Refactor version2 parsing This factors in common parts of address parsing. There is a tweak to which errors are returned when the buffer is both too small for its declared length and for the minimum length of address data: in this case, we now return an UnexpectedEof indicating the first error, previously the second error was signaled. Tests were updated with larger vectors so they could still exercise InsufficientLengthSpecified error cases. --- src/lib.rs | 3 + src/version2.rs | 146 +++++++++++++++++++----------------------------- 2 files changed, 59 insertions(+), 90 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b60114d..94f43ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -639,6 +639,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..], ] .concat() diff --git a/src/version2.rs b/src/version2.rs index fec2b99..4d9fd44 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -100,18 +100,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= length, UnexpectedEof); - buf.advance(length); - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unspec, - }); - } + ensure!(buf.remaining() >= length, UnexpectedEof); // Time to parse the following: // @@ -134,88 +123,73 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result } unix_addr; // > }; - if address_family == ProxyAddressFamily::Unix { - ensure!( - length >= 108 * 2, - InsufficientLengthSpecified { - given: length, - needs: 108usize * 2, - }, - ); - ensure!(buf.remaining() >= 108 * 2, UnexpectedEof); - let mut source = [0u8; 108]; - let mut destination = [0u8; 108]; - buf.copy_to_slice(&mut source[..]); - buf.copy_to_slice(&mut destination[..]); - // TODO(Mariell Hoversholm): Support TLVs - if length > 108 * 2 { - buf.advance(length - (108 * 2)); - } - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unix { - source, - destination, - }, - }); - } - - let port_length = 4; + // The full length of address data, + // including two addresses and two ports let address_length = match address_family { - ProxyAddressFamily::Inet => 8, - ProxyAddressFamily::Inet6 => 32, - _ => unreachable!(), + ProxyAddressFamily::Inet => (4 + 2) * 2, + ProxyAddressFamily::Inet6 => (16 + 2) * 2, + ProxyAddressFamily::Unix => 108 * 2, + ProxyAddressFamily::Unspec => 0, }; ensure!( - length >= port_length + address_length, + length >= address_length, InsufficientLengthSpecified { given: length, - needs: port_length + address_length, + needs: address_length, }, ); - ensure!( - buf.remaining() >= port_length + address_length, - UnexpectedEof, - ); - - let addresses = if address_family == ProxyAddressFamily::Inet { - let mut data = [0u8; 4]; - buf.copy_to_slice(&mut data[..]); - let source = Ipv4Addr::from(data); + ensure!(buf.remaining() >= address_length, UnexpectedEof,); + + let addresses = match address_family { + ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, + ProxyAddressFamily::Unix => { + let mut source = [0u8; 108]; + let mut destination = [0u8; 108]; + buf.copy_to_slice(&mut source[..]); + buf.copy_to_slice(&mut destination[..]); + ProxyAddresses::Unix { + source, + destination, + } + } + ProxyAddressFamily::Inet => { + let mut data = [0u8; 4]; + buf.copy_to_slice(&mut data[..]); + let source = Ipv4Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv4Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv4Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(source, source_port), - destination: SocketAddrV4::new(destination, destination_port), + ProxyAddresses::Ipv4 { + source: SocketAddrV4::new(source, source_port), + destination: SocketAddrV4::new(destination, destination_port), + } } - } else { - let mut data = [0u8; 16]; - buf.copy_to_slice(&mut data); - let source = Ipv6Addr::from(data); + ProxyAddressFamily::Inet6 => { + let mut data = [0u8; 16]; + buf.copy_to_slice(&mut data); + let source = Ipv6Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv6Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv6Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(source, source_port, 0, 0), - destination: SocketAddrV6::new(destination, destination_port, 0, 0), + ProxyAddresses::Ipv6 { + source: SocketAddrV6::new(source, source_port, 0, 0), + destination: SocketAddrV6::new(destination, destination_port, 0, 0), + } } }; - if length > port_length + address_length { + if length > address_length { // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - (port_length + address_length)); + buf.advance(length - address_length); } Ok(super::ProxyHeader::Version2 { @@ -343,15 +317,9 @@ pub(crate) fn encode( // > }; let len = match addresses { ProxyAddresses::Unspec => 0, - ProxyAddresses::Unix { .. } => { - 108 + 108 - } - ProxyAddresses::Ipv4 { .. } => { - 4 + 4 + 2 + 2 - } - ProxyAddresses::Ipv6 { .. } => { - 16 + 16 + 2 + 2 - } + ProxyAddresses::Unix { .. } => 108 + 108, + ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, + ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; let mut buf = BytesMut::with_capacity(16 + len); @@ -698,6 +666,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..] ), Err(ParseError::InsufficientLengthSpecified { @@ -847,12 +818,7 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Datagram, ProxyAddresses::Ipv6 { - source: SocketAddrV6::new( - Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), - 8192, - 0, - 0, - ), + source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 8192, 0, 0,), destination: SocketAddrV6::new( Ipv6Addr::new(65535, 65535, 32767, 32766, 111, 222, 333, 444), 0, From de308f6970c6a48ae27c0f6dfc409e63b9a936b3 Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Mon, 27 Sep 2021 16:07:31 +0200 Subject: [PATCH 04/16] Add support for TLV extensions The ones supported are the ones documented in proxy-protocol.txt as of 2020/03/05. --- src/lib.rs | 18 +- src/version2.rs | 640 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 612 insertions(+), 46 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 94f43ff..57a0861 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub enum EncodeError { /// An error occurred while encoding version 1. #[snafu(display("there was an error while encoding the v1 header: {}", source))] WriteVersion1 { source: version1::EncodeError }, + /// An error occurred while encoding version 2. + #[snafu(display("there was an error while encoding the v2 header: {}", source))] + WriteVersion2 { source: version2::EncodeError }, } /// The PROXY header emitted at most once at the start of a new connection. @@ -70,6 +73,7 @@ pub enum ProxyHeader { /// The addresses used to connect to the proxy. addresses: version2::ProxyAddresses, + extensions: Vec, }, } @@ -150,7 +154,9 @@ pub fn encode(header: ProxyHeader) -> Result { command, transport_protocol, addresses, - } => version2::encode(command, transport_protocol, addresses), + extensions, + } => version2::encode(command, transport_protocol, addresses, &extensions[..]) + .context(WriteVersion2)?, #[allow(unreachable_patterns)] // May be required to be exhaustive. _ => unimplemented!("Unimplemented version?"), @@ -347,6 +353,7 @@ mod parse_tests { command: version2::ProxyCommand::Local, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); assert_eq!( @@ -355,6 +362,7 @@ mod parse_tests { command: version2::ProxyCommand::Proxy, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); @@ -388,7 +396,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..] @@ -403,6 +411,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -449,6 +458,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -507,7 +517,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..], @@ -532,6 +542,7 @@ mod parse_tests { 0, ) }, + extensions: Vec::new(), }) ); @@ -612,6 +623,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header diff --git a/src/version2.rs b/src/version2.rs index 4d9fd44..b3d6d3e 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; +use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] @@ -19,6 +20,28 @@ pub enum ParseError { #[snafu(display("insufficient length specified: {}, requires minimum {}", given, needs))] InsufficientLengthSpecified { given: usize, needs: usize }, + + #[snafu(display("invalid length specified: {}, causes overflow", given))] + LengthOverflow { given: usize }, + + #[snafu(display("invalid TLV type id specified: {}", type_id))] + InvalidTlvTypeId { type_id: u8 }, + + #[snafu(display("invalid UTF-8: {:?}", bytes))] + InvalidUtf8 { bytes: Vec }, + + #[snafu(display("invalid ASCII: {:?}", bytes))] + InvalidAscii { bytes: Vec }, + + #[snafu(display("trailing data: {:?}", len))] + TrailingData { len: usize }, +} + +#[derive(Debug, Snafu)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum EncodeError { + #[snafu(display("value is too large to encode"))] + ValueTooLarge, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -59,6 +82,310 @@ enum ProxyAddressFamily { Unix, } +trait Tlv: Sized { + /// Identifies the type + fn type_id(&self) -> u8; + + /// The byte size of the value if encoded, or None if too big + fn value_len(&self) -> Result; + + /// Write the value to the provided buffer + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError>; + + fn encoded_len(&self) -> Result { + self.value_len()? + .checked_add(3) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + let vlen = self.value_len()?; + if vlen + .checked_add(3) + .map_or(true, |tlv_len| buf.remaining_mut() < tlv_len.into()) + { + return Err(EncodeError::ValueTooLarge); + } + buf.put_u8(self.type_id()); + buf.put_u16(vlen); + self.encode_value(buf) + } + + // API note: + // We have to pass the len instead of using a view. + // Buf doesn't have a good view / subslice abstraction + // unlike plain slices or even the Bytes implementation + // IMHO (@g2p) it would be better for parse to receive a + // slice or a concrete type. + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result; + + fn parse(buf: &mut impl Buf) -> Result { + if buf.remaining() < 3 { + return Err(ParseError::UnexpectedEof); + } + let type_id = buf.get_u8(); + let vlen = buf.get_u16(); + let expected_rem = buf + .remaining() + .checked_sub(vlen.into()) + .ok_or(ParseError::UnexpectedEof)?; + let r = Self::parse_parts(type_id, vlen, buf)?; + // Assert, because it would be an internal error + assert_eq!(buf.remaining(), expected_rem); + Ok(r) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslClientFlags(u8); + +impl SslClientFlags { + pub fn is_ssl_or_tls(&self) -> bool { + (self.0 & 1) == 1 + } + + pub fn client_authenticated_connection(&self) -> bool { + (self.0 & 2) == 2 + } + pub fn client_authenticated_session(&self) -> bool { + (self.0 & 4) == 4 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslVerifyStatus(u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum SslExtensionTlv { + /// TLS or SSL version in ASCII + Version(String), + /// TLS or SSL cipher suite in ASCII, for example "ECDHE-RSA-AES128-GCM-SHA256" + Cipher(String), + /// TLS or SSL signature algorithm in ASCII + SigAlg(String), + /// TLS or SSL key algorithm in ASCII + KeyAlg(String), + /// With client authentication, the common name for the client certificate in UTF-8 + ClientCN(String), +} + +impl SslExtensionTlv { + fn as_str(&self) -> &str { + match self { + Self::Version(version) => version, + Self::Cipher(cipher) => cipher, + Self::SigAlg(sigalg) => sigalg, + Self::KeyAlg(keyalg) => keyalg, + Self::ClientCN(cn) => cn, + } + } +} + +impl Tlv for SslExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Version(_) => PP2_SUBTYPE_SSL_VERSION, + Self::ClientCN(_) => PP2_SUBTYPE_SSL_CN, + Self::Cipher(_) => PP2_SUBTYPE_SSL_CIPHER, + Self::SigAlg(_) => PP2_SUBTYPE_SSL_SIG_ALG, + Self::KeyAlg(_) => PP2_SUBTYPE_SSL_KEY_ALG, + } + } + + fn value_len(&self) -> Result { + self.as_str() + .len() + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_slice(self.as_str().as_bytes()); + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_SUBTYPE_SSL_VERSION => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CIPHER => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_SIG_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_KEY_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CN => Self::Version(str_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ssl { + client: SslClientFlags, + verify: SslVerifyStatus, + extensions: Vec, +} + +impl Ssl { + fn parse(buf: &mut impl Buf, len: u16) -> Result { + if buf.remaining() < len.into() { + return Err(ParseError::UnexpectedEof); + } + let mut ext_len = len + .checked_sub(5) + .ok_or(ParseError::InsufficientLengthSpecified { + given: len.into(), + needs: 5, + })?; + let client = SslClientFlags(buf.get_u8()); + let verify = SslVerifyStatus(buf.get_u32()); + let mut extensions = Vec::new(); + while ext_len > 0 { + let rem0 = buf.remaining(); + extensions.push(SslExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + // The assert enforces that Buf is implemented sanely + // and not rewound. + let parsed = rem0.checked_sub(rem).expect("Buf error"); + // We don't enforce u16-sized buffers. + // Since we don't pass a bound on how much to parse, + // we can't enforce that the extension parser won't read + // (slightly) more than 64k. + // The assert is safe since ext_len was already u16 and the + // new value is lower. + ext_len = usize::from(ext_len) + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len.into(), + needs: parsed, + })? + .try_into() + .expect("Math error"); + } + Ok(Self { + client, + verify, + extensions, + }) + } + + fn encoded_len(&self) -> Result { + // 1 for flags, 4 for verify status, plus all nested TLVs + self.extensions + .iter() + .try_fold(5u16, |sum, subtlv| { + sum.checked_add(subtlv.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_u8(self.client.0); + buf.put_u32(self.verify.0); + for ext in self.extensions.iter() { + ext.encode(buf)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum ExtensionTlv { + Alpn(Vec), + Authority(String), + Crc32c(u32), + UniqueId(Vec), + Ssl(Ssl), + NetNs(String), +} + +pub(crate) const PP2_TYPE_ALPN: u8 = 0x01; +pub(crate) const PP2_TYPE_AUTHORITY: u8 = 0x02; +pub(crate) const PP2_TYPE_CRC32C: u8 = 0x03; +pub(crate) const PP2_TYPE_NOOP: u8 = 0x04; +pub(crate) const PP2_TYPE_UNIQUE_ID: u8 = 0x05; +pub(crate) const PP2_TYPE_SSL: u8 = 0x20; +pub(crate) const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21; +pub(crate) const PP2_SUBTYPE_SSL_CN: u8 = 0x22; +pub(crate) const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23; +pub(crate) const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24; +pub(crate) const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25; +pub(crate) const PP2_TYPE_NETNS: u8 = 0x30; + +fn vec_from_buf(buf: &mut impl Buf, len: u16) -> Vec { + let mut r = vec![0; len.into()]; + buf.copy_to_slice(&mut r); + r +} + +fn str_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let v = vec_from_buf(buf, len); + let r = String::from_utf8(v).map_err(|e| ParseError::InvalidUtf8 { + bytes: e.into_bytes(), + })?; + Ok(r) +} + +fn ascii_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let s = str_from_buf(buf, len)?; + if !s.is_ascii() { + Err(ParseError::InvalidAscii { + bytes: s.into_bytes(), + }) + } else { + Ok(s) + } +} + +impl Tlv for ExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Alpn(_) => PP2_TYPE_ALPN, + Self::Authority(_) => PP2_TYPE_AUTHORITY, + Self::Crc32c(_) => PP2_TYPE_CRC32C, + Self::UniqueId(_) => PP2_TYPE_UNIQUE_ID, + Self::Ssl(_) => PP2_TYPE_SSL, + Self::NetNs(_) => PP2_TYPE_NETNS, + } + } + + fn value_len(&self) -> Result { + match self { + Self::Alpn(alpn) => alpn.len(), + Self::Authority(authority) => authority.len(), + Self::Crc32c(_) => 4, + Self::UniqueId(id) => id.len(), + Self::Ssl(data) => data.encoded_len()?.into(), + Self::NetNs(netns) => netns.len(), + } + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + match self { + Self::Alpn(by) | Self::UniqueId(by) => buf.put_slice(by), + Self::Authority(st) | Self::NetNs(st) => buf.put_slice(st.as_bytes()), + Self::Crc32c(crc) => buf.put_u32(*crc), + Self::Ssl(ssl) => ssl.encode(buf)?, + }; + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_TYPE_ALPN => Self::Alpn(vec_from_buf(buf, len)), + PP2_TYPE_AUTHORITY => Self::Authority(str_from_buf(buf, len)?), + PP2_TYPE_CRC32C => Self::Crc32c(buf.get_u32()), + PP2_TYPE_UNIQUE_ID => Self::UniqueId(vec_from_buf(buf, len)), + PP2_TYPE_SSL => Self::Ssl(Ssl::parse(buf, len)?), + PP2_TYPE_NETNS => Self::NetNs(ascii_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +// Note: this is internal, assumes the first 12 bytes were parsed, +// and ignores the version half of the first byte. pub(crate) fn parse(buf: &mut impl Buf) -> Result { // We need to parse the following: // @@ -71,9 +398,10 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result> 4; + let command = st << 4 >> 4; let command = match command { 0 => ProxyCommand::Local, 1 => ProxyCommand::Proxy, @@ -125,21 +453,21 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result (4 + 2) * 2, ProxyAddressFamily::Inet6 => (16 + 2) * 2, ProxyAddressFamily::Unix => 108 * 2, ProxyAddressFamily::Unspec => 0, }; - ensure!( - length >= address_length, - InsufficientLengthSpecified { - given: length, - needs: address_length, - }, - ); - ensure!(buf.remaining() >= address_length, UnexpectedEof,); + let mut ext_len = + length + .checked_sub(address_len) + .ok_or(ParseError::InsufficientLengthSpecified { + given: length, + needs: address_len, + })?; + ensure!(buf.remaining() >= address_len, UnexpectedEof,); let addresses = match address_family { ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, @@ -187,23 +515,74 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result address_length { - // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - address_length); + let mut extensions = Vec::new(); + while ext_len > 0 { + // At this point, we know that remaining() >= ext_len + if buf.chunk()[0] == PP2_TYPE_NOOP { + if ext_len < 3 { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: 3, + }); + } + // Read/skip the type after peeking + buf.get_u8(); + let skip_len = buf.get_u16(); + let noop_len = 3u16 + .checked_add(skip_len) + .ok_or(ParseError::LengthOverflow { + given: skip_len.into(), + })? + .into(); + if noop_len > ext_len { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: noop_len, + }); + } + ext_len -= noop_len; + } else { + let rem0 = buf.remaining(); + extensions.push(ExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + let parsed = rem0.checked_sub(rem).expect("Buf error"); + ext_len = + ext_len + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: parsed, + })?; + } } Ok(super::ProxyHeader::Version2 { command, transport_protocol, addresses, + extensions, }) } +// Currently used in tests, has to be internal for the same reasons +// parse() currently is. +#[cfg(test)] +pub(crate) fn parse_fully(buf: &mut impl Buf) -> Result { + let r = parse(buf)?; + if buf.has_remaining() { + return Err(ParseError::TrailingData { + len: buf.remaining(), + }); + } + Ok(r) +} + pub(crate) fn encode( command: ProxyCommand, transport_protocol: ProxyTransportProtocol, addresses: ProxyAddresses, -) -> BytesMut { + extensions: &[ExtensionTlv], +) -> Result { // > struct proxy_hdr_v2 { // > uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ // > uint8_t ver_cmd; /* protocol version and command */ @@ -315,17 +694,27 @@ pub(crate) fn encode( // > uint8_t dst_addr[108]; // > } unix_addr; // > }; - let len = match addresses { + let address_len: u16 = match addresses { ProxyAddresses::Unspec => 0, ProxyAddresses::Unix { .. } => 108 + 108, ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; + // With extensions, we need to distinguish len and address_len + let len = extensions + .iter() + .try_fold(address_len, |acc, ext| { + acc.checked_add(ext.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge)?; - let mut buf = BytesMut::with_capacity(16 + len); + let blen = 16usize + .checked_add(len.into()) + .ok_or(EncodeError::ValueTooLarge)?; + let mut buf = BytesMut::with_capacity(blen); buf.put_slice(&SIG[..]); buf.put_slice(&[ver_cmd, fam][..]); - buf.put_u16(len as u16); + buf.put_u16(len); match addresses { ProxyAddresses::Unspec => (), @@ -356,7 +745,13 @@ pub(crate) fn encode( } } - buf + for ext in extensions.iter() { + ext.encode(&mut buf)?; + } + + assert_eq!(buf.len(), blen); + + Ok(buf) } #[cfg(test)] @@ -371,23 +766,23 @@ mod parse_tests { #[test] fn test_unspec() { assert_eq!( - parse(&mut &[0u8; 16][..]), + parse_fully(&mut &[0u8; 4][..]), Ok(ProxyHeader::Version2 { command: ProxyCommand::Local, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); - let mut prefix = BytesMut::from(&[1u8][..]); - prefix.reserve(16); - prefix.extend_from_slice(&[0u8; 16][..]); + let mut prefix = BytesMut::from(&[1u8, 0, 0, 0][..]); assert_eq!( - parse(&mut prefix), + parse_fully(&mut prefix), Ok(ProxyHeader::Version2 { command: ProxyCommand::Proxy, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); } @@ -395,7 +790,7 @@ mod parse_tests { #[test] fn test_ipv4() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -424,7 +819,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -436,6 +831,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -480,6 +876,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -488,7 +885,7 @@ mod parse_tests { #[test] fn test_ipv6() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -541,7 +938,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -563,6 +960,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); @@ -641,6 +1039,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -651,12 +1050,12 @@ mod parse_tests { let mut data = [0u8; 200]; rand::thread_rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. - assert!(parse(&mut &data[..]).is_err()); + assert!(parse_fully(&mut &data[..]).is_err()); - assert_eq!(parse(&mut &[0][..]), Err(ParseError::UnexpectedEof)); + assert_eq!(parse_fully(&mut &[0][..]), Err(ParseError::UnexpectedEof)); assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -677,12 +1076,86 @@ mod parse_tests { }), ); } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + parse_fully( + &mut &[ + // Proxy command + 1u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + ), + Ok(ProxyHeader::Version2 { + command: ProxyCommand::Proxy, + addresses: ProxyAddresses::Unspec, + transport_protocol: ProxyTransportProtocol::Unspec, + extensions: vec![ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], + }), + ); + } } #[cfg(test)] mod encode_tests { use super::*; - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use pretty_assertions::assert_eq; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -690,10 +1163,10 @@ mod encode_tests { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; - fn signed(buf: &[u8]) -> Bytes { + fn signed(buf: &[u8]) -> BytesMut { let mut bytes = BytesMut::from(&SIG[..]); bytes.extend_from_slice(buf); - bytes.freeze() + bytes } #[test] @@ -703,8 +1176,9 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[2 << 4, 0, 0, 0][..]), + Ok(signed(&[2 << 4, 0, 0, 0][..])), ); assert_eq!( @@ -712,8 +1186,9 @@ mod encode_tests { ProxyCommand::Proxy, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 1, 0, 0, 0][..]), + Ok(signed(&[(2 << 4) | 1, 0, 0, 0][..])), ); assert_eq!( encode( @@ -723,8 +1198,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, 1 << 4, @@ -743,7 +1219,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); } @@ -757,8 +1233,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, (1 << 4) | 1, @@ -777,7 +1254,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); assert_eq!( encode( @@ -787,8 +1264,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 324), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 2187), }, + &[], ), - signed( + Ok(signed( &[ 2 << 4, (1 << 4) | 2, @@ -807,7 +1285,7 @@ mod encode_tests { (2187 >> 8) as u8, 2187u16 as u8, ][..] - ), + )), ); } @@ -825,9 +1303,10 @@ mod encode_tests { 0, 0, ), - } + }, + &[], ), - signed( + Ok(signed( &[ 2 << 4, (2 << 4) | 2, @@ -870,7 +1349,82 @@ mod encode_tests { 0, 0, ][..] + )), + ); + } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + encode( + ProxyCommand::Proxy, + ProxyTransportProtocol::Unspec, + ProxyAddresses::Unspec, + &[ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], ), + Ok(signed( + &[ + // Version 2, + // Proxy command + 0x21u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + )), ); } } From 2900ffd56cd664f28664477836fcf99ec4686994 Mon Sep 17 00:00:00 2001 From: junderw Date: Wed, 4 Sep 2024 22:38:35 +0900 Subject: [PATCH 05/16] Fix EOF panics --- src/version1.rs | 2 ++ src/version2.rs | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/version1.rs b/src/version1.rs index c9a0377..1862d0c 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -175,6 +175,8 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= 1, UnexpectedEof); ensure!(buf.get_u8() == LF, IllegalHeaderEnding); let addresses = match (source, destination) { diff --git a/src/version2.rs b/src/version2.rs index b3d6d3e..e3a4089 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -459,6 +459,23 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result 108 * 2, ProxyAddressFamily::Unspec => 0, }; + if address_family == ProxyAddressFamily::Unix { + ensure!( + length >= 108 * 2, + InsufficientLengthSpecified { + given: length, + needs: 108usize * 2, + }, + ); + ensure!(buf.remaining() >= length, UnexpectedEof); + let mut source = [0u8; 108]; + let mut destination = [0u8; 108]; + buf.copy_to_slice(&mut source[..]); + buf.copy_to_slice(&mut destination[..]); + // TODO(Mariell Hoversholm): Support TLVs + if length > 108 * 2 { + buf.advance(length - (108 * 2)); + } let mut ext_len = length From c6a4b3bb940440877fcfb6824231189e104981ad Mon Sep 17 00:00:00 2001 From: timvisee Date: Fri, 19 Nov 2021 15:12:27 +0100 Subject: [PATCH 06/16] feat: add ProxyAddresses::from_socket_addrs helper --- src/version2.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/version2.rs b/src/version2.rs index e3a4089..f280202 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,7 +1,6 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; -use std::convert::TryInto; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] #[cfg_attr(test, derive(PartialEq, Eq))] @@ -74,6 +73,25 @@ pub enum ProxyAddresses { }, } +impl ProxyAddresses { + /// Construct IPv4 or IPv6 from given source and destination address. + /// + /// Returns `None` if the addresses use a different IP version. + pub fn from_socket_addrs(source: SocketAddr, destination: SocketAddr) -> Option { + match (source, destination) { + (SocketAddr::V4(source), SocketAddr::V4(destination)) => Some(Self::Ipv4 { + source, + destination, + }), + (SocketAddr::V6(source), SocketAddr::V6(destination)) => Some(Self::Ipv6 { + source, + destination, + }), + (_, _) => None, + } + } +} + #[derive(PartialEq, Eq)] enum ProxyAddressFamily { Unspec, From 949db2c862b3d6619ee378295f27d651976043fa Mon Sep 17 00:00:00 2001 From: David Papp Date: Wed, 22 Oct 2025 16:14:47 +0200 Subject: [PATCH 07/16] fix: add missing closing brace in parse function to ensure proper syntax --- src/version2.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/version2.rs b/src/version2.rs index f280202..3e5b0a0 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; +use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] @@ -494,6 +495,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result 108 * 2 { buf.advance(length - (108 * 2)); } + } let mut ext_len = length From cd1eb7f5962f9a15899a05f01e3015a43f40c0da Mon Sep 17 00:00:00 2001 From: David Papp Date: Wed, 22 Oct 2025 18:18:05 +0200 Subject: [PATCH 08/16] chore: update dependencies and refactor error handling --- Cargo.toml | 6 ++--- src/lib.rs | 23 +++++++++-------- src/version1.rs | 69 +++++++++++++++++++++++++------------------------ src/version2.rs | 41 +++++++++++++++-------------- 4 files changed, 71 insertions(+), 68 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fe7bd77..a55800c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,12 +10,12 @@ documentation = "https://docs.rs/proxy-protocol/" repository = "https://github.com/Proximyst/proxy-protocol.git" [dependencies] -snafu = "~0.6" +snafu = "~0.8" bytes = "~1" [dev-dependencies] -pretty_assertions = "^0.7" -rand = "~0.8" +pretty_assertions = "^1.4" +rand = "~0.9" [features] default = [] diff --git a/src/lib.rs b/src/lib.rs index 57a0861..aafc6d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ pub mod version2; use bytes::{Buf, BytesMut}; use snafu::{ensure, ResultExt as _, Snafu}; + #[derive(Debug, Snafu)] #[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // A new version may be added #[cfg_attr(test, derive(PartialEq, Eq))] @@ -79,7 +80,7 @@ pub enum ProxyHeader { fn parse_version(buf: &mut impl Buf) -> Result { // There is a 6 byte header to v1, 12 byte to all binary versions. - ensure!(buf.remaining() >= 6, NotProxyHeader); + ensure!(buf.remaining() >= 6, NotProxyHeaderSnafu); // V1 is the only version that starts with "PROXY" (0x50 0x52 0x4F 0x58 // 0x59), and we can therefore decide version based on that. @@ -91,11 +92,11 @@ fn parse_version(buf: &mut impl Buf) -> Result { } // Now we require 13: 12 for the prefix, 1 for the version + command - ensure!(buf.remaining() >= 13, NotProxyHeader); + ensure!(buf.remaining() >= 13, NotProxyHeaderSnafu); ensure!( buf.chunk()[..12] == [0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A], - NotProxyHeader + NotProxyHeaderSnafu ); buf.advance(12); @@ -117,7 +118,7 @@ fn parse_version(buf: &mut impl Buf) -> Result { // Interesting edge-case! This is the only time version 1 would be invalid. if version == 1 { - return InvalidVersion { version: 1u32 }.fail(); + return InvalidVersionSnafu { version: 1u32 }.fail(); } Ok(version as u32) @@ -135,9 +136,9 @@ pub fn parse(buf: &mut impl Buf) -> Result { }; Ok(match version { - 1 => self::version1::parse(buf).context(Version1)?, - 2 => self::version2::parse(buf).context(Version2)?, - _ => return InvalidVersion { version }.fail(), + 1 => self::version1::parse(buf).context(Version1Snafu)?, + 2 => self::version2::parse(buf).context(Version2Snafu)?, + _ => return InvalidVersionSnafu { version }.fail(), }) } @@ -148,7 +149,7 @@ pub fn parse(buf: &mut impl Buf) -> Result { pub fn encode(header: ProxyHeader) -> Result { Ok(match header { ProxyHeader::Version1 { addresses, .. } => { - version1::encode(addresses).context(WriteVersion1)? + version1::encode(addresses).context(WriteVersion1Snafu)? } ProxyHeader::Version2 { command, @@ -156,7 +157,7 @@ pub fn encode(header: ProxyHeader) -> Result { addresses, extensions, } => version2::encode(command, transport_protocol, addresses, &extensions[..]) - .context(WriteVersion2)?, + .context(WriteVersion2Snafu)?, #[allow(unreachable_patterns)] // May be required to be exhaustive. _ => unimplemented!("Unimplemented version?"), @@ -188,7 +189,7 @@ mod parse_tests { ); let mut random = [0u8; 128]; - rand::thread_rng().fill_bytes(&mut random); + rand::rng().fill_bytes(&mut random); let mut header = b"PROXY UNKNOWN ".to_vec(); header.extend(&random[..]); header.extend(b"\r\n"); @@ -629,7 +630,7 @@ mod parse_tests { assert!(data.remaining() == 4); // Consume the entire header let mut data = [0u8; 200]; - rand::thread_rng().fill_bytes(&mut data); + rand::rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. assert!(parse(&mut &data[..]).is_err()); diff --git a/src/version1.rs b/src/version1.rs index 1862d0c..34b0d66 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -6,6 +6,7 @@ use std::{ str::{FromStr as _, Utf8Error}, }; + const CR: u8 = 0x0D; const LF: u8 = 0x0A; @@ -64,7 +65,7 @@ fn count_till_first(haystack: &[u8], needle: u8) -> Option { } pub(crate) fn parse(buf: &mut impl Buf) -> Result { - ensure!(buf.remaining() >= 4, UnexpectedEof); + ensure!(buf.remaining() >= 4, UnexpectedEofSnafu); let step = buf.get_u8(); @@ -83,23 +84,23 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result ProxyAddressFamily::Tcp4, b'6' => ProxyAddressFamily::Tcp6, - _ => return IllegalAddressFamily.fail(), + _ => return IllegalAddressFamilySnafu.fail(), } } b'U' => { // Unknown - ensure!(buf.remaining() >= 6, UnexpectedEof); // Not 7, we consumed 1. + ensure!(buf.remaining() >= 6, UnexpectedEofSnafu); // Not 7, we consumed 1. buf.advance(6); ProxyAddressFamily::Unknown } - _ => return IllegalAddressFamily.fail(), + _ => return IllegalAddressFamilySnafu.fail(), }; if address_family == ProxyAddressFamily::Unknown { // Just consume up to the end. let mut cr = false; loop { - ensure!(buf.has_remaining(), UnexpectedEof); + ensure!(buf.has_remaining(), UnexpectedEofSnafu); let b = buf.get_u8(); if cr && b == LF { break; @@ -112,72 +113,72 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= 8, UnexpectedEof); + ensure!(buf.remaining() >= 8, UnexpectedEofSnafu); buf.advance(1); // Space - let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?; + let count = count_till_first(buf.chunk(), b' ').context(MissingAddressSnafu)?; let source = &buf.chunk()[..count]; - let source = std::str::from_utf8(source).context(NonAscii)?; + let source = std::str::from_utf8(source).context(NonAsciiSnafu)?; let source = match address_family { - ProxyAddressFamily::Tcp4 => IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddress)?), - ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddress)?), + ProxyAddressFamily::Tcp4 => IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddressSnafu)?), + ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddressSnafu)?), ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"), }; buf.advance(count); // Same as above, another address incoming. - ensure!(buf.remaining() >= 8, UnexpectedEof); + ensure!(buf.remaining() >= 8, UnexpectedEofSnafu); buf.advance(1); // Space - let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?; + let count = count_till_first(buf.chunk(), b' ').context(MissingAddressSnafu)?; let destination = &buf.chunk()[..count]; - let destination = std::str::from_utf8(destination).context(NonAscii)?; + let destination = std::str::from_utf8(destination).context(NonAsciiSnafu)?; let destination = match address_family { ProxyAddressFamily::Tcp4 => { - IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddress)?) + IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddressSnafu)?) } ProxyAddressFamily::Tcp6 => { - IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddress)?) + IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddressSnafu)?) } ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"), }; buf.advance(count); // Space, then a port. 0 is minimum valid port, so 1 byte. - ensure!(buf.remaining() >= 2, UnexpectedEof); + ensure!(buf.remaining() >= 2, UnexpectedEofSnafu); buf.advance(1); - let count = count_till_first(buf.chunk(), b' ').context(InvalidPort)?; + let count = count_till_first(buf.chunk(), b' ').context(InvalidPortSnafu)?; let source_port = &buf.chunk()[..count]; - let source_port = std::str::from_utf8(source_port).context(NonAscii)?; + let source_port = std::str::from_utf8(source_port).context(NonAsciiSnafu)?; ensure!( // The port 0 is itself valid, but 01 is not. !source_port.starts_with('0') || source_port == "0", - InvalidPort, + InvalidPortSnafu, ); - let source_port: u16 = source_port.parse().ok().context(InvalidPort)?; + let source_port: u16 = source_port.parse().ok().context(InvalidPortSnafu)?; buf.advance(count); // Space, then a port, then CRLF. 0 is minimum valid port, so 1 byte. - ensure!(buf.remaining() >= 4, UnexpectedEof); + ensure!(buf.remaining() >= 4, UnexpectedEofSnafu); buf.advance(1); // This is the last member of the string. Read until CR; that's next up. - let count = count_till_first(buf.chunk(), CR).context(InvalidPort)?; + let count = count_till_first(buf.chunk(), CR).context(InvalidPortSnafu)?; let destination_port = &buf.chunk()[..count]; - let destination_port = std::str::from_utf8(destination_port).context(NonAscii)?; + let destination_port = std::str::from_utf8(destination_port).context(NonAsciiSnafu)?; ensure!( // The port 0 is itself valid, but 01 is not. !destination_port.starts_with('0') || destination_port == "0", - InvalidPort, + InvalidPortSnafu, ); - let destination_port: u16 = destination_port.parse().ok().context(InvalidPort)?; + let destination_port: u16 = destination_port.parse().ok().context(InvalidPortSnafu)?; buf.advance(count); - ensure!(buf.get_u8() == CR, IllegalHeaderEnding); + ensure!(buf.get_u8() == CR, IllegalHeaderEndingSnafu); // We only checked up to the CR - ensure!(buf.remaining() >= 1, UnexpectedEof); - ensure!(buf.get_u8() == LF, IllegalHeaderEnding); + ensure!(buf.remaining() >= 1, UnexpectedEofSnafu); + ensure!(buf.get_u8() == LF, IllegalHeaderEndingSnafu); let addresses = match (source, destination) { (IpAddr::V4(source), IpAddr::V4(destination)) => ProxyAddresses::Ipv4 { @@ -202,14 +203,14 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result // Reserve as much data as we're gonna need -- at most. let mut buf = BytesMut::with_capacity(107).writer(); - buf.write_all(&b"PROXY TCP"[..]).context(StdIo)?; + buf.write_all(&b"PROXY TCP"[..]).context(StdIoSnafu)?; match addresses { ProxyAddresses::Ipv4 { source, destination, } => { - buf.write(&b"4 "[..]).context(StdIo)?; + buf.write(&b"4 "[..]).context(StdIoSnafu)?; write!( buf, "{} {} {} {}\r\n", @@ -218,13 +219,13 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result source.port(), destination.port(), ) - .context(StdIo)?; + .context(StdIoSnafu)?; } ProxyAddresses::Ipv6 { source, destination, } => { - buf.write(&b"6 "[..]).context(StdIo)?; + buf.write(&b"6 "[..]).context(StdIoSnafu)?; write!( buf, "{} {} {} {}\r\n", @@ -233,7 +234,7 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result source.port(), destination.port(), ) - .context(StdIo)?; + .context(StdIoSnafu)?; } ProxyAddresses::Unknown => unreachable!(), } @@ -266,7 +267,7 @@ mod parse_tests { ); let mut random = [0u8; 128]; - rand::thread_rng().fill_bytes(&mut random); + rand::rng().fill_bytes(&mut random); let mut header = b"UNKNOWN ".to_vec(); header.extend(&random[..]); header.extend(b"\r\n"); diff --git a/src/version2.rs b/src/version2.rs index 3e5b0a0..90d7717 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -3,6 +3,7 @@ use snafu::{ensure, Snafu}; use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + #[derive(Debug, Snafu)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum ParseError { @@ -140,14 +141,14 @@ trait Tlv: Sized { fn parse(buf: &mut impl Buf) -> Result { if buf.remaining() < 3 { - return Err(ParseError::UnexpectedEof); + return UnexpectedEofSnafu.fail(); } let type_id = buf.get_u8(); let vlen = buf.get_u16(); let expected_rem = buf .remaining() .checked_sub(vlen.into()) - .ok_or(ParseError::UnexpectedEof)?; + .ok_or_else(|| ParseError::UnexpectedEof)?; let r = Self::parse_parts(type_id, vlen, buf)?; // Assert, because it would be an internal error assert_eq!(buf.remaining(), expected_rem); @@ -231,7 +232,7 @@ impl Tlv for SslExtensionTlv { PP2_SUBTYPE_SSL_SIG_ALG => Self::Version(ascii_from_buf(buf, len)?), PP2_SUBTYPE_SSL_KEY_ALG => Self::Version(ascii_from_buf(buf, len)?), PP2_SUBTYPE_SSL_CN => Self::Version(str_from_buf(buf, len)?), - _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + _ => return InvalidTlvTypeIdSnafu { type_id }.fail(), }) } } @@ -246,11 +247,11 @@ pub struct Ssl { impl Ssl { fn parse(buf: &mut impl Buf, len: u16) -> Result { if buf.remaining() < len.into() { - return Err(ParseError::UnexpectedEof); + return UnexpectedEofSnafu.fail(); } let mut ext_len = len .checked_sub(5) - .ok_or(ParseError::InsufficientLengthSpecified { + .ok_or_else(|| ParseError::InsufficientLengthSpecified { given: len.into(), needs: 5, })?; @@ -272,7 +273,7 @@ impl Ssl { // new value is lower. ext_len = usize::from(ext_len) .checked_sub(parsed) - .ok_or(ParseError::InsufficientLengthSpecified { + .ok_or_else(|| ParseError::InsufficientLengthSpecified { given: ext_len.into(), needs: parsed, })? @@ -398,7 +399,7 @@ impl Tlv for ExtensionTlv { PP2_TYPE_UNIQUE_ID => Self::UniqueId(vec_from_buf(buf, len)), PP2_TYPE_SSL => Self::Ssl(Ssl::parse(buf, len)?), PP2_TYPE_NETNS => Self::NetNs(ascii_from_buf(buf, len)?), - _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + _ => return InvalidTlvTypeIdSnafu { type_id }.fail(), }) } } @@ -424,12 +425,12 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result ProxyCommand::Local, 1 => ProxyCommand::Proxy, - cmd => return UnknownCommand { cmd }.fail(), + cmd => return UnknownCommandSnafu { cmd }.fail(), }; // 4 bits for address family, 4 bits for transport protocol, // then 2 bytes for the length. - ensure!(buf.remaining() >= 3, UnexpectedEof); + ensure!(buf.remaining() >= 3, UnexpectedEofSnafu); let byte = buf.get_u8(); let address_family = match byte >> 4 { @@ -437,17 +438,17 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result ProxyAddressFamily::Inet, 2 => ProxyAddressFamily::Inet6, 3 => ProxyAddressFamily::Unix, - family => return UnknownAddressFamily { family }.fail(), + family => return UnknownAddressFamilySnafu { family }.fail(), }; let transport_protocol = match byte << 4 >> 4 { 0 => ProxyTransportProtocol::Unspec, 1 => ProxyTransportProtocol::Stream, 2 => ProxyTransportProtocol::Datagram, - protocol => return UnknownTransportProtocol { protocol }.fail(), + protocol => return UnknownTransportProtocolSnafu { protocol }.fail(), }; let length = buf.get_u16() as usize; - ensure!(buf.remaining() >= length, UnexpectedEof); + ensure!(buf.remaining() >= length, UnexpectedEofSnafu); // Time to parse the following: // @@ -481,12 +482,12 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= 108 * 2, - InsufficientLengthSpecified { + InsufficientLengthSpecifiedSnafu { given: length, needs: 108usize * 2, - }, + } ); - ensure!(buf.remaining() >= length, UnexpectedEof); + ensure!(buf.remaining() >= length, UnexpectedEofSnafu); let mut source = [0u8; 108]; let mut destination = [0u8; 108]; buf.copy_to_slice(&mut source[..]); @@ -500,11 +501,11 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= address_len, UnexpectedEof,); + ensure!(buf.remaining() >= address_len, UnexpectedEofSnafu,); let addresses = match address_family { ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, @@ -567,7 +568,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result Result Date: Thu, 12 Mar 2026 20:27:46 +0100 Subject: [PATCH 09/16] Update Dependabot configuration for cargo and daily updates --- .github/dependabot.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..42d1417 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "cargo" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" From be8c10bb150a5ab3a01e8a298a40447bd7ac6d11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:28:31 +0000 Subject: [PATCH 10/16] chore(deps): update snafu requirement from ~0.8 to ~0.9 Updates the requirements on [snafu](https://github.com/shepmaster/snafu) to permit the latest version. - [Changelog](https://github.com/shepmaster/snafu/blob/main/CHANGELOG.md) - [Commits](https://github.com/shepmaster/snafu/compare/0.8.0...0.9.0) --- updated-dependencies: - dependency-name: snafu dependency-version: 0.9.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a55800c..5046d8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ documentation = "https://docs.rs/proxy-protocol/" repository = "https://github.com/Proximyst/proxy-protocol.git" [dependencies] -snafu = "~0.8" +snafu = "~0.9" bytes = "~1" [dev-dependencies] From 4b77eef07c1b55b348713440d77a2badbd091b2d Mon Sep 17 00:00:00 2001 From: David Papp Date: Thu, 12 Mar 2026 20:35:57 +0100 Subject: [PATCH 11/16] Add GitHub Actions to Dependabot configuration --- .github/dependabot.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 42d1417..84fca58 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,3 +9,7 @@ updates: directory: "/" # Location of package manifests schedule: interval: "daily" + - package-ecosystem: "github-actions" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" From 4f7e7542f124e5333b71ce38bda3f6b436e79ed2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:36:31 +0000 Subject: [PATCH 12/16] chore(deps): bump actions/cache from 2 to 5 Bumps [actions/cache](https://github.com/actions/cache) from 2 to 5. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v2...v5) --- updated-dependencies: - dependency-name: actions/cache dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b11eae6..dcea6f2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,7 +14,7 @@ jobs: default: true components: rustfmt, clippy - name: Cache target - uses: actions/cache@v2 + uses: actions/cache@v5 with: path: | ~/.cargo/git/ From 1276d5db340dbde3634ef1489c55252ae1d1a9d6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:36:36 +0000 Subject: [PATCH 13/16] chore(deps): bump actions/checkout from 2 to 6 Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 6. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b11eae6..069cd13 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,7 @@ jobs: name: Build, test, check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - uses: actions-rs/toolchain@v1 with: profile: minimal From 09af5e162637f18eb9244c85fc80494ff26b2ce3 Mon Sep 17 00:00:00 2001 From: David Papp Date: Thu, 12 Mar 2026 20:46:06 +0100 Subject: [PATCH 14/16] chore(deps): update rand to version 0.10 and add cargo fmt and test checks to CI --- .github/workflows/build.yml | 2 ++ Cargo.toml | 2 +- src/lib.rs | 1 - src/version1.rs | 9 ++++++--- src/version2.rs | 34 ++++++++++++++++------------------ 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dee9f0b..afac59c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,6 +22,8 @@ jobs: target/ key: ${{ runner.os }}-proxy-protocol-${{ hashFiles('**/Cargo.toml') }} restore-keys: ${{ runner.os }}-proxy-protocol + - run: cargo fmt --check + - run: cargo test -- --include-ignored - run: cargo build - uses: actions-rs/cargo@v1 with: diff --git a/Cargo.toml b/Cargo.toml index 5046d8c..2652210 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ bytes = "~1" [dev-dependencies] pretty_assertions = "^1.4" -rand = "~0.9" +rand = "~0.10" [features] default = [] diff --git a/src/lib.rs b/src/lib.rs index aafc6d1..8483412 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ pub mod version2; use bytes::{Buf, BytesMut}; use snafu::{ensure, ResultExt as _, Snafu}; - #[derive(Debug, Snafu)] #[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // A new version may be added #[cfg_attr(test, derive(PartialEq, Eq))] diff --git a/src/version1.rs b/src/version1.rs index 34b0d66..bbf198c 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -6,7 +6,6 @@ use std::{ str::{FromStr as _, Utf8Error}, }; - const CR: u8 = 0x0D; const LF: u8 = 0x0A; @@ -120,8 +119,12 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddressSnafu)?), - ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddressSnafu)?), + ProxyAddressFamily::Tcp4 => { + IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddressSnafu)?) + } + ProxyAddressFamily::Tcp6 => { + IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddressSnafu)?) + } ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"), }; buf.advance(count); diff --git a/src/version2.rs b/src/version2.rs index 90d7717..1fd735a 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -3,7 +3,6 @@ use snafu::{ensure, Snafu}; use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; - #[derive(Debug, Snafu)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum ParseError { @@ -228,10 +227,10 @@ impl Tlv for SslExtensionTlv { fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { Ok(match type_id { PP2_SUBTYPE_SSL_VERSION => Self::Version(ascii_from_buf(buf, len)?), - PP2_SUBTYPE_SSL_CIPHER => Self::Version(ascii_from_buf(buf, len)?), - PP2_SUBTYPE_SSL_SIG_ALG => Self::Version(ascii_from_buf(buf, len)?), - PP2_SUBTYPE_SSL_KEY_ALG => Self::Version(ascii_from_buf(buf, len)?), - PP2_SUBTYPE_SSL_CN => Self::Version(str_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CIPHER => Self::Cipher(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_SIG_ALG => Self::SigAlg(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_KEY_ALG => Self::KeyAlg(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CN => Self::ClientCN(str_from_buf(buf, len)?), _ => return InvalidTlvTypeIdSnafu { type_id }.fail(), }) } @@ -249,12 +248,12 @@ impl Ssl { if buf.remaining() < len.into() { return UnexpectedEofSnafu.fail(); } - let mut ext_len = len - .checked_sub(5) - .ok_or_else(|| ParseError::InsufficientLengthSpecified { - given: len.into(), - needs: 5, - })?; + let mut ext_len = + len.checked_sub(5) + .ok_or_else(|| ParseError::InsufficientLengthSpecified { + given: len.into(), + needs: 5, + })?; let client = SslClientFlags(buf.get_u8()); let verify = SslVerifyStatus(buf.get_u32()); let mut extensions = Vec::new(); @@ -584,13 +583,12 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result Date: Thu, 12 Mar 2026 20:48:34 +0100 Subject: [PATCH 15/16] Refactor error handling in parsing functions for improved readability and consistency --- src/lib.rs | 5 +---- src/version2.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8483412..7197dbd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,10 +129,7 @@ fn parse_version(buf: &mut impl Buf) -> Result { /// available through [Buf::chunk], at the very least for the header. Data that /// follows may be chunked as you wish. pub fn parse(buf: &mut impl Buf) -> Result { - let version = match parse_version(buf) { - Ok(ver) => ver, - Err(e) => return Err(e), - }; + let version = parse_version(buf)?; Ok(match version { 1 => self::version1::parse(buf).context(Version1Snafu)?, diff --git a/src/version2.rs b/src/version2.rs index 1fd735a..f3157bc 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -121,7 +121,7 @@ trait Tlv: Sized { let vlen = self.value_len()?; if vlen .checked_add(3) - .map_or(true, |tlv_len| buf.remaining_mut() < tlv_len.into()) + .is_none_or(|tlv_len| buf.remaining_mut() < tlv_len.into()) { return Err(EncodeError::ValueTooLarge); } @@ -147,7 +147,7 @@ trait Tlv: Sized { let expected_rem = buf .remaining() .checked_sub(vlen.into()) - .ok_or_else(|| ParseError::UnexpectedEof)?; + .ok_or(ParseError::UnexpectedEof)?; let r = Self::parse_parts(type_id, vlen, buf)?; // Assert, because it would be an internal error assert_eq!(buf.remaining(), expected_rem); @@ -500,7 +500,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result Result Date: Thu, 12 Mar 2026 20:50:38 +0100 Subject: [PATCH 16/16] Refactor error handling in `parse` function for improved readability --- src/version2.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/version2.rs b/src/version2.rs index f3157bc..238085c 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -583,12 +583,13 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result