diff --git a/proxy-protocol.txt b/proxy-protocol.txt index 863eac6..196a445 100644 --- a/proxy-protocol.txt +++ b/proxy-protocol.txt @@ -33,7 +33,7 @@ Revision history reserved TLV type ranges, added TLV documentation, clarified string encoding. With contributions from Andriy Palamarchuk (Amazon.com). - 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) + 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) 1. Background diff --git a/src/lib.rs b/src/lib.rs index a1fe67d..57a0861 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub enum EncodeError { /// An error occurred while encoding version 1. #[snafu(display("there was an error while encoding the v1 header: {}", source))] WriteVersion1 { source: version1::EncodeError }, + /// An error occurred while encoding version 2. + #[snafu(display("there was an error while encoding the v2 header: {}", source))] + WriteVersion2 { source: version2::EncodeError }, } /// The PROXY header emitted at most once at the start of a new connection. @@ -70,6 +73,7 @@ pub enum ProxyHeader { /// The addresses used to connect to the proxy. addresses: version2::ProxyAddresses, + extensions: Vec, }, } @@ -150,7 +154,9 @@ pub fn encode(header: ProxyHeader) -> Result { command, transport_protocol, addresses, - } => version2::encode(command, transport_protocol, addresses), + extensions, + } => version2::encode(command, transport_protocol, addresses, &extensions[..]) + .context(WriteVersion2)?, #[allow(unreachable_patterns)] // May be required to be exhaustive. _ => unimplemented!("Unimplemented version?"), @@ -191,15 +197,15 @@ mod parse_tests { assert!(!buf.has_remaining()); // Consume the ENTIRE header! fn valid_v4( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -223,15 +229,25 @@ mod parse_tests { ); fn valid_v6( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } @@ -313,7 +329,7 @@ mod parse_tests { 0x49, 0x54, 0x0A, - (2 << 4) | 0, + 2 << 4, ]; const PREFIX_PROXY: [u8; 13] = [ 0x0D, @@ -337,6 +353,7 @@ mod parse_tests { command: version2::ProxyCommand::Local, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); assert_eq!( @@ -345,6 +362,7 @@ mod parse_tests { command: version2::ProxyCommand::Proxy, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); @@ -378,7 +396,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..] @@ -393,6 +411,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -439,6 +458,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -497,7 +517,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..], @@ -522,6 +542,7 @@ mod parse_tests { 0, ) }, + extensions: Vec::new(), }) ); @@ -602,6 +623,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -629,6 +651,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..], ] .concat() diff --git a/src/version1.rs b/src/version1.rs index 19fc38b..c9a0377 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -190,9 +190,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result unreachable!(), }; - Ok(super::ProxyHeader::Version1 { - addresses, - }) + Ok(super::ProxyHeader::Version1 { addresses }) } pub(crate) fn encode(addresses: ProxyAddresses) -> Result { @@ -278,15 +276,15 @@ mod parse_tests { #[test] fn test_valid_ipv4_cases() { fn valid( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -312,15 +310,25 @@ mod parse_tests { #[test] fn test_valid_ipv6_cases() { fn valid( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } diff --git a/src/version2.rs b/src/version2.rs index 6148e62..b3d6d3e 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; +use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] @@ -19,6 +20,28 @@ pub enum ParseError { #[snafu(display("insufficient length specified: {}, requires minimum {}", given, needs))] InsufficientLengthSpecified { given: usize, needs: usize }, + + #[snafu(display("invalid length specified: {}, causes overflow", given))] + LengthOverflow { given: usize }, + + #[snafu(display("invalid TLV type id specified: {}", type_id))] + InvalidTlvTypeId { type_id: u8 }, + + #[snafu(display("invalid UTF-8: {:?}", bytes))] + InvalidUtf8 { bytes: Vec }, + + #[snafu(display("invalid ASCII: {:?}", bytes))] + InvalidAscii { bytes: Vec }, + + #[snafu(display("trailing data: {:?}", len))] + TrailingData { len: usize }, +} + +#[derive(Debug, Snafu)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum EncodeError { + #[snafu(display("value is too large to encode"))] + ValueTooLarge, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -59,6 +82,310 @@ enum ProxyAddressFamily { Unix, } +trait Tlv: Sized { + /// Identifies the type + fn type_id(&self) -> u8; + + /// The byte size of the value if encoded, or None if too big + fn value_len(&self) -> Result; + + /// Write the value to the provided buffer + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError>; + + fn encoded_len(&self) -> Result { + self.value_len()? + .checked_add(3) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + let vlen = self.value_len()?; + if vlen + .checked_add(3) + .map_or(true, |tlv_len| buf.remaining_mut() < tlv_len.into()) + { + return Err(EncodeError::ValueTooLarge); + } + buf.put_u8(self.type_id()); + buf.put_u16(vlen); + self.encode_value(buf) + } + + // API note: + // We have to pass the len instead of using a view. + // Buf doesn't have a good view / subslice abstraction + // unlike plain slices or even the Bytes implementation + // IMHO (@g2p) it would be better for parse to receive a + // slice or a concrete type. + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result; + + fn parse(buf: &mut impl Buf) -> Result { + if buf.remaining() < 3 { + return Err(ParseError::UnexpectedEof); + } + let type_id = buf.get_u8(); + let vlen = buf.get_u16(); + let expected_rem = buf + .remaining() + .checked_sub(vlen.into()) + .ok_or(ParseError::UnexpectedEof)?; + let r = Self::parse_parts(type_id, vlen, buf)?; + // Assert, because it would be an internal error + assert_eq!(buf.remaining(), expected_rem); + Ok(r) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslClientFlags(u8); + +impl SslClientFlags { + pub fn is_ssl_or_tls(&self) -> bool { + (self.0 & 1) == 1 + } + + pub fn client_authenticated_connection(&self) -> bool { + (self.0 & 2) == 2 + } + pub fn client_authenticated_session(&self) -> bool { + (self.0 & 4) == 4 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslVerifyStatus(u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum SslExtensionTlv { + /// TLS or SSL version in ASCII + Version(String), + /// TLS or SSL cipher suite in ASCII, for example "ECDHE-RSA-AES128-GCM-SHA256" + Cipher(String), + /// TLS or SSL signature algorithm in ASCII + SigAlg(String), + /// TLS or SSL key algorithm in ASCII + KeyAlg(String), + /// With client authentication, the common name for the client certificate in UTF-8 + ClientCN(String), +} + +impl SslExtensionTlv { + fn as_str(&self) -> &str { + match self { + Self::Version(version) => version, + Self::Cipher(cipher) => cipher, + Self::SigAlg(sigalg) => sigalg, + Self::KeyAlg(keyalg) => keyalg, + Self::ClientCN(cn) => cn, + } + } +} + +impl Tlv for SslExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Version(_) => PP2_SUBTYPE_SSL_VERSION, + Self::ClientCN(_) => PP2_SUBTYPE_SSL_CN, + Self::Cipher(_) => PP2_SUBTYPE_SSL_CIPHER, + Self::SigAlg(_) => PP2_SUBTYPE_SSL_SIG_ALG, + Self::KeyAlg(_) => PP2_SUBTYPE_SSL_KEY_ALG, + } + } + + fn value_len(&self) -> Result { + self.as_str() + .len() + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_slice(self.as_str().as_bytes()); + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_SUBTYPE_SSL_VERSION => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CIPHER => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_SIG_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_KEY_ALG => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CN => Self::Version(str_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ssl { + client: SslClientFlags, + verify: SslVerifyStatus, + extensions: Vec, +} + +impl Ssl { + fn parse(buf: &mut impl Buf, len: u16) -> Result { + if buf.remaining() < len.into() { + return Err(ParseError::UnexpectedEof); + } + let mut ext_len = len + .checked_sub(5) + .ok_or(ParseError::InsufficientLengthSpecified { + given: len.into(), + needs: 5, + })?; + let client = SslClientFlags(buf.get_u8()); + let verify = SslVerifyStatus(buf.get_u32()); + let mut extensions = Vec::new(); + while ext_len > 0 { + let rem0 = buf.remaining(); + extensions.push(SslExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + // The assert enforces that Buf is implemented sanely + // and not rewound. + let parsed = rem0.checked_sub(rem).expect("Buf error"); + // We don't enforce u16-sized buffers. + // Since we don't pass a bound on how much to parse, + // we can't enforce that the extension parser won't read + // (slightly) more than 64k. + // The assert is safe since ext_len was already u16 and the + // new value is lower. + ext_len = usize::from(ext_len) + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len.into(), + needs: parsed, + })? + .try_into() + .expect("Math error"); + } + Ok(Self { + client, + verify, + extensions, + }) + } + + fn encoded_len(&self) -> Result { + // 1 for flags, 4 for verify status, plus all nested TLVs + self.extensions + .iter() + .try_fold(5u16, |sum, subtlv| { + sum.checked_add(subtlv.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_u8(self.client.0); + buf.put_u32(self.verify.0); + for ext in self.extensions.iter() { + ext.encode(buf)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum ExtensionTlv { + Alpn(Vec), + Authority(String), + Crc32c(u32), + UniqueId(Vec), + Ssl(Ssl), + NetNs(String), +} + +pub(crate) const PP2_TYPE_ALPN: u8 = 0x01; +pub(crate) const PP2_TYPE_AUTHORITY: u8 = 0x02; +pub(crate) const PP2_TYPE_CRC32C: u8 = 0x03; +pub(crate) const PP2_TYPE_NOOP: u8 = 0x04; +pub(crate) const PP2_TYPE_UNIQUE_ID: u8 = 0x05; +pub(crate) const PP2_TYPE_SSL: u8 = 0x20; +pub(crate) const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21; +pub(crate) const PP2_SUBTYPE_SSL_CN: u8 = 0x22; +pub(crate) const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23; +pub(crate) const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24; +pub(crate) const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25; +pub(crate) const PP2_TYPE_NETNS: u8 = 0x30; + +fn vec_from_buf(buf: &mut impl Buf, len: u16) -> Vec { + let mut r = vec![0; len.into()]; + buf.copy_to_slice(&mut r); + r +} + +fn str_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let v = vec_from_buf(buf, len); + let r = String::from_utf8(v).map_err(|e| ParseError::InvalidUtf8 { + bytes: e.into_bytes(), + })?; + Ok(r) +} + +fn ascii_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let s = str_from_buf(buf, len)?; + if !s.is_ascii() { + Err(ParseError::InvalidAscii { + bytes: s.into_bytes(), + }) + } else { + Ok(s) + } +} + +impl Tlv for ExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Alpn(_) => PP2_TYPE_ALPN, + Self::Authority(_) => PP2_TYPE_AUTHORITY, + Self::Crc32c(_) => PP2_TYPE_CRC32C, + Self::UniqueId(_) => PP2_TYPE_UNIQUE_ID, + Self::Ssl(_) => PP2_TYPE_SSL, + Self::NetNs(_) => PP2_TYPE_NETNS, + } + } + + fn value_len(&self) -> Result { + match self { + Self::Alpn(alpn) => alpn.len(), + Self::Authority(authority) => authority.len(), + Self::Crc32c(_) => 4, + Self::UniqueId(id) => id.len(), + Self::Ssl(data) => data.encoded_len()?.into(), + Self::NetNs(netns) => netns.len(), + } + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + match self { + Self::Alpn(by) | Self::UniqueId(by) => buf.put_slice(by), + Self::Authority(st) | Self::NetNs(st) => buf.put_slice(st.as_bytes()), + Self::Crc32c(crc) => buf.put_u32(*crc), + Self::Ssl(ssl) => ssl.encode(buf)?, + }; + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_TYPE_ALPN => Self::Alpn(vec_from_buf(buf, len)), + PP2_TYPE_AUTHORITY => Self::Authority(str_from_buf(buf, len)?), + PP2_TYPE_CRC32C => Self::Crc32c(buf.get_u32()), + PP2_TYPE_UNIQUE_ID => Self::UniqueId(vec_from_buf(buf, len)), + PP2_TYPE_SSL => Self::Ssl(Ssl::parse(buf, len)?), + PP2_TYPE_NETNS => Self::NetNs(ascii_from_buf(buf, len)?), + _ => return Err(ParseError::InvalidTlvTypeId { type_id }), + }) + } +} + +// Note: this is internal, assumes the first 12 bytes were parsed, +// and ignores the version half of the first byte. pub(crate) fn parse(buf: &mut impl Buf) -> Result { // We need to parse the following: // @@ -71,9 +398,10 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result> 4; + let command = st << 4 >> 4; let command = match command { 0 => ProxyCommand::Local, 1 => ProxyCommand::Proxy, @@ -100,18 +428,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= length, UnexpectedEof); - buf.advance(length); - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unspec, - }); - } + ensure!(buf.remaining() >= length, UnexpectedEof); // Time to parse the following: // @@ -134,102 +451,138 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result } unix_addr; // > }; - if address_family == ProxyAddressFamily::Unix { - ensure!( - length >= 108 * 2, - InsufficientLengthSpecified { + // The full length of address data, + // including two addresses and two ports + let address_len = match address_family { + ProxyAddressFamily::Inet => (4 + 2) * 2, + ProxyAddressFamily::Inet6 => (16 + 2) * 2, + ProxyAddressFamily::Unix => 108 * 2, + ProxyAddressFamily::Unspec => 0, + }; + + let mut ext_len = + length + .checked_sub(address_len) + .ok_or(ParseError::InsufficientLengthSpecified { given: length, - needs: 108usize * 2, - }, - ); - ensure!(buf.remaining() >= 108 * 2, UnexpectedEof); - let mut source = [0u8; 108]; - let mut destination = [0u8; 108]; - buf.copy_to_slice(&mut source[..]); - buf.copy_to_slice(&mut destination[..]); - // TODO(Mariell Hoversholm): Support TLVs - if length > 108 * 2 { - buf.advance(length - (108 * 2)); - } + needs: address_len, + })?; + ensure!(buf.remaining() >= address_len, UnexpectedEof,); - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unix { + let addresses = match address_family { + ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, + ProxyAddressFamily::Unix => { + let mut source = [0u8; 108]; + let mut destination = [0u8; 108]; + buf.copy_to_slice(&mut source[..]); + buf.copy_to_slice(&mut destination[..]); + ProxyAddresses::Unix { source, destination, - }, - }); - } + } + } + ProxyAddressFamily::Inet => { + let mut data = [0u8; 4]; + buf.copy_to_slice(&mut data[..]); + let source = Ipv4Addr::from(data); - let port_length = 4; - let address_length = match address_family { - ProxyAddressFamily::Inet => 8, - ProxyAddressFamily::Inet6 => 32, - _ => unreachable!(), - }; + buf.copy_to_slice(&mut data); + let destination = Ipv4Addr::from(data); - ensure!( - length >= port_length + address_length, - InsufficientLengthSpecified { - given: length, - needs: port_length + address_length, - }, - ); - ensure!( - buf.remaining() >= port_length + address_length, - UnexpectedEof, - ); - - let addresses = if address_family == ProxyAddressFamily::Inet { - let mut data = [0u8; 4]; - buf.copy_to_slice(&mut data[..]); - let source = Ipv4Addr::from(data); - - buf.copy_to_slice(&mut data); - let destination = Ipv4Addr::from(data); - - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(source, source_port), - destination: SocketAddrV4::new(destination, destination_port), + ProxyAddresses::Ipv4 { + source: SocketAddrV4::new(source, source_port), + destination: SocketAddrV4::new(destination, destination_port), + } } - } else { - let mut data = [0u8; 16]; - buf.copy_to_slice(&mut data); - let source = Ipv6Addr::from(data); + ProxyAddressFamily::Inet6 => { + let mut data = [0u8; 16]; + buf.copy_to_slice(&mut data); + let source = Ipv6Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv6Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv6Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(source, source_port, 0, 0), - destination: SocketAddrV6::new(destination, destination_port, 0, 0), + ProxyAddresses::Ipv6 { + source: SocketAddrV6::new(source, source_port, 0, 0), + destination: SocketAddrV6::new(destination, destination_port, 0, 0), + } } }; - if length > port_length + address_length { - // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - (port_length + address_length)); + let mut extensions = Vec::new(); + while ext_len > 0 { + // At this point, we know that remaining() >= ext_len + if buf.chunk()[0] == PP2_TYPE_NOOP { + if ext_len < 3 { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: 3, + }); + } + // Read/skip the type after peeking + buf.get_u8(); + let skip_len = buf.get_u16(); + let noop_len = 3u16 + .checked_add(skip_len) + .ok_or(ParseError::LengthOverflow { + given: skip_len.into(), + })? + .into(); + if noop_len > ext_len { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: noop_len, + }); + } + ext_len -= noop_len; + } else { + let rem0 = buf.remaining(); + extensions.push(ExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + let parsed = rem0.checked_sub(rem).expect("Buf error"); + ext_len = + ext_len + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: parsed, + })?; + } } Ok(super::ProxyHeader::Version2 { command, transport_protocol, addresses, + extensions, }) } +// Currently used in tests, has to be internal for the same reasons +// parse() currently is. +#[cfg(test)] +pub(crate) fn parse_fully(buf: &mut impl Buf) -> Result { + let r = parse(buf)?; + if buf.has_remaining() { + return Err(ParseError::TrailingData { + len: buf.remaining(), + }); + } + Ok(r) +} + pub(crate) fn encode( command: ProxyCommand, transport_protocol: ProxyTransportProtocol, addresses: ProxyAddresses, -) -> BytesMut { + extensions: &[ExtensionTlv], +) -> Result { // > struct proxy_hdr_v2 { // > uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ // > uint8_t ver_cmd; /* protocol version and command */ @@ -341,23 +694,27 @@ pub(crate) fn encode( // > uint8_t dst_addr[108]; // > } unix_addr; // > }; - let len = match addresses { + let address_len: u16 = match addresses { ProxyAddresses::Unspec => 0, - ProxyAddresses::Unix { .. } => { - 108 + 108 - } - ProxyAddresses::Ipv4 { .. } => { - 4 + 4 + 2 + 2 - } - ProxyAddresses::Ipv6 { .. } => { - 16 + 16 + 2 + 2 - } + ProxyAddresses::Unix { .. } => 108 + 108, + ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, + ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; + // With extensions, we need to distinguish len and address_len + let len = extensions + .iter() + .try_fold(address_len, |acc, ext| { + acc.checked_add(ext.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge)?; - let mut buf = BytesMut::with_capacity(16 + len); + let blen = 16usize + .checked_add(len.into()) + .ok_or(EncodeError::ValueTooLarge)?; + let mut buf = BytesMut::with_capacity(blen); buf.put_slice(&SIG[..]); buf.put_slice(&[ver_cmd, fam][..]); - buf.put_u16(len as u16); + buf.put_u16(len); match addresses { ProxyAddresses::Unspec => (), @@ -388,7 +745,13 @@ pub(crate) fn encode( } } - buf + for ext in extensions.iter() { + ext.encode(&mut buf)?; + } + + assert_eq!(buf.len(), blen); + + Ok(buf) } #[cfg(test)] @@ -403,23 +766,23 @@ mod parse_tests { #[test] fn test_unspec() { assert_eq!( - parse(&mut &[0u8; 16][..]), + parse_fully(&mut &[0u8; 4][..]), Ok(ProxyHeader::Version2 { command: ProxyCommand::Local, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); - let mut prefix = BytesMut::from(&[1u8][..]); - prefix.reserve(16); - prefix.extend_from_slice(&[0u8; 16][..]); + let mut prefix = BytesMut::from(&[1u8, 0, 0, 0][..]); assert_eq!( - parse(&mut prefix), + parse_fully(&mut prefix), Ok(ProxyHeader::Version2 { command: ProxyCommand::Proxy, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); } @@ -427,7 +790,7 @@ mod parse_tests { #[test] fn test_ipv4() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -456,7 +819,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -468,6 +831,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -512,6 +876,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -520,7 +885,7 @@ mod parse_tests { #[test] fn test_ipv6() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -573,7 +938,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -595,6 +960,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); @@ -673,6 +1039,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -683,12 +1050,12 @@ mod parse_tests { let mut data = [0u8; 200]; rand::thread_rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. - assert!(parse(&mut &data[..]).is_err()); + assert!(parse_fully(&mut &data[..]).is_err()); - assert_eq!(parse(&mut &[0][..]), Err(ParseError::UnexpectedEof)); + assert_eq!(parse_fully(&mut &[0][..]), Err(ParseError::UnexpectedEof)); assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -698,6 +1065,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..] ), Err(ParseError::InsufficientLengthSpecified { @@ -706,12 +1076,86 @@ mod parse_tests { }), ); } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + parse_fully( + &mut &[ + // Proxy command + 1u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + ), + Ok(ProxyHeader::Version2 { + command: ProxyCommand::Proxy, + addresses: ProxyAddresses::Unspec, + transport_protocol: ProxyTransportProtocol::Unspec, + extensions: vec![ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], + }), + ); + } } #[cfg(test)] mod encode_tests { use super::*; - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use pretty_assertions::assert_eq; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -719,10 +1163,10 @@ mod encode_tests { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; - fn signed(buf: &[u8]) -> Bytes { + fn signed(buf: &[u8]) -> BytesMut { let mut bytes = BytesMut::from(&SIG[..]); bytes.extend_from_slice(buf); - bytes.freeze() + bytes } #[test] @@ -732,8 +1176,9 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 0, 0, 0, 0][..]), + Ok(signed(&[2 << 4, 0, 0, 0][..])), ); assert_eq!( @@ -741,8 +1186,9 @@ mod encode_tests { ProxyCommand::Proxy, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 1, 0, 0, 0][..]), + Ok(signed(&[(2 << 4) | 1, 0, 0, 0][..])), ); assert_eq!( encode( @@ -752,11 +1198,12 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, - (1 << 4) | 0, + 1 << 4, 0, 12, 1, @@ -772,7 +1219,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); } @@ -786,8 +1233,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, (1 << 4) | 1, @@ -806,7 +1254,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); assert_eq!( encode( @@ -816,10 +1264,11 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 324), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 2187), }, + &[], ), - signed( + Ok(signed( &[ - (2 << 4) | 0, + 2 << 4, (1 << 4) | 2, 0, 12, @@ -836,7 +1285,7 @@ mod encode_tests { (2187 >> 8) as u8, 2187u16 as u8, ][..] - ), + )), ); } @@ -847,23 +1296,19 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Datagram, ProxyAddresses::Ipv6 { - source: SocketAddrV6::new( - Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), - 8192, - 0, - 0, - ), + source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 8192, 0, 0,), destination: SocketAddrV6::new( Ipv6Addr::new(65535, 65535, 32767, 32766, 111, 222, 333, 444), 0, 0, 0, ), - } + }, + &[], ), - signed( + Ok(signed( &[ - (2 << 4) | 0, + 2 << 4, (2 << 4) | 2, 0, 36, @@ -904,7 +1349,82 @@ mod encode_tests { 0, 0, ][..] + )), + ); + } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + encode( + ProxyCommand::Proxy, + ProxyTransportProtocol::Unspec, + ProxyAddresses::Unspec, + &[ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], ), + Ok(signed( + &[ + // Version 2, + // Proxy command + 0x21u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + )), ); } }