diff --git a/light-poseidon/src/lib.rs b/light-poseidon/src/lib.rs index 6b51336..b18ff9a 100644 --- a/light-poseidon/src/lib.rs +++ b/light-poseidon/src/lib.rs @@ -313,6 +313,7 @@ pub struct Poseidon { params: PoseidonParameters, domain_tag: F, state: Vec, + scratch: Vec, } impl Poseidon { @@ -330,45 +331,68 @@ impl Poseidon { domain_tag, params, state: Vec::with_capacity(width), + scratch: Vec::with_capacity(width), } } #[inline(always)] fn apply_ark(&mut self, round: usize) { - self.state.iter_mut().enumerate().for_each(|(i, a)| { - let c = self.params.ark[round * self.params.width + i]; - *a += c; - }); + let width = self.params.width; + let base = round * width; + for i in 0..width { + self.state[i] += self.params.ark[base + i]; + } } #[inline(always)] fn apply_sbox_full(&mut self) { + let alpha = self.params.alpha; self.state.iter_mut().for_each(|a| { - *a = a.pow([self.params.alpha]); + let value = *a; + *a = if alpha == 5 { + pow5(value) + } else { + value.pow([alpha]) + }; }); } #[inline(always)] fn apply_sbox_partial(&mut self) { - self.state[0] = self.state[0].pow([self.params.alpha]); + let alpha = self.params.alpha; + let value = self.state[0]; + self.state[0] = if alpha == 5 { + pow5(value) + } else { + value.pow([alpha]) + }; } #[inline(always)] fn apply_mds(&mut self) { - self.state = self - .state - .iter() - .enumerate() - .map(|(i, _)| { - self.state - .iter() - .enumerate() - .fold(F::zero(), |acc, (j, a)| acc + *a * self.params.mds[i][j]) - }) - .collect(); + let width = self.params.width; + if self.scratch.len() != width { + self.scratch.resize(width, F::zero()); + } + for i in 0..width { + let mut acc = F::zero(); + let row = &self.params.mds[i]; + for (j, value) in row.iter().enumerate() { + acc += self.state[j] * *value; + } + self.scratch[i] = acc; + } + std::mem::swap(&mut self.state, &mut self.scratch); } } +#[inline(always)] +fn pow5(value: F) -> F { + let square = value.square(); + let fourth = square.square(); + fourth * value +} + impl PoseidonHasher for Poseidon { fn hash(&mut self, inputs: &[F]) -> Result { if inputs.len() != self.params.width - 1 { diff --git a/xtask/src/generate_parameters.rs b/xtask/src/generate_parameters.rs index 2fccf18..540bf29 100644 --- a/xtask/src/generate_parameters.rs +++ b/xtask/src/generate_parameters.rs @@ -65,7 +65,7 @@ pub fn generate_parameters(_opts: Options) -> Result<(), anyhow::Error> { .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .output() - .map_err(|e| anyhow::format_err!("git clone failed: {}", e.to_string()))?; + .map_err(|e| anyhow::format_err!("git clone failed: {}", e))?; } if !Path::new("./target/params").exists() { let _mkdir_result = std::process::Command::new("mkdir") @@ -73,7 +73,7 @@ pub fn generate_parameters(_opts: Options) -> Result<(), anyhow::Error> { .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .output() - .map_err(|e| anyhow::format_err!("mkdir failed: {}", e.to_string()))?; + .map_err(|e| anyhow::format_err!("mkdir failed: {}", e))?; } for i in 2..14 { let path = format!("./target/params/poseidon_params_bn254_x5_{}", i); @@ -240,7 +240,7 @@ pub fn generate_parameters(_opts: Options) -> Result<(), anyhow::Error> { std::process::Command::new("cargo") .arg("fmt") .output() - .map_err(|e| anyhow::format_err!("cargo fmt failed: {}", e.to_string()))?; + .map_err(|e| anyhow::format_err!("cargo fmt failed: {}", e))?; Ok(()) }