diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..84fca58 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "cargo" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" + - package-ecosystem: "github-actions" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b11eae6..afac59c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,7 @@ jobs: name: Build, test, check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -14,7 +14,7 @@ jobs: default: true components: rustfmt, clippy - name: Cache target - uses: actions/cache@v2 + uses: actions/cache@v5 with: path: | ~/.cargo/git/ @@ -22,6 +22,8 @@ jobs: target/ key: ${{ runner.os }}-proxy-protocol-${{ hashFiles('**/Cargo.toml') }} restore-keys: ${{ runner.os }}-proxy-protocol + - run: cargo fmt --check + - run: cargo test -- --include-ignored - run: cargo build - uses: actions-rs/cargo@v1 with: diff --git a/Cargo.toml b/Cargo.toml index fe7bd77..2652210 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,12 +10,12 @@ documentation = "https://docs.rs/proxy-protocol/" repository = "https://github.com/Proximyst/proxy-protocol.git" [dependencies] -snafu = "~0.6" +snafu = "~0.9" bytes = "~1" [dev-dependencies] -pretty_assertions = "^0.7" -rand = "~0.8" +pretty_assertions = "^1.4" +rand = "~0.10" [features] default = [] diff --git a/proxy-protocol.txt b/proxy-protocol.txt index 863eac6..196a445 100644 --- a/proxy-protocol.txt +++ b/proxy-protocol.txt @@ -33,7 +33,7 @@ Revision history reserved TLV type ranges, added TLV documentation, clarified string encoding. With contributions from Andriy Palamarchuk (Amazon.com). - 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) + 2020/03/05 - added the unique ID TLV type (Tim Düsterhus) 1. Background diff --git a/src/lib.rs b/src/lib.rs index a1fe67d..7197dbd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub enum EncodeError { /// An error occurred while encoding version 1. #[snafu(display("there was an error while encoding the v1 header: {}", source))] WriteVersion1 { source: version1::EncodeError }, + /// An error occurred while encoding version 2. + #[snafu(display("there was an error while encoding the v2 header: {}", source))] + WriteVersion2 { source: version2::EncodeError }, } /// The PROXY header emitted at most once at the start of a new connection. @@ -70,12 +73,13 @@ pub enum ProxyHeader { /// The addresses used to connect to the proxy. addresses: version2::ProxyAddresses, + extensions: Vec, }, } fn parse_version(buf: &mut impl Buf) -> Result { // There is a 6 byte header to v1, 12 byte to all binary versions. - ensure!(buf.remaining() >= 6, NotProxyHeader); + ensure!(buf.remaining() >= 6, NotProxyHeaderSnafu); // V1 is the only version that starts with "PROXY" (0x50 0x52 0x4F 0x58 // 0x59), and we can therefore decide version based on that. @@ -87,11 +91,11 @@ fn parse_version(buf: &mut impl Buf) -> Result { } // Now we require 13: 12 for the prefix, 1 for the version + command - ensure!(buf.remaining() >= 13, NotProxyHeader); + ensure!(buf.remaining() >= 13, NotProxyHeaderSnafu); ensure!( buf.chunk()[..12] == [0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A], - NotProxyHeader + NotProxyHeaderSnafu ); buf.advance(12); @@ -113,7 +117,7 @@ fn parse_version(buf: &mut impl Buf) -> Result { // Interesting edge-case! This is the only time version 1 would be invalid. if version == 1 { - return InvalidVersion { version: 1u32 }.fail(); + return InvalidVersionSnafu { version: 1u32 }.fail(); } Ok(version as u32) @@ -125,15 +129,12 @@ fn parse_version(buf: &mut impl Buf) -> Result { /// available through [Buf::chunk], at the very least for the header. Data that /// follows may be chunked as you wish. pub fn parse(buf: &mut impl Buf) -> Result { - let version = match parse_version(buf) { - Ok(ver) => ver, - Err(e) => return Err(e), - }; + let version = parse_version(buf)?; Ok(match version { - 1 => self::version1::parse(buf).context(Version1)?, - 2 => self::version2::parse(buf).context(Version2)?, - _ => return InvalidVersion { version }.fail(), + 1 => self::version1::parse(buf).context(Version1Snafu)?, + 2 => self::version2::parse(buf).context(Version2Snafu)?, + _ => return InvalidVersionSnafu { version }.fail(), }) } @@ -144,13 +145,15 @@ pub fn parse(buf: &mut impl Buf) -> Result { pub fn encode(header: ProxyHeader) -> Result { Ok(match header { ProxyHeader::Version1 { addresses, .. } => { - version1::encode(addresses).context(WriteVersion1)? + version1::encode(addresses).context(WriteVersion1Snafu)? } ProxyHeader::Version2 { command, transport_protocol, addresses, - } => version2::encode(command, transport_protocol, addresses), + extensions, + } => version2::encode(command, transport_protocol, addresses, &extensions[..]) + .context(WriteVersion2Snafu)?, #[allow(unreachable_patterns)] // May be required to be exhaustive. _ => unimplemented!("Unimplemented version?"), @@ -182,7 +185,7 @@ mod parse_tests { ); let mut random = [0u8; 128]; - rand::thread_rng().fill_bytes(&mut random); + rand::rng().fill_bytes(&mut random); let mut header = b"PROXY UNKNOWN ".to_vec(); header.extend(&random[..]); header.extend(b"\r\n"); @@ -191,15 +194,15 @@ mod parse_tests { assert!(!buf.has_remaining()); // Consume the ENTIRE header! fn valid_v4( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -223,15 +226,25 @@ mod parse_tests { ); fn valid_v6( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: version1::ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } @@ -313,7 +326,7 @@ mod parse_tests { 0x49, 0x54, 0x0A, - (2 << 4) | 0, + 2 << 4, ]; const PREFIX_PROXY: [u8; 13] = [ 0x0D, @@ -337,6 +350,7 @@ mod parse_tests { command: version2::ProxyCommand::Local, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); assert_eq!( @@ -345,6 +359,7 @@ mod parse_tests { command: version2::ProxyCommand::Proxy, addresses: version2::ProxyAddresses::Unspec, transport_protocol: version2::ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); @@ -378,7 +393,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..] @@ -393,6 +408,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -439,6 +455,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -497,7 +514,7 @@ mod parse_tests { 1, 1, // TLV - 69, + version2::PP2_TYPE_NOOP, 0, 0, ][..], @@ -522,6 +539,7 @@ mod parse_tests { 0, ) }, + extensions: Vec::new(), }) ); @@ -602,12 +620,13 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header let mut data = [0u8; 200]; - rand::thread_rng().fill_bytes(&mut data); + rand::rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. assert!(parse(&mut &data[..]).is_err()); @@ -629,6 +648,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..], ] .concat() diff --git a/src/version1.rs b/src/version1.rs index 19fc38b..bbf198c 100644 --- a/src/version1.rs +++ b/src/version1.rs @@ -64,7 +64,7 @@ fn count_till_first(haystack: &[u8], needle: u8) -> Option { } pub(crate) fn parse(buf: &mut impl Buf) -> Result { - ensure!(buf.remaining() >= 4, UnexpectedEof); + ensure!(buf.remaining() >= 4, UnexpectedEofSnafu); let step = buf.get_u8(); @@ -83,23 +83,23 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result ProxyAddressFamily::Tcp4, b'6' => ProxyAddressFamily::Tcp6, - _ => return IllegalAddressFamily.fail(), + _ => return IllegalAddressFamilySnafu.fail(), } } b'U' => { // Unknown - ensure!(buf.remaining() >= 6, UnexpectedEof); // Not 7, we consumed 1. + ensure!(buf.remaining() >= 6, UnexpectedEofSnafu); // Not 7, we consumed 1. buf.advance(6); ProxyAddressFamily::Unknown } - _ => return IllegalAddressFamily.fail(), + _ => return IllegalAddressFamilySnafu.fail(), }; if address_family == ProxyAddressFamily::Unknown { // Just consume up to the end. let mut cr = false; loop { - ensure!(buf.has_remaining(), UnexpectedEof); + ensure!(buf.has_remaining(), UnexpectedEofSnafu); let b = buf.get_u8(); if cr && b == LF { break; @@ -112,70 +112,76 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result= 8, UnexpectedEof); + ensure!(buf.remaining() >= 8, UnexpectedEofSnafu); buf.advance(1); // Space - let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?; + let count = count_till_first(buf.chunk(), b' ').context(MissingAddressSnafu)?; let source = &buf.chunk()[..count]; - let source = std::str::from_utf8(source).context(NonAscii)?; + let source = std::str::from_utf8(source).context(NonAsciiSnafu)?; let source = match address_family { - ProxyAddressFamily::Tcp4 => IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddress)?), - ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddress)?), + ProxyAddressFamily::Tcp4 => { + IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddressSnafu)?) + } + ProxyAddressFamily::Tcp6 => { + IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddressSnafu)?) + } ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"), }; buf.advance(count); // Same as above, another address incoming. - ensure!(buf.remaining() >= 8, UnexpectedEof); + ensure!(buf.remaining() >= 8, UnexpectedEofSnafu); buf.advance(1); // Space - let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?; + let count = count_till_first(buf.chunk(), b' ').context(MissingAddressSnafu)?; let destination = &buf.chunk()[..count]; - let destination = std::str::from_utf8(destination).context(NonAscii)?; + let destination = std::str::from_utf8(destination).context(NonAsciiSnafu)?; let destination = match address_family { ProxyAddressFamily::Tcp4 => { - IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddress)?) + IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddressSnafu)?) } ProxyAddressFamily::Tcp6 => { - IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddress)?) + IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddressSnafu)?) } ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"), }; buf.advance(count); // Space, then a port. 0 is minimum valid port, so 1 byte. - ensure!(buf.remaining() >= 2, UnexpectedEof); + ensure!(buf.remaining() >= 2, UnexpectedEofSnafu); buf.advance(1); - let count = count_till_first(buf.chunk(), b' ').context(InvalidPort)?; + let count = count_till_first(buf.chunk(), b' ').context(InvalidPortSnafu)?; let source_port = &buf.chunk()[..count]; - let source_port = std::str::from_utf8(source_port).context(NonAscii)?; + let source_port = std::str::from_utf8(source_port).context(NonAsciiSnafu)?; ensure!( // The port 0 is itself valid, but 01 is not. !source_port.starts_with('0') || source_port == "0", - InvalidPort, + InvalidPortSnafu, ); - let source_port: u16 = source_port.parse().ok().context(InvalidPort)?; + let source_port: u16 = source_port.parse().ok().context(InvalidPortSnafu)?; buf.advance(count); // Space, then a port, then CRLF. 0 is minimum valid port, so 1 byte. - ensure!(buf.remaining() >= 4, UnexpectedEof); + ensure!(buf.remaining() >= 4, UnexpectedEofSnafu); buf.advance(1); // This is the last member of the string. Read until CR; that's next up. - let count = count_till_first(buf.chunk(), CR).context(InvalidPort)?; + let count = count_till_first(buf.chunk(), CR).context(InvalidPortSnafu)?; let destination_port = &buf.chunk()[..count]; - let destination_port = std::str::from_utf8(destination_port).context(NonAscii)?; + let destination_port = std::str::from_utf8(destination_port).context(NonAsciiSnafu)?; ensure!( // The port 0 is itself valid, but 01 is not. !destination_port.starts_with('0') || destination_port == "0", - InvalidPort, + InvalidPortSnafu, ); - let destination_port: u16 = destination_port.parse().ok().context(InvalidPort)?; + let destination_port: u16 = destination_port.parse().ok().context(InvalidPortSnafu)?; buf.advance(count); - ensure!(buf.get_u8() == CR, IllegalHeaderEnding); - ensure!(buf.get_u8() == LF, IllegalHeaderEnding); + ensure!(buf.get_u8() == CR, IllegalHeaderEndingSnafu); + // We only checked up to the CR + ensure!(buf.remaining() >= 1, UnexpectedEofSnafu); + ensure!(buf.get_u8() == LF, IllegalHeaderEndingSnafu); let addresses = match (source, destination) { (IpAddr::V4(source), IpAddr::V4(destination)) => ProxyAddresses::Ipv4 { @@ -190,9 +196,7 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result unreachable!(), }; - Ok(super::ProxyHeader::Version1 { - addresses, - }) + Ok(super::ProxyHeader::Version1 { addresses }) } pub(crate) fn encode(addresses: ProxyAddresses) -> Result { @@ -202,14 +206,14 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result // Reserve as much data as we're gonna need -- at most. let mut buf = BytesMut::with_capacity(107).writer(); - buf.write_all(&b"PROXY TCP"[..]).context(StdIo)?; + buf.write_all(&b"PROXY TCP"[..]).context(StdIoSnafu)?; match addresses { ProxyAddresses::Ipv4 { source, destination, } => { - buf.write(&b"4 "[..]).context(StdIo)?; + buf.write(&b"4 "[..]).context(StdIoSnafu)?; write!( buf, "{} {} {} {}\r\n", @@ -218,13 +222,13 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result source.port(), destination.port(), ) - .context(StdIo)?; + .context(StdIoSnafu)?; } ProxyAddresses::Ipv6 { source, destination, } => { - buf.write(&b"6 "[..]).context(StdIo)?; + buf.write(&b"6 "[..]).context(StdIoSnafu)?; write!( buf, "{} {} {} {}\r\n", @@ -233,7 +237,7 @@ pub(crate) fn encode(addresses: ProxyAddresses) -> Result source.port(), destination.port(), ) - .context(StdIo)?; + .context(StdIoSnafu)?; } ProxyAddresses::Unknown => unreachable!(), } @@ -266,7 +270,7 @@ mod parse_tests { ); let mut random = [0u8; 128]; - rand::thread_rng().fill_bytes(&mut random); + rand::rng().fill_bytes(&mut random); let mut header = b"UNKNOWN ".to_vec(); header.extend(&random[..]); header.extend(b"\r\n"); @@ -278,15 +282,15 @@ mod parse_tests { #[test] fn test_valid_ipv4_cases() { fn valid( - (a, b, c, d): (u8, u8, u8, u8), - e: u16, - (f, g, h, i): (u8, u8, u8, u8), - j: u16, + (s0, s1, s2, s3): (u8, u8, u8, u8), + sp: u16, + (d0, d1, d2, d3): (u8, u8, u8, u8), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e), - destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j), + source: SocketAddrV4::new(Ipv4Addr::new(s0, s1, s2, s3), sp), + destination: SocketAddrV4::new(Ipv4Addr::new(d0, d1, d2, d3), dp), }, } } @@ -312,15 +316,25 @@ mod parse_tests { #[test] fn test_valid_ipv6_cases() { fn valid( - (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16), - i: u16, - (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16), - r: u16, + (s0, s1, s2, s3, s4, s5, s6, s7): (u16, u16, u16, u16, u16, u16, u16, u16), + sp: u16, + (d0, d1, d2, d3, d4, d5, d6, d7): (u16, u16, u16, u16, u16, u16, u16, u16), + dp: u16, ) -> ProxyHeader { ProxyHeader::Version1 { addresses: ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0), - destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0), + source: SocketAddrV6::new( + Ipv6Addr::new(s0, s1, s2, s3, s4, s5, s6, s7), + sp, + 0, + 0, + ), + destination: SocketAddrV6::new( + Ipv6Addr::new(d0, d1, d2, d3, d4, d5, d6, d7), + dp, + 0, + 0, + ), }, } } diff --git a/src/version2.rs b/src/version2.rs index 6148e62..238085c 100644 --- a/src/version2.rs +++ b/src/version2.rs @@ -1,6 +1,7 @@ use bytes::{Buf, BufMut as _, BytesMut}; use snafu::{ensure, Snafu}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; +use std::convert::TryInto; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, Snafu)] #[cfg_attr(test, derive(PartialEq, Eq))] @@ -19,6 +20,28 @@ pub enum ParseError { #[snafu(display("insufficient length specified: {}, requires minimum {}", given, needs))] InsufficientLengthSpecified { given: usize, needs: usize }, + + #[snafu(display("invalid length specified: {}, causes overflow", given))] + LengthOverflow { given: usize }, + + #[snafu(display("invalid TLV type id specified: {}", type_id))] + InvalidTlvTypeId { type_id: u8 }, + + #[snafu(display("invalid UTF-8: {:?}", bytes))] + InvalidUtf8 { bytes: Vec }, + + #[snafu(display("invalid ASCII: {:?}", bytes))] + InvalidAscii { bytes: Vec }, + + #[snafu(display("trailing data: {:?}", len))] + TrailingData { len: usize }, +} + +#[derive(Debug, Snafu)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum EncodeError { + #[snafu(display("value is too large to encode"))] + ValueTooLarge, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -51,6 +74,25 @@ pub enum ProxyAddresses { }, } +impl ProxyAddresses { + /// Construct IPv4 or IPv6 from given source and destination address. + /// + /// Returns `None` if the addresses use a different IP version. + pub fn from_socket_addrs(source: SocketAddr, destination: SocketAddr) -> Option { + match (source, destination) { + (SocketAddr::V4(source), SocketAddr::V4(destination)) => Some(Self::Ipv4 { + source, + destination, + }), + (SocketAddr::V6(source), SocketAddr::V6(destination)) => Some(Self::Ipv6 { + source, + destination, + }), + (_, _) => None, + } + } +} + #[derive(PartialEq, Eq)] enum ProxyAddressFamily { Unspec, @@ -59,6 +101,310 @@ enum ProxyAddressFamily { Unix, } +trait Tlv: Sized { + /// Identifies the type + fn type_id(&self) -> u8; + + /// The byte size of the value if encoded, or None if too big + fn value_len(&self) -> Result; + + /// Write the value to the provided buffer + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError>; + + fn encoded_len(&self) -> Result { + self.value_len()? + .checked_add(3) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + let vlen = self.value_len()?; + if vlen + .checked_add(3) + .is_none_or(|tlv_len| buf.remaining_mut() < tlv_len.into()) + { + return Err(EncodeError::ValueTooLarge); + } + buf.put_u8(self.type_id()); + buf.put_u16(vlen); + self.encode_value(buf) + } + + // API note: + // We have to pass the len instead of using a view. + // Buf doesn't have a good view / subslice abstraction + // unlike plain slices or even the Bytes implementation + // IMHO (@g2p) it would be better for parse to receive a + // slice or a concrete type. + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result; + + fn parse(buf: &mut impl Buf) -> Result { + if buf.remaining() < 3 { + return UnexpectedEofSnafu.fail(); + } + let type_id = buf.get_u8(); + let vlen = buf.get_u16(); + let expected_rem = buf + .remaining() + .checked_sub(vlen.into()) + .ok_or(ParseError::UnexpectedEof)?; + let r = Self::parse_parts(type_id, vlen, buf)?; + // Assert, because it would be an internal error + assert_eq!(buf.remaining(), expected_rem); + Ok(r) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslClientFlags(u8); + +impl SslClientFlags { + pub fn is_ssl_or_tls(&self) -> bool { + (self.0 & 1) == 1 + } + + pub fn client_authenticated_connection(&self) -> bool { + (self.0 & 2) == 2 + } + pub fn client_authenticated_session(&self) -> bool { + (self.0 & 4) == 4 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SslVerifyStatus(u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum SslExtensionTlv { + /// TLS or SSL version in ASCII + Version(String), + /// TLS or SSL cipher suite in ASCII, for example "ECDHE-RSA-AES128-GCM-SHA256" + Cipher(String), + /// TLS or SSL signature algorithm in ASCII + SigAlg(String), + /// TLS or SSL key algorithm in ASCII + KeyAlg(String), + /// With client authentication, the common name for the client certificate in UTF-8 + ClientCN(String), +} + +impl SslExtensionTlv { + fn as_str(&self) -> &str { + match self { + Self::Version(version) => version, + Self::Cipher(cipher) => cipher, + Self::SigAlg(sigalg) => sigalg, + Self::KeyAlg(keyalg) => keyalg, + Self::ClientCN(cn) => cn, + } + } +} + +impl Tlv for SslExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Version(_) => PP2_SUBTYPE_SSL_VERSION, + Self::ClientCN(_) => PP2_SUBTYPE_SSL_CN, + Self::Cipher(_) => PP2_SUBTYPE_SSL_CIPHER, + Self::SigAlg(_) => PP2_SUBTYPE_SSL_SIG_ALG, + Self::KeyAlg(_) => PP2_SUBTYPE_SSL_KEY_ALG, + } + } + + fn value_len(&self) -> Result { + self.as_str() + .len() + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_slice(self.as_str().as_bytes()); + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_SUBTYPE_SSL_VERSION => Self::Version(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CIPHER => Self::Cipher(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_SIG_ALG => Self::SigAlg(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_KEY_ALG => Self::KeyAlg(ascii_from_buf(buf, len)?), + PP2_SUBTYPE_SSL_CN => Self::ClientCN(str_from_buf(buf, len)?), + _ => return InvalidTlvTypeIdSnafu { type_id }.fail(), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ssl { + client: SslClientFlags, + verify: SslVerifyStatus, + extensions: Vec, +} + +impl Ssl { + fn parse(buf: &mut impl Buf, len: u16) -> Result { + if buf.remaining() < len.into() { + return UnexpectedEofSnafu.fail(); + } + let mut ext_len = + len.checked_sub(5) + .ok_or_else(|| ParseError::InsufficientLengthSpecified { + given: len.into(), + needs: 5, + })?; + let client = SslClientFlags(buf.get_u8()); + let verify = SslVerifyStatus(buf.get_u32()); + let mut extensions = Vec::new(); + while ext_len > 0 { + let rem0 = buf.remaining(); + extensions.push(SslExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + // The assert enforces that Buf is implemented sanely + // and not rewound. + let parsed = rem0.checked_sub(rem).expect("Buf error"); + // We don't enforce u16-sized buffers. + // Since we don't pass a bound on how much to parse, + // we can't enforce that the extension parser won't read + // (slightly) more than 64k. + // The assert is safe since ext_len was already u16 and the + // new value is lower. + ext_len = usize::from(ext_len) + .checked_sub(parsed) + .ok_or_else(|| ParseError::InsufficientLengthSpecified { + given: ext_len.into(), + needs: parsed, + })? + .try_into() + .expect("Math error"); + } + Ok(Self { + client, + verify, + extensions, + }) + } + + fn encoded_len(&self) -> Result { + // 1 for flags, 4 for verify status, plus all nested TLVs + self.extensions + .iter() + .try_fold(5u16, |sum, subtlv| { + sum.checked_add(subtlv.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge) + } + + fn encode(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + buf.put_u8(self.client.0); + buf.put_u32(self.verify.0); + for ext in self.extensions.iter() { + ext.encode(buf)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(not(feature = "always_exhaustive"), non_exhaustive)] // Extensions may be added +pub enum ExtensionTlv { + Alpn(Vec), + Authority(String), + Crc32c(u32), + UniqueId(Vec), + Ssl(Ssl), + NetNs(String), +} + +pub(crate) const PP2_TYPE_ALPN: u8 = 0x01; +pub(crate) const PP2_TYPE_AUTHORITY: u8 = 0x02; +pub(crate) const PP2_TYPE_CRC32C: u8 = 0x03; +pub(crate) const PP2_TYPE_NOOP: u8 = 0x04; +pub(crate) const PP2_TYPE_UNIQUE_ID: u8 = 0x05; +pub(crate) const PP2_TYPE_SSL: u8 = 0x20; +pub(crate) const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21; +pub(crate) const PP2_SUBTYPE_SSL_CN: u8 = 0x22; +pub(crate) const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23; +pub(crate) const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24; +pub(crate) const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25; +pub(crate) const PP2_TYPE_NETNS: u8 = 0x30; + +fn vec_from_buf(buf: &mut impl Buf, len: u16) -> Vec { + let mut r = vec![0; len.into()]; + buf.copy_to_slice(&mut r); + r +} + +fn str_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let v = vec_from_buf(buf, len); + let r = String::from_utf8(v).map_err(|e| ParseError::InvalidUtf8 { + bytes: e.into_bytes(), + })?; + Ok(r) +} + +fn ascii_from_buf(buf: &mut impl Buf, len: u16) -> Result { + let s = str_from_buf(buf, len)?; + if !s.is_ascii() { + Err(ParseError::InvalidAscii { + bytes: s.into_bytes(), + }) + } else { + Ok(s) + } +} + +impl Tlv for ExtensionTlv { + fn type_id(&self) -> u8 { + match self { + Self::Alpn(_) => PP2_TYPE_ALPN, + Self::Authority(_) => PP2_TYPE_AUTHORITY, + Self::Crc32c(_) => PP2_TYPE_CRC32C, + Self::UniqueId(_) => PP2_TYPE_UNIQUE_ID, + Self::Ssl(_) => PP2_TYPE_SSL, + Self::NetNs(_) => PP2_TYPE_NETNS, + } + } + + fn value_len(&self) -> Result { + match self { + Self::Alpn(alpn) => alpn.len(), + Self::Authority(authority) => authority.len(), + Self::Crc32c(_) => 4, + Self::UniqueId(id) => id.len(), + Self::Ssl(data) => data.encoded_len()?.into(), + Self::NetNs(netns) => netns.len(), + } + .try_into() + .map_err(|_| EncodeError::ValueTooLarge) + } + + fn encode_value(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { + match self { + Self::Alpn(by) | Self::UniqueId(by) => buf.put_slice(by), + Self::Authority(st) | Self::NetNs(st) => buf.put_slice(st.as_bytes()), + Self::Crc32c(crc) => buf.put_u32(*crc), + Self::Ssl(ssl) => ssl.encode(buf)?, + }; + Ok(()) + } + + fn parse_parts(type_id: u8, len: u16, buf: &mut impl Buf) -> Result { + Ok(match type_id { + PP2_TYPE_ALPN => Self::Alpn(vec_from_buf(buf, len)), + PP2_TYPE_AUTHORITY => Self::Authority(str_from_buf(buf, len)?), + PP2_TYPE_CRC32C => Self::Crc32c(buf.get_u32()), + PP2_TYPE_UNIQUE_ID => Self::UniqueId(vec_from_buf(buf, len)), + PP2_TYPE_SSL => Self::Ssl(Ssl::parse(buf, len)?), + PP2_TYPE_NETNS => Self::NetNs(ascii_from_buf(buf, len)?), + _ => return InvalidTlvTypeIdSnafu { type_id }.fail(), + }) + } +} + +// Note: this is internal, assumes the first 12 bytes were parsed, +// and ignores the version half of the first byte. pub(crate) fn parse(buf: &mut impl Buf) -> Result { // We need to parse the following: // @@ -71,18 +417,19 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result> 4; + let command = st << 4 >> 4; let command = match command { 0 => ProxyCommand::Local, 1 => ProxyCommand::Proxy, - cmd => return UnknownCommand { cmd }.fail(), + cmd => return UnknownCommandSnafu { cmd }.fail(), }; // 4 bits for address family, 4 bits for transport protocol, // then 2 bytes for the length. - ensure!(buf.remaining() >= 3, UnexpectedEof); + ensure!(buf.remaining() >= 3, UnexpectedEofSnafu); let byte = buf.get_u8(); let address_family = match byte >> 4 { @@ -90,28 +437,17 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result ProxyAddressFamily::Inet, 2 => ProxyAddressFamily::Inet6, 3 => ProxyAddressFamily::Unix, - family => return UnknownAddressFamily { family }.fail(), + family => return UnknownAddressFamilySnafu { family }.fail(), }; let transport_protocol = match byte << 4 >> 4 { 0 => ProxyTransportProtocol::Unspec, 1 => ProxyTransportProtocol::Stream, 2 => ProxyTransportProtocol::Datagram, - protocol => return UnknownTransportProtocol { protocol }.fail(), + protocol => return UnknownTransportProtocolSnafu { protocol }.fail(), }; let length = buf.get_u16() as usize; - - if address_family == ProxyAddressFamily::Unspec { - // We have no information to parse. - ensure!(buf.remaining() >= length, UnexpectedEof); - buf.advance(length); - - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unspec, - }); - } + ensure!(buf.remaining() >= length, UnexpectedEofSnafu); // Time to parse the following: // @@ -134,15 +470,23 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result } unix_addr; // > }; + // The full length of address data, + // including two addresses and two ports + let address_len = match address_family { + ProxyAddressFamily::Inet => (4 + 2) * 2, + ProxyAddressFamily::Inet6 => (16 + 2) * 2, + ProxyAddressFamily::Unix => 108 * 2, + ProxyAddressFamily::Unspec => 0, + }; if address_family == ProxyAddressFamily::Unix { ensure!( length >= 108 * 2, - InsufficientLengthSpecified { + InsufficientLengthSpecifiedSnafu { given: length, needs: 108usize * 2, - }, + } ); - ensure!(buf.remaining() >= 108 * 2, UnexpectedEof); + ensure!(buf.remaining() >= length, UnexpectedEofSnafu); let mut source = [0u8; 108]; let mut destination = [0u8; 108]; buf.copy_to_slice(&mut source[..]); @@ -151,85 +495,131 @@ pub(crate) fn parse(buf: &mut impl Buf) -> Result 108 * 2 { buf.advance(length - (108 * 2)); } + } + + let mut ext_len = + length + .checked_sub(address_len) + .ok_or(ParseError::InsufficientLengthSpecified { + given: length, + needs: address_len, + })?; + ensure!(buf.remaining() >= address_len, UnexpectedEofSnafu,); - return Ok(super::ProxyHeader::Version2 { - command, - transport_protocol, - addresses: ProxyAddresses::Unix { + let addresses = match address_family { + ProxyAddressFamily::Unspec => ProxyAddresses::Unspec, + ProxyAddressFamily::Unix => { + let mut source = [0u8; 108]; + let mut destination = [0u8; 108]; + buf.copy_to_slice(&mut source[..]); + buf.copy_to_slice(&mut destination[..]); + ProxyAddresses::Unix { source, destination, - }, - }); - } + } + } + ProxyAddressFamily::Inet => { + let mut data = [0u8; 4]; + buf.copy_to_slice(&mut data[..]); + let source = Ipv4Addr::from(data); - let port_length = 4; - let address_length = match address_family { - ProxyAddressFamily::Inet => 8, - ProxyAddressFamily::Inet6 => 32, - _ => unreachable!(), - }; + buf.copy_to_slice(&mut data); + let destination = Ipv4Addr::from(data); - ensure!( - length >= port_length + address_length, - InsufficientLengthSpecified { - given: length, - needs: port_length + address_length, - }, - ); - ensure!( - buf.remaining() >= port_length + address_length, - UnexpectedEof, - ); - - let addresses = if address_family == ProxyAddressFamily::Inet { - let mut data = [0u8; 4]; - buf.copy_to_slice(&mut data[..]); - let source = Ipv4Addr::from(data); - - buf.copy_to_slice(&mut data); - let destination = Ipv4Addr::from(data); - - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv4 { - source: SocketAddrV4::new(source, source_port), - destination: SocketAddrV4::new(destination, destination_port), + ProxyAddresses::Ipv4 { + source: SocketAddrV4::new(source, source_port), + destination: SocketAddrV4::new(destination, destination_port), + } } - } else { - let mut data = [0u8; 16]; - buf.copy_to_slice(&mut data); - let source = Ipv6Addr::from(data); + ProxyAddressFamily::Inet6 => { + let mut data = [0u8; 16]; + buf.copy_to_slice(&mut data); + let source = Ipv6Addr::from(data); - buf.copy_to_slice(&mut data); - let destination = Ipv6Addr::from(data); + buf.copy_to_slice(&mut data); + let destination = Ipv6Addr::from(data); - let source_port = buf.get_u16(); - let destination_port = buf.get_u16(); + let source_port = buf.get_u16(); + let destination_port = buf.get_u16(); - ProxyAddresses::Ipv6 { - source: SocketAddrV6::new(source, source_port, 0, 0), - destination: SocketAddrV6::new(destination, destination_port, 0, 0), + ProxyAddresses::Ipv6 { + source: SocketAddrV6::new(source, source_port, 0, 0), + destination: SocketAddrV6::new(destination, destination_port, 0, 0), + } } }; - if length > port_length + address_length { - // TODO(Mariell Hoversholm): Implement TLVs - buf.advance(length - (port_length + address_length)); + let mut extensions = Vec::new(); + while ext_len > 0 { + // At this point, we know that remaining() >= ext_len + if buf.chunk()[0] == PP2_TYPE_NOOP { + if ext_len < 3 { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: 3, + }); + } + // Read/skip the type after peeking + buf.get_u8(); + let skip_len = buf.get_u16(); + let noop_len = 3u16 + .checked_add(skip_len) + .ok_or_else(|| ParseError::LengthOverflow { + given: skip_len.into(), + })? + .into(); + if noop_len > ext_len { + return Err(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: noop_len, + }); + } + ext_len -= noop_len; + } else { + let rem0 = buf.remaining(); + extensions.push(ExtensionTlv::parse(buf)?); + let rem = buf.remaining(); + let parsed = rem0.checked_sub(rem).expect("Buf error"); + ext_len = + ext_len + .checked_sub(parsed) + .ok_or(ParseError::InsufficientLengthSpecified { + given: ext_len, + needs: parsed, + })?; + } } Ok(super::ProxyHeader::Version2 { command, transport_protocol, addresses, + extensions, }) } +// Currently used in tests, has to be internal for the same reasons +// parse() currently is. +#[cfg(test)] +pub(crate) fn parse_fully(buf: &mut impl Buf) -> Result { + let r = parse(buf)?; + if buf.has_remaining() { + return Err(ParseError::TrailingData { + len: buf.remaining(), + }); + } + Ok(r) +} + pub(crate) fn encode( command: ProxyCommand, transport_protocol: ProxyTransportProtocol, addresses: ProxyAddresses, -) -> BytesMut { + extensions: &[ExtensionTlv], +) -> Result { // > struct proxy_hdr_v2 { // > uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ // > uint8_t ver_cmd; /* protocol version and command */ @@ -341,23 +731,27 @@ pub(crate) fn encode( // > uint8_t dst_addr[108]; // > } unix_addr; // > }; - let len = match addresses { + let address_len: u16 = match addresses { ProxyAddresses::Unspec => 0, - ProxyAddresses::Unix { .. } => { - 108 + 108 - } - ProxyAddresses::Ipv4 { .. } => { - 4 + 4 + 2 + 2 - } - ProxyAddresses::Ipv6 { .. } => { - 16 + 16 + 2 + 2 - } + ProxyAddresses::Unix { .. } => 108 + 108, + ProxyAddresses::Ipv4 { .. } => 4 + 4 + 2 + 2, + ProxyAddresses::Ipv6 { .. } => 16 + 16 + 2 + 2, }; + // With extensions, we need to distinguish len and address_len + let len = extensions + .iter() + .try_fold(address_len, |acc, ext| { + acc.checked_add(ext.encoded_len().ok()?) + }) + .ok_or(EncodeError::ValueTooLarge)?; - let mut buf = BytesMut::with_capacity(16 + len); + let blen = 16usize + .checked_add(len.into()) + .ok_or(EncodeError::ValueTooLarge)?; + let mut buf = BytesMut::with_capacity(blen); buf.put_slice(&SIG[..]); buf.put_slice(&[ver_cmd, fam][..]); - buf.put_u16(len as u16); + buf.put_u16(len); match addresses { ProxyAddresses::Unspec => (), @@ -388,7 +782,13 @@ pub(crate) fn encode( } } - buf + for ext in extensions.iter() { + ext.encode(&mut buf)?; + } + + assert_eq!(buf.len(), blen); + + Ok(buf) } #[cfg(test)] @@ -403,23 +803,23 @@ mod parse_tests { #[test] fn test_unspec() { assert_eq!( - parse(&mut &[0u8; 16][..]), + parse_fully(&mut &[0u8; 4][..]), Ok(ProxyHeader::Version2 { command: ProxyCommand::Local, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); - let mut prefix = BytesMut::from(&[1u8][..]); - prefix.reserve(16); - prefix.extend_from_slice(&[0u8; 16][..]); + let mut prefix = BytesMut::from(&[1u8, 0, 0, 0][..]); assert_eq!( - parse(&mut prefix), + parse_fully(&mut prefix), Ok(ProxyHeader::Version2 { command: ProxyCommand::Proxy, addresses: ProxyAddresses::Unspec, transport_protocol: ProxyTransportProtocol::Unspec, + extensions: Vec::new(), }), ); } @@ -427,7 +827,7 @@ mod parse_tests { #[test] fn test_ipv4() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -456,7 +856,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -468,6 +868,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 257), }, + extensions: Vec::new(), }) ); @@ -512,6 +913,7 @@ mod parse_tests { source: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), destination: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 255 << 8), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -520,7 +922,7 @@ mod parse_tests { #[test] fn test_ipv6() { assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -573,7 +975,7 @@ mod parse_tests { 1, 1, // TLV - 69, + PP2_TYPE_NOOP, 0, 0, ][..] @@ -595,6 +997,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); @@ -673,6 +1076,7 @@ mod parse_tests { 0, ), }, + extensions: Vec::new(), }) ); assert!(data.remaining() == 4); // Consume the entire header @@ -681,14 +1085,14 @@ mod parse_tests { #[test] fn test_invalid_data() { let mut data = [0u8; 200]; - rand::thread_rng().fill_bytes(&mut data); + rand::rng().fill_bytes(&mut data); data[0] = 99; // Make 100% sure it's invalid. - assert!(parse(&mut &data[..]).is_err()); + assert!(parse_fully(&mut &data[..]).is_err()); - assert_eq!(parse(&mut &[0][..]), Err(ParseError::UnexpectedEof)); + assert_eq!(parse_fully(&mut &[0][..]), Err(ParseError::UnexpectedEof)); assert_eq!( - parse( + parse_fully( &mut &[ // Proxy command 1u8, @@ -698,6 +1102,9 @@ mod parse_tests { // 3 bytes is clearly too few if we expect 2 IPv4s and ports 0, 3, + 0, + 0, + 0, ][..] ), Err(ParseError::InsufficientLengthSpecified { @@ -706,12 +1113,86 @@ mod parse_tests { }), ); } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + parse_fully( + &mut &[ + // Proxy command + 1u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + ), + Ok(ProxyHeader::Version2 { + command: ProxyCommand::Proxy, + addresses: ProxyAddresses::Unspec, + transport_protocol: ProxyTransportProtocol::Unspec, + extensions: vec![ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], + }), + ); + } } #[cfg(test)] mod encode_tests { use super::*; - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use pretty_assertions::assert_eq; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -719,10 +1200,10 @@ mod encode_tests { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; - fn signed(buf: &[u8]) -> Bytes { + fn signed(buf: &[u8]) -> BytesMut { let mut bytes = BytesMut::from(&SIG[..]); bytes.extend_from_slice(buf); - bytes.freeze() + bytes } #[test] @@ -732,8 +1213,9 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 0, 0, 0, 0][..]), + Ok(signed(&[2 << 4, 0, 0, 0][..])), ); assert_eq!( @@ -741,8 +1223,9 @@ mod encode_tests { ProxyCommand::Proxy, ProxyTransportProtocol::Unspec, ProxyAddresses::Unspec, + &[], ), - signed(&[(2 << 4) | 1, 0, 0, 0][..]), + Ok(signed(&[(2 << 4) | 1, 0, 0, 0][..])), ); assert_eq!( encode( @@ -752,11 +1235,12 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, - (1 << 4) | 0, + 1 << 4, 0, 12, 1, @@ -772,7 +1256,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); } @@ -786,8 +1270,9 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 65535), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 9012), }, + &[], ), - signed( + Ok(signed( &[ (2 << 4) | 1, (1 << 4) | 1, @@ -806,7 +1291,7 @@ mod encode_tests { (9012u16 >> 8) as u8, 9012u16 as u8, ][..] - ), + )), ); assert_eq!( encode( @@ -816,10 +1301,11 @@ mod encode_tests { source: SocketAddrV4::new(Ipv4Addr::new(255, 255, 255, 255), 324), destination: SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 2187), }, + &[], ), - signed( + Ok(signed( &[ - (2 << 4) | 0, + 2 << 4, (1 << 4) | 2, 0, 12, @@ -836,7 +1322,7 @@ mod encode_tests { (2187 >> 8) as u8, 2187u16 as u8, ][..] - ), + )), ); } @@ -847,23 +1333,19 @@ mod encode_tests { ProxyCommand::Local, ProxyTransportProtocol::Datagram, ProxyAddresses::Ipv6 { - source: SocketAddrV6::new( - Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), - 8192, - 0, - 0, - ), + source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 8192, 0, 0,), destination: SocketAddrV6::new( Ipv6Addr::new(65535, 65535, 32767, 32766, 111, 222, 333, 444), 0, 0, 0, ), - } + }, + &[], ), - signed( + Ok(signed( &[ - (2 << 4) | 0, + 2 << 4, (2 << 4) | 2, 0, 36, @@ -904,7 +1386,82 @@ mod encode_tests { 0, 0, ][..] + )), + ); + } + + #[test] + fn test_tlv() { + use super::ExtensionTlv::*; + use super::SslExtensionTlv::*; + + assert_eq!( + encode( + ProxyCommand::Proxy, + ProxyTransportProtocol::Unspec, + ProxyAddresses::Unspec, + &[ + Alpn(b"h2".to_vec()), + Authority("example.org".to_string()), + Ssl(super::Ssl { + client: SslClientFlags(7), + verify: SslVerifyStatus(0), + extensions: vec![Version("TLSv1.3".to_string()),], + }), + ], ), + Ok(signed( + &[ + // Version 2, + // Proxy command + 0x21u8, + // Connection type: Unknown + 0, + // TLV length: 3 + 2 + 3 + 11 + 3 + 15 + 0, + 37, + PP2_TYPE_ALPN, + 0, + 2, + // h2 + 0x68, + 0x32, + PP2_TYPE_AUTHORITY, + 0, + 11, + // example.org + 0x65, + 0x78, + 0x61, + 0x6d, + 0x70, + 0x6c, + 0x65, + 0x2e, + 0x6f, + 0x72, + 0x67, + PP2_TYPE_SSL, + 0, + 15, + 0x07, + 0, + 0, + 0, + 0, + PP2_SUBTYPE_SSL_VERSION, + 0, + 7, + // TLSv1.3. This is from OpenSSL, to match haproxy. + 0x54, + 0x4c, + 0x53, + 0x76, + 0x31, + 0x2e, + 0x33, + ][..] + )), ); } }