diff --git a/Cargo.lock b/Cargo.lock index 979564c7..dd9c4613 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,12 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 + +[[package]] +name = "allocator-api2" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "anstream" @@ -51,6 +57,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "byteorder" version = "1.5.0" @@ -65,9 +77,9 @@ checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cc" -version = "1.1.36" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baee610e9452a8f6f0a1b6194ec09ff9e2d85dea54432acdae41aa0761c95d70" +checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" dependencies = [ "shlex", ] @@ -142,20 +154,39 @@ dependencies = [ "lexopt", "octseq", "ring", + "tempfile", ] [[package]] name = "domain" version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64008666d9f3b6a88a63cd28ad8f3a5a859b8037e11bfb680c1b24945ea1c28d" +source = "git+https://github.com/NLnetLabs/domain.git#39a04b6f6af9496a2ec91b6e5707ecf255fe0c38" dependencies = [ "bytes", + "hashbrown", "octseq", "rand", + "ring", + "serde", "time", ] +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" + [[package]] name = "getrandom" version = "0.2.15" @@ -167,6 +198,15 @@ dependencies = [ "wasi", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "allocator-api2", +] + [[package]] name = "heck" version = "0.5.0" @@ -191,6 +231,12 @@ version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "num-conv" version = "0.1.0" @@ -204,8 +250,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "126c3ca37c9c44cec575247f43a3e4374d8927684f129d2beeb0d2cef262fe12" dependencies = [ "bytes", + "serde", ] +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + [[package]] name = "powerfmt" version = "0.2.0" @@ -284,20 +337,33 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustix" +version = "0.38.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "serde" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", @@ -333,6 +399,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "time" version = "0.3.36" diff --git a/Cargo.toml b/Cargo.toml index 6ba3667b..8a3a3e86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,16 @@ path = "src/bin/ldns.rs" [dependencies] clap = { version = "4.3.4", features = ["derive"] } -domain = "0.10.1" +domain = { version = "0.10.3", git = "https://github.com/NLnetLabs/domain.git", features = [ + "zonefile", + "bytes", + "unstable-validate", +] } lexopt = "0.3.0" # for implementation of nsec3 hash until domain has it stabilized -octseq = { version = "0.5.1", features = ["std"] } +octseq = { version = "0.5.2", features = ["std"] } ring = { version = "0.17" } + +[dev-dependencies] +tempfile = "3.14.0" diff --git a/src/args.rs b/src/args.rs index 3883f518..7c868b5d 100644 --- a/src/args.rs +++ b/src/args.rs @@ -7,7 +7,7 @@ use super::error::Error; #[command(version, disable_help_subcommand = true)] pub struct Args { #[command(subcommand)] - command: Command, + pub command: Command, } impl Args { diff --git a/src/commands/key2ds.rs b/src/commands/key2ds.rs new file mode 100644 index 00000000..b10f7bcf --- /dev/null +++ b/src/commands/key2ds.rs @@ -0,0 +1,498 @@ +use std::ffi::OsString; +use std::fs::File; +use std::io::{self, Write as _}; +use std::path::PathBuf; + +use clap::builder::ValueParser; +use clap::Parser; +use domain::base::iana::{DigestAlg, SecAlg}; +use domain::base::zonefile_fmt::ZonefileFmt; +use domain::base::Record; +use domain::rdata::Ds; +use domain::validate::DnskeyExt; +use domain::zonefile::inplace::{Entry, ScannedRecordData}; +use lexopt::Arg; + +use crate::env::Env; +use crate::error::Error; + +use super::LdnsCommand; + +#[derive(Clone, Debug, Parser, PartialEq, Eq)] +#[command(version)] +pub struct Key2ds { + /// ignore SEP flag (i.e. make DS records for any key) + #[arg(long = "ignore-sep")] + ignore_sep: bool, + + /// do not write DS records to file(s) but to stdout + #[arg(short = 'n')] + write_to_stdout: bool, + + /// Overwrite existing DS files + #[arg(short = 'f', long = "force")] + force_overwrite: bool, + + /// algorithm to use for digest + #[arg( + short = 'a', + long = "algorithm", + value_parser = ValueParser::new(parse_digest_alg) + )] + algorithm: Option, + + /// Keyfile to read + #[arg()] + keyfile: PathBuf, +} + +pub fn parse_digest_alg(arg: &str) -> Result { + if let Ok(num) = arg.parse() { + let alg = DigestAlg::from_int(num); + if alg.to_mnemonic().is_some() { + Ok(alg) + } else { + Err(Error::from("unknown algorithm number")) + } + } else { + DigestAlg::from_mnemonic(arg.as_bytes()).ok_or(Error::from("unknown algorithm mnemonic")) + } +} + +const LDNS_HELP: &str = "\ +ldns-key2ds [-fn] [-1|-2|-4] keyfile + Generate a DS RR from the DNSKEYS in keyfile + The following file will be created for each key: + `K++.ds`. The base name `K++` + will be printed to stdout. + +Options: + -f: ignore SEP flag (i.e. make DS records for any key) + -n: do not write DS records to file(s) but to stdout + (default) use similar hash to the key algorithm + -1: use SHA1 for the DS hash + -2: use SHA256 for the DS hash + -4: use SHA384 for the DS hash\ +"; + +impl LdnsCommand for Key2ds { + const HELP: &'static str = LDNS_HELP; + + fn parse_ldns>(args: I) -> Result { + let mut ignore_sep = false; + let mut write_to_stdout = false; + let mut algorithm = None; + let mut keyfile = None; + + let mut parser = lexopt::Parser::from_args(args); + + while let Some(arg) = parser.next()? { + match arg { + Arg::Short('1') => algorithm = Some(DigestAlg::SHA1), + Arg::Short('2') => algorithm = Some(DigestAlg::SHA256), + Arg::Short('4') => algorithm = Some(DigestAlg::SHA384), + Arg::Short('f') => ignore_sep = true, + Arg::Short('n') => write_to_stdout = true, + Arg::Value(val) => { + if keyfile.is_some() { + return Err("Only one keyfile is allowed".into()); + } + keyfile = Some(val); + } + Arg::Short(x) => return Err(format!("Invalid short option: -{x}").into()), + Arg::Long(x) => { + return Err(format!("Long options are not supported, but `--{x}` given").into()) + } + } + } + + let Some(keyfile) = keyfile else { + return Err("No keyfile given".into()); + }; + + Ok(Self { + ignore_sep, + write_to_stdout, + algorithm, + // Preventing overwriting files is a dnst feature that is not + // present in the ldns version of this command. + force_overwrite: true, + keyfile: keyfile.into(), + }) + } +} + +impl Key2ds { + pub fn execute(self, env: impl Env) -> Result<(), Error> { + let mut file = File::open(env.in_cwd(&self.keyfile)).map_err(|e| { + format!( + "Failed to open public key file \"{}\": {e}", + self.keyfile.display() + ) + })?; + let zonefile = domain::zonefile::inplace::Zonefile::load(&mut file).unwrap(); + for entry in zonefile { + let entry = entry.map_err(|e| { + format!( + "Error while reading public key from file \"{}\": {e}", + self.keyfile.display() + ) + })?; + + // We only care about records in a zonefile + let Entry::Record(record) = entry else { + continue; + }; + + let class = record.class(); + let ttl = record.ttl(); + let owner = record.owner(); + + // Of the records that we see, we only care about DNSKEY records + let ScannedRecordData::Dnskey(dnskey) = record.data() else { + continue; + }; + + // if ignore_sep is specified, we accept any key + // otherwise, we only want SEP keys + if !self.ignore_sep && !dnskey.is_secure_entry_point() { + continue; + } + + let key_tag = dnskey.key_tag(); + let sec_alg = dnskey.algorithm(); + let digest_alg = self + .algorithm + .unwrap_or_else(|| determine_hash_from_sec_alg(sec_alg)); + + if digest_alg == DigestAlg::GOST { + return Err("Error: the GOST algorithm is deprecated and must not be used. Try a different algorithm.".into()); + } + + let digest = dnskey + .digest(&owner, digest_alg) + .map_err(|e| format!("Error computing digest: {e}"))?; + + let ds = Ds::new(key_tag, sec_alg, digest_alg, digest).expect( + "Infallible because the digest won't be too long since it's a valid digest", + ); + + let rr = Record::new(owner, class, ttl, ds); + + if self.write_to_stdout { + writeln!(env.stdout(), "{}", rr.display_zonefile(false)); + } else { + let owner = owner.fmt_with_dot(); + let sec_alg = sec_alg.to_int(); + + let keyname = format!("K{owner}+{sec_alg:03}+{key_tag:05}"); + let filename = format!("{keyname}.ds"); + + let res = if self.force_overwrite { + File::create(env.in_cwd(&filename)) + } else { + let res = File::create_new(env.in_cwd(&filename)); + + // Create a bit of a nicer message than a "File exists" IO + // error. + if let Err(e) = &res { + if e.kind() == io::ErrorKind::AlreadyExists { + return Err(format!( + "The file '{filename}' already exists, use the --force to overwrite" + ) + .into()); + } + } + + res + }; + + let mut out_file = + res.map_err(|e| format!("Could not create file \"{filename}\": {e}"))?; + + writeln!(out_file, "{}", rr.display_zonefile(false)) + .map_err(|e| format!("Could not write to file \"{filename}\": {e}"))?; + + writeln!(env.stdout(), "{keyname}"); + } + } + + Ok(()) + } +} + +fn determine_hash_from_sec_alg(sec_alg: SecAlg) -> DigestAlg { + match sec_alg { + SecAlg::RSASHA256 + | SecAlg::RSASHA512 + | SecAlg::ED25519 + | SecAlg::ED448 + | SecAlg::ECDSAP256SHA256 => DigestAlg::SHA256, + SecAlg::ECDSAP384SHA384 => DigestAlg::SHA384, + SecAlg::ECC_GOST => DigestAlg::GOST, + _ => DigestAlg::SHA1, + } +} + +#[cfg(test)] +mod test { + use domain::base::iana::DigestAlg; + use tempfile::TempDir; + + use crate::commands::Command; + use crate::env::fake::FakeCmd; + use std::fs::File; + use std::io::Write; + use std::path::PathBuf; + + use super::Key2ds; + + #[track_caller] + fn parse(args: FakeCmd) -> Key2ds { + let res = args.parse(); + let Command::Key2ds(x) = res.unwrap().command else { + panic!("Not a Key2ds!"); + }; + x + } + + #[test] + fn dnst_parse() { + let cmd = FakeCmd::new(["dnst", "key2ds"]); + + cmd.parse().unwrap_err(); + cmd.args(["keyfile1.key", "keyfile2.key"]) + .parse() + .unwrap_err(); + + let base = Key2ds { + ignore_sep: false, + write_to_stdout: false, + force_overwrite: false, + algorithm: None, + keyfile: PathBuf::from("keyfile1.key"), + }; + + // Check the defaults + let res = parse(cmd.args(["keyfile1.key"])); + assert_eq!(res, base); + + let res = parse(cmd.args(["keyfile1.key", "-f"])); + assert_eq!( + res, + Key2ds { + force_overwrite: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "--force"])); + assert_eq!( + res, + Key2ds { + force_overwrite: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "--ignore-sep"])); + assert_eq!( + res, + Key2ds { + ignore_sep: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "-n"])); + assert_eq!( + res, + Key2ds { + write_to_stdout: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "-a", "SHA-1"])); + assert_eq!( + res, + Key2ds { + algorithm: Some(DigestAlg::SHA1), + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "--algorithm", "SHA-1"])); + assert_eq!( + res, + Key2ds { + algorithm: Some(DigestAlg::SHA1), + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "--algorithm", "1"])); + assert_eq!( + res, + Key2ds { + algorithm: Some(DigestAlg::SHA1), + ..base.clone() + } + ); + } + + #[test] + fn ldns_parse() { + let cmd = FakeCmd::new(["ldns-key2ds"]); + + cmd.parse().unwrap_err(); + cmd.args(["keyfile1.key", "keyfile2.key"]) + .parse() + .unwrap_err(); + cmd.args(["-a", "keyfile2.key"]).parse().unwrap_err(); + cmd.args(["-fdoesnottakeavalue", "keyfile2.key"]) + .parse() + .unwrap_err(); + + let base = Key2ds { + ignore_sep: false, + write_to_stdout: false, + force_overwrite: true, // note that this is true + algorithm: None, + keyfile: PathBuf::from("keyfile1.key"), + }; + + // Check the defaults + let res = parse(cmd.args(["keyfile1.key"])); + assert_eq!(res, base,); + + let res = parse(cmd.args(["keyfile1.key", "-f"])); + assert_eq!( + res, + Key2ds { + ignore_sep: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "-fn"])); + assert_eq!( + res, + Key2ds { + ignore_sep: true, + write_to_stdout: true, + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "-1"])); + assert_eq!( + res, + Key2ds { + algorithm: Some(DigestAlg::SHA1), + ..base.clone() + } + ); + + let res = parse(cmd.args(["keyfile1.key", "-fnfn421"])); + assert_eq!( + res, + Key2ds { + ignore_sep: true, + write_to_stdout: true, + algorithm: Some(DigestAlg::SHA1), + ..base.clone() + } + ); + } + + fn run_setup() -> TempDir { + let dir = tempfile::TempDir::new().unwrap(); + let mut file = File::create(dir.path().join("key1.key")).unwrap(); + file + .write_all(b"example.test. IN DNSKEY 257 3 15 8AWQIqSo35guqX6WPIFsUlOnbiqGC5sydeBTVMdLGMs= ;{id = 60136 (ksk), size = 256b}\n") + .unwrap(); + + let mut file = File::create(dir.path().join("key2.key")).unwrap(); + file.write_all( + b"\ + one.test. IN DNSKEY 257 3 15 JKVltzkO0wxbjrY1dNKjEHrXvPqahmbmqwXaNrSwXsI=\n\ + two.test. IN DNSKEY 257 3 15 F0jH0dfoYXe9/tKqoghlZTY5+K/uRQReTkjvBmr7gy8=\n\ + ", + ) + .unwrap(); + + dir + } + + #[test] + fn file_with_single_key() { + let dir = run_setup(); + + let res = FakeCmd::new(["dnst", "key2ds", "key1.key"]).cwd(&dir).run(); + + assert_eq!(res.exit_code, 0, "{res:?}"); + assert_eq!(res.stdout, "Kexample.test.+015+60136\n"); + assert_eq!(res.stderr, ""); + + let out = std::fs::read_to_string(dir.path().join("Kexample.test.+015+60136.ds")).unwrap(); + assert_eq!(out, "example.test. 3600 IN DS 60136 15 2 52BD3BF40C8220BF1A3E2A3751C423BC4B69BCD7F328D38C4CD021A85DE65AD4\n"); + } + + #[test] + fn file_with_two_keys() { + let dir = run_setup(); + + let res = FakeCmd::new(["dnst", "key2ds", "key2.key"]).cwd(&dir).run(); + + assert_eq!(res.exit_code, 0, "{res:?}"); + assert_eq!(res.stdout, "Kone.test.+015+38429\nKtwo.test.+015+00425\n",); + assert_eq!(res.stderr, ""); + + let out = std::fs::read_to_string(dir.path().join("Kone.test.+015+38429.ds")).unwrap(); + assert_eq!(out, "one.test. 3600 IN DS 38429 15 2 B85F7D27C48A7B84D633C7A41C3022EA0F7FC80896227B61AE7BFC59BF5F0256\n"); + + let out = std::fs::read_to_string(dir.path().join("Ktwo.test.+015+00425.ds")).unwrap(); + assert_eq!(out, "two.test. 3600 IN DS 425 15 2 AA2030287A7C5C56CB3C0E9C64BE55616729C0C78DE2B83613D03B10C0F1EA93\n"); + } + + #[test] + fn print_to_stdout() { + let dir = run_setup(); + + let res = FakeCmd::new(["dnst", "key2ds", "-n", "key1.key"]) + .cwd(&dir) + .run(); + + assert_eq!(res.exit_code, 0); + assert_eq!( + res.stdout, + "example.test. 3600 IN DS 60136 15 2 52BD3BF40C8220BF1A3E2A3751C423BC4B69BCD7F328D38C4CD021A85DE65AD4\n" + ); + assert_eq!(res.stderr, ""); + } + + #[test] + fn overwrite_file() { + let dir = run_setup(); + + // Make sure the file already exists + File::create(dir.path().join("Kexample.test.+015+60136.ds")).unwrap(); + + let res = FakeCmd::new(["dnst", "key2ds", "key1.key"]).cwd(&dir).run(); + + assert_eq!(res.exit_code, 1); + assert_eq!(res.stdout, ""); + assert!(res.stderr.contains( + "The file 'Kexample.test.+015+60136.ds' already exists, use the --force to overwrite" + )); + + let res = FakeCmd::new(["dnst", "key2ds", "--force", "key1.key"]) + .cwd(&dir) + .run(); + + assert_eq!(res.exit_code, 0); + assert_eq!(res.stdout, "Kexample.test.+015+60136\n"); + assert_eq!(res.stderr, ""); + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index b7dbb3d8..53ac7664 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,11 +1,13 @@ //! The command of _dnst_. pub mod help; +pub mod key2ds; pub mod nsec3hash; use std::ffi::{OsStr, OsString}; use std::str::FromStr; +use key2ds::Key2ds; use nsec3hash::Nsec3Hash; use crate::env::Env; @@ -19,6 +21,14 @@ pub enum Command { #[command(name = "nsec3-hash")] Nsec3Hash(self::nsec3hash::Nsec3Hash), + /// Generate a DS RR from the DNSKEYS in keyfile + /// + /// The following file will be created for each key: + /// `K++.ds`. The base name `K++` + /// will be printed to stdout. + #[command(name = "key2ds")] + Key2ds(key2ds::Key2ds), + /// Show the manual pages Help(self::help::Help), } @@ -27,6 +37,7 @@ impl Command { pub fn execute(self, env: impl Env) -> Result<(), Error> { match self { Self::Nsec3Hash(nsec3hash) => nsec3hash.execute(env), + Self::Key2ds(key2ds) => key2ds.execute(env), Self::Help(help) => help.execute(), } } @@ -59,6 +70,12 @@ impl From for Command { } } +impl From for Command { + fn from(val: Key2ds) -> Self { + Command::Key2ds(val) + } +} + /// Utility function to parse an [`OsStr`] with a custom function fn parse_os_with(opt: &str, val: &OsStr, f: impl Fn(&str) -> Result) -> Result where diff --git a/src/env/fake.rs b/src/env/fake.rs index f9d4cde3..0715ee59 100644 --- a/src/env/fake.rs +++ b/src/env/fake.rs @@ -1,5 +1,7 @@ +use std::borrow::Cow; use std::ffi::OsString; use std::fmt; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::Mutex; @@ -16,11 +18,13 @@ use super::Stream; pub struct FakeCmd { /// The command to run, including `argv[0]` cmd: Vec, + cwd: Option, } /// The result of running a [`FakeCmd`] /// /// The fields are public to allow for easy assertions in tests. +#[derive(Debug)] pub struct FakeResult { pub exit_code: u8, pub stdout: String, @@ -53,6 +57,13 @@ impl Env for FakeEnv { fn stderr(&self) -> Stream { Stream(self.stderr.clone()) } + + fn in_cwd<'a>(&self, path: &'a impl AsRef) -> Cow<'a, Path> { + match &self.cmd.cwd { + Some(cwd) => cwd.join(path).into(), + None => path.as_ref().into(), + } + } } impl FakeCmd { @@ -62,6 +73,14 @@ impl FakeCmd { pub fn new>(cmd: impl IntoIterator) -> Self { Self { cmd: cmd.into_iter().map(Into::into).collect(), + cwd: None, + } + } + + pub fn cwd(&self, path: impl AsRef) -> Self { + Self { + cwd: Some(path.as_ref().to_path_buf()), + ..self.clone() } } diff --git a/src/env/mod.rs b/src/env/mod.rs index b00b57d3..dd62dcb1 100644 --- a/src/env/mod.rs +++ b/src/env/mod.rs @@ -1,5 +1,7 @@ +use std::borrow::Cow; use std::ffi::OsString; use std::fmt; +use std::path::Path; mod real; @@ -32,6 +34,8 @@ pub trait Env { // /// Get a reference to stdin // fn stdin(&self) -> impl io::Read; + + fn in_cwd<'a>(&self, path: &'a impl AsRef) -> Cow<'a, Path>; } /// A type with an infallible `write_fmt` method for use with [`write!`] macros @@ -73,4 +77,8 @@ impl Env for &E { fn stderr(&self) -> Stream { (**self).stderr() } + + fn in_cwd<'a>(&self, path: &'a impl AsRef) -> Cow<'a, Path> { + (**self).in_cwd(path) + } } diff --git a/src/env/real.rs b/src/env/real.rs index c854db43..26c01aa5 100644 --- a/src/env/real.rs +++ b/src/env/real.rs @@ -1,6 +1,7 @@ use std::ffi::OsString; use std::fmt; use std::io; +use std::path::Path; use super::Env; use super::Stream; @@ -20,6 +21,10 @@ impl Env for RealEnv { fn stderr(&self) -> Stream { Stream(FmtWriter(io::stderr())) } + + fn in_cwd<'a>(&self, path: &'a impl AsRef) -> std::borrow::Cow<'a, std::path::Path> { + path.as_ref().into() + } } struct FmtWriter(T); diff --git a/src/lib.rs b/src/lib.rs index 8a149998..48e40759 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ use std::ffi::OsString; use std::path::Path; use clap::Parser; -use commands::{nsec3hash::Nsec3Hash, LdnsCommand}; +use commands::{key2ds::Key2ds, nsec3hash::Nsec3Hash, LdnsCommand}; use env::Env; use error::Error; @@ -26,6 +26,7 @@ pub fn try_ldns_compatibility>( .ok_or("Binary file name is not valid unicode")?; let res = match binary_name { + "ldns-key2ds" => Key2ds::parse_ldns_args(args_iter), "ldns-nsec3-hash" => Nsec3Hash::parse_ldns_args(args_iter), _ => return Ok(None), };