From fa6c2adfe1e47906c8a5f194f9c132830e8285cc Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Wed, 22 Sep 2021 12:08:12 +0200 Subject: [PATCH 1/4] Clippy fixes --- src/lib.rs | 36 +++++++++++++++++++++++------------- src/version1.rs | 38 +++++++++++++++++++++++--------------- src/version2.rs | 8 ++++---- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a1fe67d..b60114d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,15 +191,15 @@ mod parse_tests { assert!(!buf.has_remaining()); // Consume the ENTIRE header! fn valid_v4( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -223,15 +223,25 @@ mod parse_tests { ); fn valid_v6( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } @@ -313,7 +323,7 @@ mod parse_tests { 0x49, 0x54, 0x0A, - (2 << 4) | 0, + 2 << 4, ]; const PREFIX_PROXY: [u8; 13] = [ 0x0D, diff --git a/src/version1.rs b/src/version1.rs index 19fc38b..c9a0377 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -190,9 +190,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result unreachable!(), }; - Ok(super::ProxyHeader::Version1 { - addresses, - }) + Ok(super::ProxyHeader::Version1 { addresses }) } pub(crate) fn encode(addresses: ProxyAddresses) -> Result { @@ -278,15 +276,15 @@ mod parse_tests { #[test] fn test_valid_ipv4_cases() { fn valid( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -312,15 +310,25 @@ mod parse_tests { #[test] fn test_valid_ipv6_cases() { fn valid( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } diff --git a/src/version2.rs b/src/version2.rs index 6148e62..fec2b99 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -733,7 +733,7 @@ mod encode_tests { ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, ), - signed(&[(2 << 4) | 0, 0, 0, 0][..]), + signed(&[2 << 4, 0, 0, 0][..]), ); assert_eq!( @@ -756,7 +756,7 @@ mod encode_tests { signed( &[ (2 << 4) | 1, - (1 << 4) | 0, + 1 << 4, 0, 12, 1, @@ -819,7 +819,7 @@ mod encode_tests { ), signed( &[ - (2 << 4) | 0, + 2 << 4, (1 << 4) | 2, 0, 12, @@ -863,7 +863,7 @@ mod encode_tests { ), signed( &[ - (2 << 4) | 0, + 2 << 4, (2 << 4) | 2, 0, 36, From f701d20e4dc021cb074b7f1e337e7088b8aec69b Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Mon, 27 Sep 2021 16:42:47 +0200 Subject: [PATCH 2/4] Fix mojibake in the protocol specification The document is served without an encoding header. --- proxy-protocol.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy-protocol.txt b/proxy-protocol.txt index 863eac6..196a445 100644 --- a/proxy-protocol.txt +++ b/proxy-protocol.txt @@ -33,7 +33,7 @@ Revision history reserved TLV type ranges, added TLV documentation, clarified string encoding. With contributions from Andriy Palamarchuk (Amazon.com). - 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) + 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) 1. Background From 93cd650e59cb44f4f67d9e170721c0a96b9feeb3 Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Wed, 22 Sep 2021 13:51:20 +0200 Subject: [PATCH 3/4] Refactor version2 parsing This factors in common parts of address parsing. There is a tweak to which errors are returned when the buffer is both too small for its declared length and for the minimum length of address data: in this case, we now return an UnexpectedEof indicating the first error, previously the second error was signaled. Tests were updated with larger vectors so they could still exercise InsufficientLengthSpecified error cases. --- src/lib.rs | 3 + src/version2.rs | 146 +++++++++++++++++++----------------------------- 2 files changed, 59 insertions(+), 90 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b60114d..94f43ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -639,6 +639,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..], ] .concat() diff --git a/src/version2.rs b/src/version2.rs index fec2b99..4d9fd44 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -100,18 +100,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= length, UnexpectedEof); - buf.advance(length); - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unspec, - }); - } + ensure!(buf.remaining() >= length, UnexpectedEof); // Time to parse the following: // @@ -134,88 +123,73 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result } unix_addr; // > }; - if address_family == ProxyAddressFamily::Unix { - ensure!( - length >= 108 * 2, - InsufficientLengthSpecified { - given: length, - needs: 108usize * 2, - }, - ); - ensure!(buf.remaining() >= 108 * 2, UnexpectedEof); - let mut source = [0u8; 108]; - let mut destination = [0u8; 108]; - buf.copy_to_slice(&mut source[..]); - buf.copy_to_slice(&mut destination[..]); - // TODO(Mariell Hoversholm): Support TLVs - if length > 108 * 2 { - buf.advance(length - (108 * 2)); - } - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unix { - source, - destination, - }, - }); - } - - let port_length = 4; + // The full length of address data, + // including two addresses and two ports let address_length = match address_family { - ProxyAddressFamily::Inet => 8, - ProxyAddressFamily::Inet6 => 32, - _ => unreachable!(), + ProxyAddressFamily::Inet => (4 + 2) * 2, + ProxyAddressFamily::Inet6 => (16 + 2) * 2, + ProxyAddressFamily::Unix => 108 * 2, + ProxyAddressFamily::Unspec => 0, }; ensure!( - length >= port_length + address_length, + length >= address_length, InsufficientLengthSpecified { given: length, - needs: port_length + address_length, + needs: address_length, }, ); - ensure!( - buf.remaining() >= port_length + address_length, - UnexpectedEof, - ); - - let addresses = if address_family == ProxyAddressFamily::Inet { - let mut data = [0u8; 4]; - buf.copy_to_slice(&mut data[..]); - let source = Ipv4Addr::from(data); + ensure!(buf.remaining() >= address_length, UnexpectedEof,); + + let addresses = match address_family { + ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, + ProxyAddressFamily::Unix => { + let mut source = [0u8; 108]; + let mut destination = [0u8; 108]; + buf.copy_to_slice(&mut source[..]); + buf.copy_to_slice(&mut destination[..]); + ProxyAddresses::Unix { + source, + destination, + } + } + ProxyAddressFamily::Inet => { + let mut data = [0u8; 4]; + buf.copy_to_slice(&mut data[..]); + let source = Ipv4Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv4Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv4Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(source, source_port), - destination: SocketAddrV4::new(destination, destination_port), + ProxyAddresses::Ipv4 { + source: SocketAddrV4::new(source, source_port), + destination: SocketAddrV4::new(destination, destination_port), + } } - } else { - let mut data = [0u8; 16]; - buf.copy_to_slice(&mut data); - let source = Ipv6Addr::from(data); + ProxyAddressFamily::Inet6 => { + let mut data = [0u8; 16]; + buf.copy_to_slice(&mut data); + let source = Ipv6Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv6Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv6Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(source, source_port, 0, 0), - destination: SocketAddrV6::new(destination, destination_port, 0, 0), + ProxyAddresses::Ipv6 { + source: SocketAddrV6::new(source, source_port, 0, 0), + destination: SocketAddrV6::new(destination, destination_port, 0, 0), + } } }; - if length > port_length + address_length { + if length > address_length { // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - (port_length + address_length)); + buf.advance(length - address_length); } Ok(super::ProxyHeader::Version2 { @@ -343,15 +317,9 @@ pub(crate) fn encode( // > }; let len = match addresses { ProxyAddresses::Unspec => 0, - ProxyAddresses::Unix { .. } => { - 108 + 108 - } - ProxyAddresses::Ipv4 { .. } => { - 4 + 4 + 2 + 2 - } - ProxyAddresses::Ipv6 { .. } => { - 16 + 16 + 2 + 2 - } + ProxyAddresses::Unix { .. } => 108 + 108, + ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, + ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; let mut buf = BytesMut::with_capacity(16 + len); @@ -698,6 +666,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..] ), Err(ParseError::InsufficientLengthSpecified { @@ -847,12 +818,7 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Datagram, ProxyAddresses::Ipv6 { - source: SocketAddrV6::new( - Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), - 8192, - 0, - 0, - ), + source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 8192, 0, 0,), destination: SocketAddrV6::new( Ipv6Addr::new(65535, 65535, 32767, 32766, 111, 222, 333, 444), 0, From de308f6970c6a48ae27c0f6dfc409e63b9a936b3 Mon Sep 17 00:00:00 2001 From: Gabriel de Perthuis Date: Mon, 27 Sep 2021 16:07:31 +0200 Subject: [PATCH 4/4] Add support for TLV extensions The ones supported are the ones documented in proxy-protocol.txt as of 2020/03/05. --- src/lib.rs | 18 +- src/version2.rs | 640 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 612 insertions(+), 46 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 94f43ff..57a0861 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub enum EncodeError { /// An error occurred while encoding version 1. #[snafu(display("there was an error while encoding the v1 header: {}", source))] WriteVersion1 { source: version1::EncodeError }, + /// An error occurred while encoding version 2. + #[snafu(display("there was an error while encoding the v2 header: {}", source))] + WriteVersion2 { source: version2::EncodeError }, } /// The PROXY header emitted at most once at the start of a new connection. @@ -70,6 +73,7 @@ pub enum ProxyHeader { /// The addresses used to connect to the proxy. addresses: version2::ProxyAddresses, + extensions: Vec, }, } @@ -150,7 +154,9 @@ pub fn encode(header: ProxyHeader) -> Result { command, transport_protocol, addresses, - } => version2::encode(command, transport_protocol, addresses), + extensions, + } => version2::encode(command, transport_protocol, addresses, &extensions[..]) + .context(WriteVersion2)?, #[allow(unreachable_patterns)] // May be required to be exhaustive. _ => unimplemented!("Unimplemented version?"), @@ -347,6 +353,7 @@ mod parse_tests { command: version2::ProxyCommand::Local, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); assert_eq!( @@ -355,6 +362,7 @@ mod parse_tests { command: version2::ProxyCommand::Proxy, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); @@ -388,7 +396,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..] @@ -403,6 +411,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -449,6 +458,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -507,7 +517,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..], @@ -532,6 +542,7 @@ mod parse_tests { 0, ) }, + extensions: Vec::new(), }) ); @@ -612,6 +623,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header diff --git a/src/version2.rs b/src/version2.rs index 4d9fd44..b3d6d3e 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; +use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] @@ -19,6 +20,28 @@ pub enum ParseError { #[snafu(display("insufficient length specified: {}, requires minimum {}", given, needs))] InsufficientLengthSpecified { given: usize, needs: usize }, + + #[snafu(display("invalid length specified: {}, causes overflow", given))] + LengthOverflow { given: usize }, + + #[snafu(display("invalid TLV type id specified: {}", type_id))] + InvalidTlvTypeId { type_id: u8 }, + + #[snafu(display("invalid UTF-8: {:?}", bytes))] + InvalidUtf8 { bytes: Vec }, + + #[snafu(display("invalid ASCII: {:?}", bytes))] + InvalidAscii { bytes: Vec }, + + #[snafu(display("trailing data: {:?}", len))] + TrailingData { len: usize }, +} + +#[derive(Debug, Snafu)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum EncodeError { + #[snafu(display("value is too large to encode"))] + ValueTooLarge, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -59,6 +82,310 @@ enum ProxyAddressFamily { Unix, } +trait Tlv: Sized { + /// Identifies the type + fn type_id(&self) -> u8; + + /// The byte size of the value if encoded, or None if too big + fn value_len(&self) -> Result; + + /// Write the value to the provided buffer + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError>; + + fn encoded_len(&self) -> Result { + self.value_len()? + .checked_add(3) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + let vlen = self.value_len()?; + if vlen + .checked_add(3) + .map_or(true, |tlv_len| buf.remaining_mut() < tlv_len.into()) + { + return Err(EncodeError::ValueTooLarge); + } + buf.put_u8(self.type_id()); + buf.put_u16(vlen); + self.encode_value(buf) + } + + // API note: + // We have to pass the len instead of using a view. + // Buf doesn't have a good view / subslice abstraction + // unlike plain slices or even the Bytes implementation + // IMHO (@g2p) it would be better for parse to receive a + // slice or a concrete type. + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result; + + fn parse(buf: &mut impl Buf) -> Result { + if buf.remaining() < 3 { + return Err(ParseError::UnexpectedEof); + } + let type_id = buf.get_u8(); + let vlen = buf.get_u16(); + let expected_rem = buf + .remaining() + .checked_sub(vlen.into()) + .ok_or(ParseError::UnexpectedEof)?; + let r = Self::parse_parts(type_id, vlen, buf)?; + // Assert, because it would be an internal error + assert_eq!(buf.remaining(), expected_rem); + Ok(r) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslClientFlags(u8); + +impl SslClientFlags { + pub fn is_ssl_or_tls(&self) -> bool { + (self.0 & 1) == 1 + } + + pub fn client_authenticated_connection(&self) -> bool { + (self.0 & 2) == 2 + } + pub fn client_authenticated_session(&self) -> bool { + (self.0 & 4) == 4 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslVerifyStatus(u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum SslExtensionTlv { + /// TLS or SSL version in ASCII + Version(String), + /// TLS or SSL cipher suite in ASCII, for example "ECDHE-RSA-AES128-GCM-SHA256" + Cipher(String), + /// TLS or SSL signature algorithm in ASCII + SigAlg(String), + /// TLS or SSL key algorithm in ASCII + KeyAlg(String), + /// With client authentication, the common name for the client certificate in UTF-8 + ClientCN(String), +} + +impl SslExtensionTlv { + fn as_str(&self) -> &str { + match self { + Self::Version(version) => version, + Self::Cipher(cipher) => cipher, + Self::SigAlg(sigalg) => sigalg, + Self::KeyAlg(keyalg) => keyalg, + Self::ClientCN(cn) => cn, + } + } +} + +impl Tlv for SslExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Version(_) => PP2_SUBTYPE_SSL_VERSION, + Self::ClientCN(_) => PP2_SUBTYPE_SSL_CN, + Self::Cipher(_) => PP2_SUBTYPE_SSL_CIPHER, + Self::SigAlg(_) => PP2_SUBTYPE_SSL_SIG_ALG, + Self::KeyAlg(_) => PP2_SUBTYPE_SSL_KEY_ALG, + } + } + + fn value_len(&self) -> Result { + self.as_str() + .len() + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_slice(self.as_str().as_bytes()); + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_SUBTYPE_SSL_VERSION => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CIPHER => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_SIG_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_KEY_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CN => Self::Version(str_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ssl { + client: SslClientFlags, + verify: SslVerifyStatus, + extensions: Vec, +} + +impl Ssl { + fn parse(buf: &mut impl Buf, len: u16) -> Result { + if buf.remaining() < len.into() { + return Err(ParseError::UnexpectedEof); + } + let mut ext_len = len + .checked_sub(5) + .ok_or(ParseError::InsufficientLengthSpecified { + given: len.into(), + needs: 5, + })?; + let client = SslClientFlags(buf.get_u8()); + let verify = SslVerifyStatus(buf.get_u32()); + let mut extensions = Vec::new(); + while ext_len > 0 { + let rem0 = buf.remaining(); + extensions.push(SslExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + // The assert enforces that Buf is implemented sanely + // and not rewound. + let parsed = rem0.checked_sub(rem).expect("Buf error"); + // We don't enforce u16-sized buffers. + // Since we don't pass a bound on how much to parse, + // we can't enforce that the extension parser won't read + // (slightly) more than 64k. + // The assert is safe since ext_len was already u16 and the + // new value is lower. + ext_len = usize::from(ext_len) + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len.into(), + needs: parsed, + })? + .try_into() + .expect("Math error"); + } + Ok(Self { + client, + verify, + extensions, + }) + } + + fn encoded_len(&self) -> Result { + // 1 for flags, 4 for verify status, plus all nested TLVs + self.extensions + .iter() + .try_fold(5u16, |sum, subtlv| { + sum.checked_add(subtlv.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_u8(self.client.0); + buf.put_u32(self.verify.0); + for ext in self.extensions.iter() { + ext.encode(buf)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum ExtensionTlv { + Alpn(Vec), + Authority(String), + Crc32c(u32), + UniqueId(Vec), + Ssl(Ssl), + NetNs(String), +} + +pub(crate) const PP2_TYPE_ALPN: u8 = 0x01; +pub(crate) const PP2_TYPE_AUTHORITY: u8 = 0x02; +pub(crate) const PP2_TYPE_CRC32C: u8 = 0x03; +pub(crate) const PP2_TYPE_NOOP: u8 = 0x04; +pub(crate) const PP2_TYPE_UNIQUE_ID: u8 = 0x05; +pub(crate) const PP2_TYPE_SSL: u8 = 0x20; +pub(crate) const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21; +pub(crate) const PP2_SUBTYPE_SSL_CN: u8 = 0x22; +pub(crate) const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23; +pub(crate) const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24; +pub(crate) const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25; +pub(crate) const PP2_TYPE_NETNS: u8 = 0x30; + +fn vec_from_buf(buf: &mut impl Buf, len: u16) -> Vec { + let mut r = vec![0; len.into()]; + buf.copy_to_slice(&mut r); + r +} + +fn str_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let v = vec_from_buf(buf, len); + let r = String::from_utf8(v).map_err(|e| ParseError::InvalidUtf8 { + bytes: e.into_bytes(), + })?; + Ok(r) +} + +fn ascii_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let s = str_from_buf(buf, len)?; + if !s.is_ascii() { + Err(ParseError::InvalidAscii { + bytes: s.into_bytes(), + }) + } else { + Ok(s) + } +} + +impl Tlv for ExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Alpn(_) => PP2_TYPE_ALPN, + Self::Authority(_) => PP2_TYPE_AUTHORITY, + Self::Crc32c(_) => PP2_TYPE_CRC32C, + Self::UniqueId(_) => PP2_TYPE_UNIQUE_ID, + Self::Ssl(_) => PP2_TYPE_SSL, + Self::NetNs(_) => PP2_TYPE_NETNS, + } + } + + fn value_len(&self) -> Result { + match self { + Self::Alpn(alpn) => alpn.len(), + Self::Authority(authority) => authority.len(), + Self::Crc32c(_) => 4, + Self::UniqueId(id) => id.len(), + Self::Ssl(data) => data.encoded_len()?.into(), + Self::NetNs(netns) => netns.len(), + } + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + match self { + Self::Alpn(by) | Self::UniqueId(by) => buf.put_slice(by), + Self::Authority(st) | Self::NetNs(st) => buf.put_slice(st.as_bytes()), + Self::Crc32c(crc) => buf.put_u32(*crc), + Self::Ssl(ssl) => ssl.encode(buf)?, + }; + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_TYPE_ALPN => Self::Alpn(vec_from_buf(buf, len)), + PP2_TYPE_AUTHORITY => Self::Authority(str_from_buf(buf, len)?), + PP2_TYPE_CRC32C => Self::Crc32c(buf.get_u32()), + PP2_TYPE_UNIQUE_ID => Self::UniqueId(vec_from_buf(buf, len)), + PP2_TYPE_SSL => Self::Ssl(Ssl::parse(buf, len)?), + PP2_TYPE_NETNS => Self::NetNs(ascii_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +// Note: this is internal, assumes the first 12 bytes were parsed, +// and ignores the version half of the first byte. pub(crate) fn parse(buf: &mut impl Buf) -> Result { // We need to parse the following: // @@ -71,9 +398,10 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result> 4; + let command = st << 4 >> 4; let command = match command { 0 => ProxyCommand::Local, 1 => ProxyCommand::Proxy, @@ -125,21 +453,21 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result (4 + 2) * 2, ProxyAddressFamily::Inet6 => (16 + 2) * 2, ProxyAddressFamily::Unix => 108 * 2, ProxyAddressFamily::Unspec => 0, }; - ensure!( - length >= address_length, - InsufficientLengthSpecified { - given: length, - needs: address_length, - }, - ); - ensure!(buf.remaining() >= address_length, UnexpectedEof,); + let mut ext_len = + length + .checked_sub(address_len) + .ok_or(ParseError::InsufficientLengthSpecified { + given: length, + needs: address_len, + })?; + ensure!(buf.remaining() >= address_len, UnexpectedEof,); let addresses = match address_family { ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, @@ -187,23 +515,74 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result address_length { - // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - address_length); + let mut extensions = Vec::new(); + while ext_len > 0 { + // At this point, we know that remaining() >= ext_len + if buf.chunk()[0] == PP2_TYPE_NOOP { + if ext_len < 3 { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: 3, + }); + } + // Read/skip the type after peeking + buf.get_u8(); + let skip_len = buf.get_u16(); + let noop_len = 3u16 + .checked_add(skip_len) + .ok_or(ParseError::LengthOverflow { + given: skip_len.into(), + })? + .into(); + if noop_len > ext_len { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: noop_len, + }); + } + ext_len -= noop_len; + } else { + let rem0 = buf.remaining(); + extensions.push(ExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + let parsed = rem0.checked_sub(rem).expect("Buf error"); + ext_len = + ext_len + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: parsed, + })?; + } } Ok(super::ProxyHeader::Version2 { command, transport_protocol, addresses, + extensions, }) } +// Currently used in tests, has to be internal for the same reasons +// parse() currently is. +#[cfg(test)] +pub(crate) fn parse_fully(buf: &mut impl Buf) -> Result { + let r = parse(buf)?; + if buf.has_remaining() { + return Err(ParseError::TrailingData { + len: buf.remaining(), + }); + } + Ok(r) +} + pub(crate) fn encode( command: ProxyCommand, transport_protocol: ProxyTransportProtocol, addresses: ProxyAddresses, -) -> BytesMut { + extensions: &[ExtensionTlv], +) -> Result { // > struct proxy_hdr_v2 { // > uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ // > uint8_t ver_cmd; /* protocol version and command */ @@ -315,17 +694,27 @@ pub(crate) fn encode( // > uint8_t dst_addr[108]; // > } unix_addr; // > }; - let len = match addresses { + let address_len: u16 = match addresses { ProxyAddresses::Unspec => 0, ProxyAddresses::Unix { .. } => 108 + 108, ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; + // With extensions, we need to distinguish len and address_len + let len = extensions + .iter() + .try_fold(address_len, |acc, ext| { + acc.checked_add(ext.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge)?; - let mut buf = BytesMut::with_capacity(16 + len); + let blen = 16usize + .checked_add(len.into()) + .ok_or(EncodeError::ValueTooLarge)?; + let mut buf = BytesMut::with_capacity(blen); buf.put_slice(&SIG[..]); buf.put_slice(&[ver_cmd, fam][..]); - buf.put_u16(len as u16); + buf.put_u16(len); match addresses { ProxyAddresses::Unspec => (), @@ -356,7 +745,13 @@ pub(crate) fn encode( } } - buf + for ext in extensions.iter() { + ext.encode(&mut buf)?; + } + + assert_eq!(buf.len(), blen); + + Ok(buf) } #[cfg(test)] @@ -371,23 +766,23 @@ mod parse_tests { #[test] fn test_unspec() { assert_eq!( - parse(&mut &[0u8; 16][..]), + parse_fully(&mut &[0u8; 4][..]), Ok(ProxyHeader::Version2 { command: ProxyCommand::Local, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); - let mut prefix = BytesMut::from(&[1u8][..]); - prefix.reserve(16); - prefix.extend_from_slice(&[0u8; 16][..]); + let mut prefix = BytesMut::from(&[1u8, 0, 0, 0][..]); assert_eq!( - parse(&mut prefix), + parse_fully(&mut prefix), Ok(ProxyHeader::Version2 { command: ProxyCommand::Proxy, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); } @@ -395,7 +790,7 @@ mod parse_tests { #[test] fn test_ipv4() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -424,7 +819,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -436,6 +831,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -480,6 +876,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -488,7 +885,7 @@ mod parse_tests { #[test] fn test_ipv6() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -541,7 +938,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -563,6 +960,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); @@ -641,6 +1039,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -651,12 +1050,12 @@ mod parse_tests { let mut data = [0u8; 200]; rand::thread_rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. - assert!(parse(&mut &data[..]).is_err()); + assert!(parse_fully(&mut &data[..]).is_err()); - assert_eq!(parse(&mut &[0][..]), Err(ParseError::UnexpectedEof)); + assert_eq!(parse_fully(&mut &[0][..]), Err(ParseError::UnexpectedEof)); assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -677,12 +1076,86 @@ mod parse_tests { }), ); } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + parse_fully( + &mut &[ + // Proxy command + 1u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + ), + Ok(ProxyHeader::Version2 { + command: ProxyCommand::Proxy, + addresses: ProxyAddresses::Unspec, + transport_protocol: ProxyTransportProtocol::Unspec, + extensions: vec![ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], + }), + ); + } } #[cfg(test)] mod encode_tests { use super::*; - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use pretty_assertions::assert_eq; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -690,10 +1163,10 @@ mod encode_tests { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; - fn signed(buf: &[u8]) -> Bytes { + fn signed(buf: &[u8]) -> BytesMut { let mut bytes = BytesMut::from(&SIG[..]); bytes.extend_from_slice(buf); - bytes.freeze() + bytes } #[test] @@ -703,8 +1176,9 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[2 << 4, 0, 0, 0][..]), + Ok(signed(&[2 << 4, 0, 0, 0][..])), ); assert_eq!( @@ -712,8 +1186,9 @@ mod encode_tests { ProxyCommand::Proxy, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 1, 0, 0, 0][..]), + Ok(signed(&[(2 << 4) | 1, 0, 0, 0][..])), ); assert_eq!( encode( @@ -723,8 +1198,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, 1 << 4, @@ -743,7 +1219,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); } @@ -757,8 +1233,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, (1 << 4) | 1, @@ -777,7 +1254,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); assert_eq!( encode( @@ -787,8 +1264,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 324), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 2187), }, + &[], ), - signed( + Ok(signed( &[ 2 << 4, (1 << 4) | 2, @@ -807,7 +1285,7 @@ mod encode_tests { (2187 >> 8) as u8, 2187u16 as u8, ][..] - ), + )), ); } @@ -825,9 +1303,10 @@ mod encode_tests { 0, 0, ), - } + }, + &[], ), - signed( + Ok(signed( &[ 2 << 4, (2 << 4) | 2, @@ -870,7 +1349,82 @@ mod encode_tests { 0, 0, ][..] + )), + ); + } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + encode( + ProxyCommand::Proxy, + ProxyTransportProtocol::Unspec, + ProxyAddresses::Unspec, + &[ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], ), + Ok(signed( + &[ + // Version 2, + // Proxy command + 0x21u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + )), ); } }