diff --git a/Cargo.toml b/Cargo.toml index 0684fbca4..539afa065 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ alloc = [] bumpalo = ["dep:bumpalo", "std"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] -serde = ["dep:serde", "octseq/serde"] +serde = ["std", "dep:serde", "octseq/serde"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = ["alloc", "dep:hashbrown", "bumpalo?/std", "bytes?/std", "octseq/std", "time/std"] tracing = ["dep:log", "dep:tracing"] @@ -74,7 +74,7 @@ net = ["bytes", "futures-util", "rand", "std", "tokio"] resolv = ["net", "smallvec", "unstable-client-transport"] resolv-sync = ["resolv", "tokio/rt"] tsig = ["bytes", "constant_time_eq", "ring", "smallvec"] -zonefile = ["bytes", "serde", "std"] +zonefile = ["bytes", "serde", "std", "bumpalo"] # new: ["std", "bumpalo", "dep:time"] # Unstable features unstable-new = [] @@ -148,7 +148,7 @@ required-features = ["net", "tokio-stream", "tracing-subscriber", "unstable-clie [[example]] name = "read-zone" -required-features = ["zonefile"] +required-features = ["zonefile", "unstable-new"] [[example]] name = "query-zone" diff --git a/examples/read-zone.rs b/examples/read-zone.rs index 110aaa804..555a46086 100644 --- a/examples/read-zone.rs +++ b/examples/read-zone.rs @@ -1,12 +1,11 @@ //! Reads a zone file. -use std::env; use std::fs::File; use std::process::exit; use std::time::SystemTime; +use std::{env, io::BufReader}; -use domain::zonefile::inplace::Entry; -use domain::zonefile::inplace::Zonefile; +use domain::new::zonefile::simple::ZonefileScanner; fn main() { let mut args = env::args(); @@ -21,40 +20,17 @@ fn main() { for zone_file in zone_files { print!("Processing {zone_file}: "); let start = SystemTime::now(); - let mut reader = - Zonefile::load(&mut File::open(&zone_file).unwrap()).unwrap(); - println!( - "Data loaded ({:.03}s).", - start.elapsed().unwrap().as_secs_f32() - ); + let file = BufReader::new(File::open(&zone_file).unwrap()); + let mut scanner = ZonefileScanner::new(file, None); let mut i = 0; - let mut last_entry = None; - loop { - match reader.next_entry() { - Ok(entry) if entry.is_some() => { - last_entry = entry; - } - Ok(_) => break, // EOF - Err(err) => { - eprintln!( - "\nAn error occurred while reading {zone_file}:" - ); - eprintln!(" Error: {err}"); - if let Some(entry) = &last_entry { - if let Entry::Record(record) = &entry { - eprintln!( - "\nThe last record read was:\n{record}." - ); - } else { - eprintln!("\nThe last record read was:\n{last_entry:#?}."); - } - eprintln!("\nTry commenting out the line after that record with a leading ; (semi-colon) character.") - } - exit(1); - } - } + while let Some(entry) = scanner.scan().transpose() { i += 1; + if let Err(err) = entry { + eprintln!("Could not parse {zone_file}: {err}"); + exit(1); + } + if i % 100_000_000 == 0 { println!( "Processed {}M records ({:.03}s)", diff --git a/src/new/base/charstr.rs b/src/new/base/charstr.rs index c9cfc3721..6b45d256d 100644 --- a/src/new/base/charstr.rs +++ b/src/new/base/charstr.rs @@ -8,6 +8,9 @@ use core::str::FromStr; use crate::utils::dst::{UnsizedCopy, UnsizedCopyFrom}; +#[cfg(feature = "zonefile")] +use crate::new::zonefile::scanner::{Scan, ScanError, Scanner}; + use super::{ build::{BuildInMessage, NameCompressor}, parse::{ParseMessageBytes, SplitMessageBytes}, @@ -459,6 +462,42 @@ impl fmt::Display for CharStrParseError { } } +//--- Parsing from the zonefile format + +#[cfg(feature = "zonefile")] +impl<'a> Scan<'a> for &'a CharStr { + /// Scan a character string. + /// + /// This parses the `d-word` syntax from [the specification]. + /// + /// [the specification]: crate::new::zonefile#specification + fn scan( + scanner: &mut Scanner<'_>, + alloc: &'a bumpalo::Bump, + buffer: &mut std::vec::Vec, + ) -> Result { + let start = buffer.len(); + match scanner.scan_token(buffer)? { + Some(token) if token.len() > 255 => { + buffer.truncate(start); + Err(ScanError::Custom("overlong character string")) + } + + Some(token) => { + let bytes = alloc.alloc_slice_copy(token); + buffer.truncate(start); + // SAFETY: 'token' consists of up to 255 bytes. + Ok(unsafe { core::mem::transmute::<&[u8], Self>(bytes) }) + } + + None => { + buffer.truncate(start); + Err(ScanError::Incomplete) + } + } + } +} + //============ Tests ========================================================= #[cfg(test)] @@ -486,4 +525,28 @@ mod test { ); assert_eq!(buffer, &bytes[..6]); } + + #[cfg(feature = "zonefile")] + #[test] + fn scan() { + use crate::new::zonefile::scanner::{Scan, ScanError, Scanner}; + + let cases = [ + (b"hello" as &[u8], Ok(b"hello" as &[u8])), + (b"\"hi\"and\"bye\"" as &[u8], Ok(b"hiandbye")), + (b"\"\"" as &[u8], Ok(b"")), + (b"" as &[u8], Err(ScanError::Incomplete)), + ]; + + let alloc = bumpalo::Bump::new(); + let mut buffer = std::vec::Vec::new(); + for (input, expected) in cases { + let mut scanner = Scanner::new(input, None); + assert_eq!( + <&CharStr>::scan(&mut scanner, &alloc, &mut buffer) + .map(|c| &c.octets), + expected + ); + } + } } diff --git a/src/new/base/name/absolute.rs b/src/new/base/name/absolute.rs index eaae23db2..6b9c9f46e 100644 --- a/src/new/base/name/absolute.rs +++ b/src/new/base/name/absolute.rs @@ -21,6 +21,9 @@ use crate::{ utils::dst::{UnsizedCopy, UnsizedCopyFrom}, }; +#[cfg(feature = "zonefile")] +use crate::new::zonefile::scanner::{Scan, ScanError, Scanner}; + use super::{ CanonicalName, Label, LabelBuf, LabelIter, LabelParseError, NameCompressor, @@ -190,6 +193,26 @@ impl BuildInMessage for Name { } } +//--- Parsing from the zonefile format + +#[cfg(feature = "zonefile")] +impl<'a> Scan<'a> for &'a Name { + /// Scan a domain name token. + /// + /// This parses a domain name, following the [specification]. + /// + /// [specification]: crate::new::zonefile#specification + fn scan( + scanner: &mut Scanner<'_>, + alloc: &'a bumpalo::Bump, + buffer: &mut std::vec::Vec, + ) -> Result { + let name = NameBuf::scan(scanner, alloc, buffer)?; + let bytes = alloc.alloc_slice_copy(name.as_bytes()); + Ok(unsafe { Name::from_bytes_unchecked(bytes) }) + } +} + //--- Cloning #[cfg(feature = "alloc")] @@ -332,6 +355,24 @@ impl fmt::Debug for Name { } } +//--- Serialize + +#[cfg(feature = "serde")] +impl serde::Serialize for Name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use std::string::ToString; + + if serializer.is_human_readable() { + serializer.serialize_newtype_struct("Name", &self.to_string()) + } else { + serializer.serialize_newtype_struct("Name", self.as_bytes()) + } + } +} + //----------- NameBuf -------------------------------------------------------- /// A 256-byte buffer containing a [`Name`]. @@ -555,30 +596,93 @@ impl NameBuf { } } -//--- Parsing from strings - -impl FromStr for NameBuf { - type Err = NameParseError; +//--- Parsing from the zonefile format - /// Parse a name from a string. +#[cfg(feature = "zonefile")] +impl Scan<'_> for NameBuf { + /// Scan a domain name token. /// - /// This is intended for easily constructing hard-coded domain names. The - /// labels in the name should be given in the conventional order (i.e. not - /// reversed), and should be separated by ASCII periods. The labels will - /// be parsed using [`LabelBuf::from_str()`]; see its documentation. This - /// function cannot parse all valid domain names; if an exceptional name - /// needs to be parsed, use [`Name::from_bytes_unchecked()`]. If the - /// input is empty, the root name is returned. - fn from_str(s: &str) -> Result { + /// This parses a domain name, following the [specification]. + /// + /// [specification]: crate::new::zonefile#specification + fn scan( + scanner: &mut Scanner<'_>, + alloc: &'_ bumpalo::Bump, + buffer: &mut std::vec::Vec, + ) -> Result { + // Build up a 'Name'. let mut this = Self::empty(); - for label in s.split('.') { - let label = - label.parse::().map_err(NameParseError::Label)?; + + // Try parsing '@', indicating the origin name. + if let [b'@', b' ' | b'\t' | b'\r' | b'\n', ..] | [b'@'] = + scanner.remaining() + { + scanner.consume(1); + let origin = scanner + .origin() + .ok_or(ScanError::Custom("unknown origin name"))?; + + origin + .build_bytes(&mut this.buffer) + .expect("Valid 'RevName's are at most 255 bytes"); + this.size = origin.len() as u8; + return Ok(this); + } + + loop { + // Parse a label and prepend it to the buffer. + let label = LabelBuf::scan(scanner, alloc, buffer)?; if 255 - this.size < 1 + label.as_bytes().len() as u8 { - return Err(NameParseError::Overlong); + return Err(ScanError::Custom( + "domain name exceeds 255 bytes", + )); } this.append_label(&label); + + // Check if this is the end of the domain name. + match *scanner.remaining() { + [b' ' | b'\t' | b'\r' | b'\n', ..] | [] => { + // This is a relative domain name. + let origin = scanner + .origin() + .ok_or(ScanError::Custom("unknown origin name"))?; + + // Append the origin to this name. + origin + .build_bytes(&mut this.buffer[this.size as usize..]) + .map_err(|_| { + ScanError::Custom( + "relative domain name exceeds 255 bytes", + ) + })?; + // We exclude the root label, which gets added manually. + this.size += origin.len() as u8 - 1; + break; + } + + [b'.', b' ' | b'\t' | b'\r' | b'\n', ..] | [b'.'] => { + // This is an absolute domain name. + scanner.consume(1); + break; + } + + [b'.', ..] => { + scanner.consume(1); + } + + _ => { + return Err(ScanError::Custom( + "irregular character in domain name", + )); + } + } } + + if this.size == 0 { + return Err(ScanError::Incomplete); + } + + // Add a root label and stop. this.append_label(Label::ROOT); Ok(this) } @@ -668,6 +772,174 @@ impl fmt::Debug for NameBuf { } } +//--- Parsing from strings + +impl NameBuf { + /// Parse a domain name from the zonefile format. + pub fn parse_str(mut s: &[u8]) -> Result<(Self, &[u8]), NameParseError> { + // The buffer we'll fill into. + let mut this = Self::empty(); + + // Parse label by label. + loop { + let (label, rest) = LabelBuf::parse_str(s)?; + + if 255 - this.size < 1 + label.as_bytes().len() as u8 { + return Err(NameParseError::Overlong); + } + this.append_label(&label); + + match *rest { + [b' ' | b'\n' | b'\r' | b'\t', ..] | [] => { + s = rest; + break; + } + [b'.', ref rest @ ..] => s = rest, + _ => return Err(NameParseError::InvalidChar), + } + } + this.append_label(Label::ROOT); + + Ok((this, s)) + } +} + +impl FromStr for NameBuf { + type Err = NameParseError; + + /// Parse a name from a string. + fn from_str(s: &str) -> Result { + match Self::parse_str(s.as_bytes()) { + Ok((this, &[])) => Ok(this), + Ok(_) => Err(NameParseError::InvalidChar), + Err(err) => Err(err), + } + } +} + +//--- Serialize, Deserialize + +#[cfg(feature = "serde")] +impl serde::Serialize for NameBuf { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + (**self).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for NameBuf { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + if deserializer.is_human_readable() { + struct V; + + impl serde::de::Visitor<'_> for V { + type Value = NameBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a domain name, in the DNS zonefile format") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|err| E::custom(err)) + } + } + + struct NV; + + impl<'a> serde::de::Visitor<'a> for NV { + type Value = NameBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("an absolute domain name") + } + + fn visit_newtype_struct( + self, + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'a>, + { + deserializer.deserialize_str(V) + } + } + + deserializer.deserialize_newtype_struct("Name", NV) + } else { + struct V; + + impl serde::de::Visitor<'_> for V { + type Value = NameBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a domain name, in the DNS wire format") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + NameBuf::parse_bytes(v).map_err(|_| E::custom("misformatted domain name for the DNS wire format")) + } + } + + struct NV; + + impl<'a> serde::de::Visitor<'a> for NV { + type Value = NameBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("an absolute domain name") + } + + fn visit_newtype_struct( + self, + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'a>, + { + deserializer.deserialize_bytes(V) + } + } + + deserializer.deserialize_newtype_struct("Name", NV) + } + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for std::boxed::Box { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + NameBuf::deserialize(deserializer) + .map(|this| this.unsized_copy_into()) + } +} + //------------ NameParseError ------------------------------------------------ /// An error in parsing a [`Name`] from a string. @@ -681,10 +953,25 @@ pub enum NameParseError { /// Valid names are between 1 and 255 bytes, inclusive. Overlong, + /// The name contained an invalid character. + /// + /// Valid names contain any of the following characters: + /// - ASCII alphanumeric characters + /// - `-`, `_`, `*` (within labels) + /// - `.` (between labels) + /// - Correctly escaped characters + InvalidChar, + /// A label in the name could not be parsed. Label(LabelParseError), } +impl From for NameParseError { + fn from(value: LabelParseError) -> Self { + Self::Label(value) + } +} + // TODO(1.81.0): Use 'core::error::Error' instead. #[cfg(feature = "std")] impl std::error::Error for NameParseError {} @@ -693,11 +980,75 @@ impl fmt::Display for NameParseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { Self::Overlong => "the domain name was too long", + Self::InvalidChar | Self::Label(LabelParseError::InvalidChar) => { + "the domain name contained an invalid character" + } Self::Label(LabelParseError::Overlong) => "a label was too long", Self::Label(LabelParseError::Empty) => "a label was empty", - Self::Label(LabelParseError::InvalidChar) => { - "the domain name contained an invalid character" + Self::Label(LabelParseError::PartialEscape) => { + "a label contained an incomplete escape" + } + Self::Label(LabelParseError::InvalidEscape) => { + "a label contained an invalid escape" } }) } } + +//============ Unit tests ==================================================== + +#[cfg(test)] +mod test { + #[cfg(feature = "zonefile")] + #[test] + fn scan() { + use std::vec::Vec; + + use crate::{ + new::base::name::RevNameBuf, + new::zonefile::scanner::{Scan, ScanError, Scanner}, + }; + + use super::NameBuf; + + let cases = [ + (b"".as_slice(), Err(ScanError::Incomplete)), + (b" ".as_slice(), Err(ScanError::Incomplete)), + (b"a", Ok(&[b"a" as &[u8], b"org", b""] as &[&[u8]])), + (b"xn--hello.", Ok(&[b"xn--hello", b""])), + ( + b"hello\\.world.sld", + Ok(&[b"hello.world", b"sld", b"org", b""]), + ), + (b"a\\046b.c.", Ok(&[b"a.b", b"c", b""])), + (b"a.b\\ c.d", Ok(&[b"a", b"b c", b"d", b"org", b""])), + ]; + + let alloc = bumpalo::Bump::new(); + let mut buffer = Vec::new(); + for (input, expected) in cases { + let origin = "org".parse::().unwrap(); + let mut scanner = Scanner::new(input, Some(&origin)); + let mut name_buf = None; + let actual = NameBuf::scan(&mut scanner, &alloc, &mut buffer) + .map(|name| name_buf.insert(name).labels()); + match expected { + Ok(labels) => { + assert!( + actual.clone().is_ok_and(|actual| actual + .map(|l| &l.as_bytes()[1..]) + .eq(labels.iter().copied())), + "{actual:?} == Ok({labels:?})" + ); + } + + Err(err) => { + assert!( + actual.clone().is_err_and(|e| e == err), + "{actual:?} == Err({err:?})" + ); + } + } + } + } +} diff --git a/src/new/base/name/label.rs b/src/new/base/name/label.rs index 125c79a9e..e44d77d7e 100644 --- a/src/new/base/name/label.rs +++ b/src/new/base/name/label.rs @@ -17,6 +17,9 @@ use crate::new::base::wire::{ }; use crate::utils::dst::{UnsizedCopy, UnsizedCopyFrom}; +#[cfg(feature = "zonefile")] +use crate::new::zonefile::scanner::{Scan, ScanError, Scanner}; + //----------- Label ---------------------------------------------------------- /// A label in a domain name. @@ -152,6 +155,26 @@ impl BuildBytes for Label { } } +//--- Parsing from the zonefile format + +#[cfg(feature = "zonefile")] +impl<'a> Scan<'a> for &'a Label { + /// Scan a domain name label. + /// + /// This parses a domain name label, following the [specification]. + /// + /// [specification]: crate::new::zonefile#specification + fn scan( + scanner: &mut Scanner<'_>, + alloc: &'a bumpalo::Bump, + buffer: &mut std::vec::Vec, + ) -> Result { + let label = LabelBuf::scan(scanner, alloc, buffer)?; + let bytes = alloc.alloc_slice_copy(label.as_bytes()); + Ok(unsafe { Label::from_bytes_unchecked(bytes) }) + } +} + //--- Inspection impl Label { @@ -274,14 +297,18 @@ impl fmt::Display for Label { /// /// The label is printed in the conventional zone file format, with bytes /// outside printable ASCII formatted as `\\DDD` (a backslash followed by - /// three zero-padded decimal digits), and `.` and `\\` simply escaped by - /// a backslash. + /// three zero-padded decimal digits), and uncommon ASCII characters just + /// escaped by a backslash. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_wildcard() { + return f.write_str("*"); + } + self.contents().iter().try_for_each(|&byte| { - if b".\\".contains(&byte) { - write!(f, "\\{}", byte as char) - } else if byte.is_ascii_graphic() { + if byte.is_ascii_alphanumeric() || b"-_".contains(&byte) { write!(f, "{}", byte as char) + } else if byte.is_ascii_graphic() { + write!(f, "\\{}", byte as char) } else { write!(f, "\\{:03}", byte) } @@ -292,9 +319,25 @@ impl fmt::Display for Label { impl fmt::Debug for Label { /// Print a label for debugging purposes. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Label") - .field(&format_args!("{}", self)) - .finish() + write!(f, "Label({self})") + } +} + +//--- Serialize + +#[cfg(feature = "serde")] +impl serde::Serialize for Label { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use std::string::ToString; + + if serializer.is_human_readable() { + serializer.serialize_newtype_struct("Label", &self.to_string()) + } else { + serializer.serialize_newtype_struct("Label", self.contents()) + } } } @@ -328,35 +371,86 @@ impl UnsizedCopyFrom for LabelBuf { } } -//--- Parsing from strings +//--- Interaction -impl FromStr for LabelBuf { - type Err = LabelParseError; +impl LabelBuf { + /// Append some bytes to the [`Label`]. + /// + /// If the label would grow too large, [`TruncationError`] is returned. + #[cfg(feature = "zonefile")] + fn append(&mut self, bytes: &[u8]) -> Result<(), TruncationError> { + let len = self.data[0] as usize; + if len + bytes.len() > 63 { + return Err(TruncationError); + } - /// Parse a label from a string. + self.data[1 + len..][..bytes.len()].copy_from_slice(bytes); + self.data[0] += bytes.len() as u8; + Ok(()) + } +} + +//--- Parsing from the zonefile format + +#[cfg(feature = "zonefile")] +impl Scan<'_> for LabelBuf { + /// Scan a domain name label. /// - /// This is intended for easily constructing hard-coded labels. The input - /// is not expected to be in the zonefile format; it should simply contain - /// 1 to 63 characters, each being a plain ASCII alphanumeric or a hyphen. - /// To construct a label containing bytes outside this range, use - /// [`Label::from_bytes_unchecked()`]. To construct a root label, use - /// [`Label::ROOT`]. - fn from_str(s: &str) -> Result { - if s == "*" { - Ok(Self::copy_from(Label::WILDCARD)) - } else if !s.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-') { - Err(LabelParseError::InvalidChar) - } else if s.is_empty() { - Err(LabelParseError::Empty) - } else if s.len() > 63 { - Err(LabelParseError::Overlong) - } else { - let bytes = s.as_bytes(); - let mut data = [0u8; 64]; - data[0] = bytes.len() as u8; - data[1..1 + bytes.len()].copy_from_slice(bytes); - Ok(Self { data }) + /// This parses a domain name label, following the [specification]. + /// + /// [specification]: crate::new::zonefile#specification + fn scan( + scanner: &mut Scanner<'_>, + _alloc: &'_ bumpalo::Bump, + _buffer: &mut std::vec::Vec, + ) -> Result { + // Try parsing a wildcard label. + if let [b'*', b' ' | b'\t' | b'\r' | b'\n' | b'.', ..] | [b'*'] = + scanner.remaining() + { + scanner.consume(1); + return Ok(Self::copy_from(Label::WILDCARD)); } + + // The buffer we'll fill into. + let mut this = Self { data: [0u8; 64] }; + + // Loop through non-special chunks and special sequences. + loop { + let (chunk, first) = scanner.scan_unquoted_chunk(|&c| { + !c.is_ascii_alphanumeric() && !b"-_".contains(&c) + }); + + // Copy the non-special chunk into the buffer. + this.append(chunk).map_err(|_| { + ScanError::Custom("a domain label exceeded 63 bytes") + })?; + + // Determine the nature of the special sequence. + match first { + Some(b'"') => { + return Err(ScanError::Custom( + "a domain label was quoted", + )) + } + + Some(b'\\') => { + // An escape sequence. + scanner.consume(1); + this.append(&[scanner.scan_escape()?]).map_err(|_| { + ScanError::Custom("a domain label exceeded 63 bytes") + })?; + } + + _ => break, + } + } + + // Parse the result as a label. + if this.data[0] == 0 { + return Err(ScanError::Incomplete); + } + Ok(this) } } @@ -424,6 +518,22 @@ impl BuildBytes for LabelBuf { } } +//--- Formatting + +impl fmt::Display for LabelBuf { + /// Print a label. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Debug for LabelBuf { + /// Print a label for debugging purposes. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + //--- Access to the underlying 'Label' impl Deref for LabelBuf { @@ -498,6 +608,207 @@ impl Hash for LabelBuf { } } +//--- Parsing from strings + +impl LabelBuf { + /// Parse a label from the zonefile format. + pub fn parse_str(mut s: &[u8]) -> Result<(Self, &[u8]), LabelParseError> { + if let &[b'*', ref rest @ ..] = s { + return Ok((Self::copy_from(Label::WILDCARD), rest)); + } + + // The buffer we'll fill into. + let mut this = Self { data: [0u8; 64] }; + + // Parse character by character. + loop { + let full = s; + let &[b, ref rest @ ..] = s else { break }; + s = rest; + let value = if b.is_ascii_alphanumeric() || b"-_".contains(&b) { + // A regular label character. + b + } else if b == b'\\' { + // An escape character. + let &[b, ref rest @ ..] = s else { break }; + s = rest; + if b.is_ascii_digit() { + let digits = rest + .get(..3) + .ok_or(LabelParseError::PartialEscape)?; + let digits = core::str::from_utf8(digits) + .map_err(|_| LabelParseError::InvalidEscape)?; + digits + .parse() + .map_err(|_| LabelParseError::InvalidEscape)? + } else if b.is_ascii_graphic() { + b + } else { + return Err(LabelParseError::InvalidEscape); + } + } else if b". \n\r\t".contains(&b) { + // The label has ended. + s = full; + break; + } else { + return Err(LabelParseError::InvalidChar); + }; + + let off = this.data[0] as usize + 1; + this.data[0] += 1; + let ptr = + this.data.get_mut(off).ok_or(LabelParseError::Overlong)?; + *ptr = value; + } + + if this.data[0] == 0 { + return Err(LabelParseError::Empty); + } + + Ok((this, s)) + } +} + +impl FromStr for LabelBuf { + type Err = LabelParseError; + + /// Parse a label from a string. + fn from_str(s: &str) -> Result { + match Self::parse_str(s.as_bytes()) { + Ok((this, &[])) => Ok(this), + Ok(_) => Err(LabelParseError::InvalidChar), + Err(err) => Err(err), + } + } +} + +//--- Serialize, Deserialize + +#[cfg(feature = "serde")] +impl serde::Serialize for LabelBuf { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + (**self).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for LabelBuf { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + if deserializer.is_human_readable() { + struct V; + + impl serde::de::Visitor<'_> for V { + type Value = LabelBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a label, in the DNS zonefile format") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|err| E::custom(err)) + } + } + + struct NV; + + impl<'a> serde::de::Visitor<'a> for NV { + type Value = LabelBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a DNS label") + } + + fn visit_newtype_struct( + self, + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'a>, + { + deserializer.deserialize_str(V) + } + } + + deserializer.deserialize_newtype_struct("Label", NV) + } else { + struct V; + + impl serde::de::Visitor<'_> for V { + type Value = LabelBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a label, in the DNS wire format") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + LabelBuf::parse_bytes(v).map_err(|_| { + E::custom( + "misformatted label for the DNS wire format", + ) + }) + } + } + + struct NV; + + impl<'a> serde::de::Visitor<'a> for NV { + type Value = LabelBuf; + + fn expecting( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + f.write_str("a DNS label") + } + + fn visit_newtype_struct( + self, + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'a>, + { + deserializer.deserialize_bytes(V) + } + } + + deserializer.deserialize_newtype_struct("Label", NV) + } + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for std::boxed::Box