Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions light-poseidon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ pub struct Poseidon<F: PrimeField> {
params: PoseidonParameters<F>,
domain_tag: F,
state: Vec<F>,
scratch: Vec<F>,
}

impl<F: PrimeField> Poseidon<F> {
Expand All @@ -330,45 +331,68 @@ impl<F: PrimeField> Poseidon<F> {
domain_tag,
params,
state: Vec::with_capacity(width),
scratch: Vec::with_capacity(width),
}
}

#[inline(always)]
fn apply_ark(&mut self, round: usize) {
self.state.iter_mut().enumerate().for_each(|(i, a)| {
let c = self.params.ark[round * self.params.width + i];
*a += c;
});
let width = self.params.width;
let base = round * width;
for i in 0..width {
self.state[i] += self.params.ark[base + i];
}
}

#[inline(always)]
fn apply_sbox_full(&mut self) {
let alpha = self.params.alpha;
self.state.iter_mut().for_each(|a| {
*a = a.pow([self.params.alpha]);
let value = *a;
*a = if alpha == 5 {
pow5(value)
} else {
value.pow([alpha])
};
});
}

#[inline(always)]
fn apply_sbox_partial(&mut self) {
self.state[0] = self.state[0].pow([self.params.alpha]);
let alpha = self.params.alpha;
let value = self.state[0];
self.state[0] = if alpha == 5 {
pow5(value)
} else {
value.pow([alpha])
};
}

#[inline(always)]
fn apply_mds(&mut self) {
self.state = self
.state
.iter()
.enumerate()
.map(|(i, _)| {
self.state
.iter()
.enumerate()
.fold(F::zero(), |acc, (j, a)| acc + *a * self.params.mds[i][j])
})
.collect();
let width = self.params.width;
if self.scratch.len() != width {
self.scratch.resize(width, F::zero());
}
for i in 0..width {
let mut acc = F::zero();
let row = &self.params.mds[i];
for (j, value) in row.iter().enumerate() {
acc += self.state[j] * *value;
}
self.scratch[i] = acc;
}
std::mem::swap(&mut self.state, &mut self.scratch);
}
}

#[inline(always)]
fn pow5<F: PrimeField>(value: F) -> F {
let square = value.square();
let fourth = square.square();
fourth * value
}

impl<F: PrimeField> PoseidonHasher<F> for Poseidon<F> {
fn hash(&mut self, inputs: &[F]) -> Result<F, PoseidonError> {
if inputs.len() != self.params.width - 1 {
Expand Down
6 changes: 3 additions & 3 deletions xtask/src/generate_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ pub fn generate_parameters(_opts: Options) -> Result<(), anyhow::Error> {
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.output()
.map_err(|e| anyhow::format_err!("git clone failed: {}", e.to_string()))?;
.map_err(|e| anyhow::format_err!("git clone failed: {}", e))?;
}
if !Path::new("./target/params").exists() {
let _mkdir_result = std::process::Command::new("mkdir")
.arg("./target/params")
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.output()
.map_err(|e| anyhow::format_err!("mkdir failed: {}", e.to_string()))?;
.map_err(|e| anyhow::format_err!("mkdir failed: {}", e))?;
}
for i in 2..14 {
let path = format!("./target/params/poseidon_params_bn254_x5_{}", i);
Expand Down Expand Up @@ -240,7 +240,7 @@ pub fn generate_parameters(_opts: Options) -> Result<(), anyhow::Error> {
std::process::Command::new("cargo")
.arg("fmt")
.output()
.map_err(|e| anyhow::format_err!("cargo fmt failed: {}", e.to_string()))?;
.map_err(|e| anyhow::format_err!("cargo fmt failed: {}", e))?;
Ok(())
}

Expand Down
Loading