Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions starky/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub mod pil2circom;
pub mod pilcom;
pub mod prove;
pub mod serializer;
pub(crate) mod utils;
pub mod zkin_join;

pub mod dev;
Expand Down
10 changes: 5 additions & 5 deletions starky/src/linearhash_bls12381.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#![allow(non_snake_case)]
use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64};
use crate::errors::Result;
use crate::field_bls12381::{Fr, FrRepr};
use crate::poseidon_bls12381_opt::Poseidon;
use crate::traits::MTNodeType;
use crate::ElementDigest;
use ff::*;
//use rayon::prelude::*;
use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64};
use plonky::field_gl::Fr as FGL;
use rayon::prelude::*;

#[derive(Default)]
pub struct LinearHashBLS12381 {
Expand Down Expand Up @@ -99,7 +99,7 @@ impl LinearHashBLS12381 {
) -> Result<ElementDigest<4>> {
assert_eq!(elems.len(), 16);
let elems = elems
.iter()
.par_iter()
.map(|e| Fr((*e).as_scalar::<Fr>()))
.collect::<Vec<Fr>>();
let digest = self.h.hash(&elems, init_state)?;
Expand All @@ -119,8 +119,8 @@ impl LinearHashBLS12381 {

// group into 3 * 4
let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1];
vals.chunks(3)
.zip(tmp_buf.iter_mut())
vals.par_chunks(3)
.zip(tmp_buf.par_iter_mut())
.for_each(|(ein, eout)| {
// padding zero to 4
let mut ein_4 = [FGL::ZERO; 4];
Expand Down
11 changes: 6 additions & 5 deletions starky/src/linearhash_bn128.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#![allow(non_snake_case)]

use crate::constant::{OFFSET_2_128, OFFSET_2_64};
use crate::errors::Result;
use crate::field_bn128::{Fr, FrRepr};
use crate::poseidon_bn128_opt::Poseidon;
use crate::traits::MTNodeType;
use crate::ElementDigest;
use ff::*;
//use rayon::prelude::*;
use crate::constant::{OFFSET_2_128, OFFSET_2_64};
use plonky::field_gl::Fr as FGL;
use rayon::prelude::*;

#[derive(Default)]
pub struct LinearHashBN128 {
Expand Down Expand Up @@ -100,7 +101,7 @@ impl LinearHashBN128 {
) -> Result<ElementDigest<4>> {
assert_eq!(elems.len(), 16);
let elems = elems
.iter()
.par_iter()
.map(|e| Fr((*e).as_scalar::<Fr>()))
.collect::<Vec<Fr>>();
let digest = self.h.hash(&elems, init_state)?;
Expand All @@ -120,8 +121,8 @@ impl LinearHashBN128 {

// group into 3 * 4
let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1];
vals.chunks(3)
.zip(tmp_buf.iter_mut())
vals.par_chunks(3)
.zip(tmp_buf.par_iter_mut())
.for_each(|(ein, eout)| {
// padding zero to 4
let mut ein_4 = [FGL::ZERO; 4];
Expand Down
31 changes: 18 additions & 13 deletions starky/src/merklehash_bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl MerkleTreeBLS12381 {
);

let out = &mut self.nodes[p_out..(p_out + n_ops)];
out.iter_mut()
.zip(nodes)
.for_each(|(nout, nin)| *nout = nin);
out.par_iter_mut()
.zip(nodes.par_iter())
.for_each(|(nout, nin)| *nout = *nin);
Ok(())
}

Expand All @@ -81,12 +81,15 @@ impl MerkleTreeBLS12381 {
);
let n_ops = buff_in.len() / 16;
let mut buff_out64: Vec<ElementDigest<4>> = vec![ElementDigest::<4>::default(); n_ops];
buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| {
*out = self
.h
.hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero())
.unwrap();
});
buff_out64
.par_iter_mut()
.zip(0..n_ops)
.for_each(|(out, i)| {
*out = self
.h
.hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero())
.unwrap();
});
Ok(buff_out64)
}

Expand Down Expand Up @@ -192,10 +195,12 @@ impl MerkleTree for MerkleTreeBLS12381 {
.zip(buff.par_chunks(n_per_thread_f * width))
.for_each(|(out, bb)| {
let cur_n = bb.len() / width;
out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| {
let batch = &bb[(j * width)..((j + 1) * width)];
*row_out = self.h.hash_element_array(batch).unwrap();
});
out.par_iter_mut()
.zip((0..cur_n).into_par_iter())
.for_each(|(row_out, j)| {
let batch = &bb[(j * width)..((j + 1) * width)];
*row_out = self.h.hash_element_array(batch).unwrap();
});
});
}
log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64());
Expand Down
36 changes: 21 additions & 15 deletions starky/src/merklehash_bn128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl MerkleTreeBN128 {
);

let out = &mut self.nodes[p_out..(p_out + n_ops)];
out.iter_mut()
.zip(nodes)
.for_each(|(nout, nin)| *nout = nin);
out.par_iter_mut()
.zip(nodes.par_iter())
.for_each(|(nout, nin)| *nout = *nin);
Ok(())
}

Expand All @@ -81,12 +81,15 @@ impl MerkleTreeBN128 {
);
let n_ops = buff_in.len() / 16;
let mut buff_out64: Vec<ElementDigest<4>> = vec![ElementDigest::<4>::default(); n_ops];
buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| {
*out = self
.h
.hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero())
.unwrap();
});
buff_out64
.par_iter_mut()
.zip(0..n_ops)
.for_each(|(out, i)| {
*out = self
.h
.hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero())
.unwrap();
});
Ok(buff_out64)
}

Expand Down Expand Up @@ -187,20 +190,22 @@ impl MerkleTree for MerkleTreeBN128 {
}
// calculate the nodes of the specific height Merkle tree
let mut nodes = vec![ElementDigest::<4>::default(); get_n_nodes(height)];
let now = Instant::now();
if !buff.is_empty() {
let now = Instant::now();
nodes
.par_chunks_mut(n_per_thread_f)
.zip(buff.par_chunks(n_per_thread_f * width))
.for_each(|(out, bb)| {
let cur_n = bb.len() / width;
out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| {
let batch = &bb[(j * width)..((j + 1) * width)];
*row_out = self.h.hash_element_array(batch).unwrap();
});
out.par_iter_mut()
.zip((0..cur_n).into_par_iter())
.for_each(|(row_out, j)| {
let batch = &bb[(j * width)..((j + 1) * width)];
*row_out = self.h.hash_element_array(batch).unwrap();
});
});
log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64());
}
log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64());

// merklize level
self.nodes = nodes;
Expand All @@ -212,6 +217,7 @@ impl MerkleTree for MerkleTreeBN128 {
let mut next_n256: usize = (n256 - 1) / 16 + 1;
let mut p_in: usize = 0;
let mut p_out: usize = p_in + next_n256 * 16;

while n256 > 1 {
let now = Instant::now();
self.merklize_level(p_in, next_n256, p_out)?;
Expand Down
37 changes: 19 additions & 18 deletions starky/src/poseidon_bls12381_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::field_bls12381::Fr;
use crate::poseidon_bls12381::Constants;
use crate::poseidon_bls12381_constants_opt as constants;
use ff::{from_hex, Field};
use rayon::prelude::*;

pub fn load_constants() -> Constants {
let (c_str, m_str, p_str, s_str) = constants::constants();
Expand Down Expand Up @@ -125,21 +126,21 @@ impl Poseidon {
let mut state = vec![*init_state; t];
state[1..].clone_from_slice(inp);
state
.iter_mut()
.par_iter_mut()
.enumerate()
.for_each(|(i, a)| a.add_assign(&C[i]));

for r in 0..(n_rounds_f / 2 - 1) {
state.iter_mut().for_each(Self::pow5);
state.iter_mut().enumerate().for_each(|(i, a)| {
state.par_iter_mut().for_each(Self::pow5);
state.par_iter_mut().enumerate().for_each(|(i, a)| {
a.add_assign(&C[(r + 1) * t + i]);
});

//state = state.map((_, i) =>
// state.reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero)
//);
let sz = state.len();
tmp_state.iter_mut().enumerate().for_each(|(i, out)| {
tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| {
let mut acc = Fr::zero();
for j in 0..sz {
let mut tmp = M[j][i];
Expand All @@ -149,20 +150,20 @@ impl Poseidon {
*out = acc;
});
state
.iter_mut()
.zip(tmp_state.iter())
.par_iter_mut()
.zip(tmp_state.par_iter())
.for_each(|(out, inp)| {
*out = *inp;
});
}

state.iter_mut().for_each(Self::pow5);
state.iter_mut().enumerate().for_each(|(i, a)| {
state.par_iter_mut().for_each(Self::pow5);
state.par_iter_mut().enumerate().for_each(|(i, a)| {
a.add_assign(&C[(n_rounds_f / 2 - 1 + 1) * t + i]);
}); //opt

let sz = state.len();
tmp_state.iter_mut().enumerate().for_each(|(i, out)| {
tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| {
let mut acc = Fr::zero();
for j in 0..sz {
let mut tmp = P[j][i];
Expand All @@ -172,8 +173,8 @@ impl Poseidon {
*out = acc;
});
state
.iter_mut()
.zip(tmp_state.iter())
.par_iter_mut()
.zip(tmp_state.par_iter())
.for_each(|(out, inp)| {
*out = *inp;
});
Expand All @@ -199,13 +200,13 @@ impl Poseidon {
}

for r in 0..(n_rounds_f / 2 - 1) {
state.iter_mut().for_each(Self::pow5);
state.iter_mut().enumerate().for_each(|(i, a)| {
state.par_iter_mut().for_each(Self::pow5);
state.par_iter_mut().enumerate().for_each(|(i, a)| {
a.add_assign(&C[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]);
});

let sz = state.len();
tmp_state.iter_mut().enumerate().for_each(|(i, out)| {
tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| {
let mut acc = Fr::zero();
for j in 0..sz {
let mut tmp = M[j][i];
Expand All @@ -215,16 +216,16 @@ impl Poseidon {
*out = acc;
});
state
.iter_mut()
.zip(tmp_state.iter())
.par_iter_mut()
.zip(tmp_state.par_iter())
.for_each(|(out, inp)| {
*out = *inp;
});
}

state.iter_mut().for_each(Self::pow5);
state.par_iter_mut().for_each(Self::pow5);
let sz = state.len();
tmp_state.iter_mut().enumerate().for_each(|(i, out)| {
tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| {
let mut acc = Fr::zero();
for j in 0..sz {
let mut tmp = M[j][i];
Expand Down
Loading