diff --git a/starky/src/lib.rs b/starky/src/lib.rs index 0e9540ed..9027ecf3 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -70,6 +70,7 @@ pub mod pil2circom; pub mod pilcom; pub mod prove; pub mod serializer; +pub(crate) mod utils; pub mod zkin_join; pub mod dev; diff --git a/starky/src/linearhash_bls12381.rs b/starky/src/linearhash_bls12381.rs index 6263038c..2597f0fb 100644 --- a/starky/src/linearhash_bls12381.rs +++ b/starky/src/linearhash_bls12381.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] +use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64}; use crate::errors::Result; use crate::field_bls12381::{Fr, FrRepr}; use crate::poseidon_bls12381_opt::Poseidon; use crate::traits::MTNodeType; use crate::ElementDigest; use ff::*; -//use rayon::prelude::*; -use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64}; use plonky::field_gl::Fr as FGL; +use rayon::prelude::*; #[derive(Default)] pub struct LinearHashBLS12381 { @@ -99,7 +99,7 @@ impl LinearHashBLS12381 { ) -> Result> { assert_eq!(elems.len(), 16); let elems = elems - .iter() + .par_iter() .map(|e| Fr((*e).as_scalar::())) .collect::>(); let digest = self.h.hash(&elems, init_state)?; @@ -119,8 +119,8 @@ impl LinearHashBLS12381 { // group into 3 * 4 let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1]; - vals.chunks(3) - .zip(tmp_buf.iter_mut()) + vals.par_chunks(3) + .zip(tmp_buf.par_iter_mut()) .for_each(|(ein, eout)| { // padding zero to 4 let mut ein_4 = [FGL::ZERO; 4]; diff --git a/starky/src/linearhash_bn128.rs b/starky/src/linearhash_bn128.rs index d5fdd6bc..08ab5186 100644 --- a/starky/src/linearhash_bn128.rs +++ b/starky/src/linearhash_bn128.rs @@ -1,13 +1,14 @@ #![allow(non_snake_case)] + +use crate::constant::{OFFSET_2_128, OFFSET_2_64}; use crate::errors::Result; use crate::field_bn128::{Fr, FrRepr}; use crate::poseidon_bn128_opt::Poseidon; use crate::traits::MTNodeType; use crate::ElementDigest; use ff::*; -//use rayon::prelude::*; -use crate::constant::{OFFSET_2_128, OFFSET_2_64}; use plonky::field_gl::Fr as FGL; +use rayon::prelude::*; #[derive(Default)] pub struct LinearHashBN128 { @@ -100,7 +101,7 @@ impl LinearHashBN128 { ) -> Result> { assert_eq!(elems.len(), 16); let elems = elems - .iter() + .par_iter() .map(|e| Fr((*e).as_scalar::())) .collect::>(); let digest = self.h.hash(&elems, init_state)?; @@ -120,8 +121,8 @@ impl LinearHashBN128 { // group into 3 * 4 let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1]; - vals.chunks(3) - .zip(tmp_buf.iter_mut()) + vals.par_chunks(3) + .zip(tmp_buf.par_iter_mut()) .for_each(|(ein, eout)| { // padding zero to 4 let mut ein_4 = [FGL::ZERO; 4]; diff --git a/starky/src/merklehash_bls12381.rs b/starky/src/merklehash_bls12381.rs index 0e6f8981..2b9d5af0 100644 --- a/starky/src/merklehash_bls12381.rs +++ b/starky/src/merklehash_bls12381.rs @@ -61,9 +61,9 @@ impl MerkleTreeBLS12381 { ); let out = &mut self.nodes[p_out..(p_out + n_ops)]; - out.iter_mut() - .zip(nodes) - .for_each(|(nout, nin)| *nout = nin); + out.par_iter_mut() + .zip(nodes.par_iter()) + .for_each(|(nout, nin)| *nout = *nin); Ok(()) } @@ -81,12 +81,15 @@ impl MerkleTreeBLS12381 { ); let n_ops = buff_in.len() / 16; let mut buff_out64: Vec> = vec![ElementDigest::<4>::default(); n_ops]; - buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| { - *out = self - .h - .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) - .unwrap(); - }); + buff_out64 + .par_iter_mut() + .zip(0..n_ops) + .for_each(|(out, i)| { + *out = self + .h + .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) + .unwrap(); + }); Ok(buff_out64) } @@ -192,10 +195,12 @@ impl MerkleTree for MerkleTreeBLS12381 { .zip(buff.par_chunks(n_per_thread_f * width)) .for_each(|(out, bb)| { let cur_n = bb.len() / width; - out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| { - let batch = &bb[(j * width)..((j + 1) * width)]; - *row_out = self.h.hash_element_array(batch).unwrap(); - }); + out.par_iter_mut() + .zip((0..cur_n).into_par_iter()) + .for_each(|(row_out, j)| { + let batch = &bb[(j * width)..((j + 1) * width)]; + *row_out = self.h.hash_element_array(batch).unwrap(); + }); }); } log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); diff --git a/starky/src/merklehash_bn128.rs b/starky/src/merklehash_bn128.rs index a0529f80..e4d9571d 100644 --- a/starky/src/merklehash_bn128.rs +++ b/starky/src/merklehash_bn128.rs @@ -61,9 +61,9 @@ impl MerkleTreeBN128 { ); let out = &mut self.nodes[p_out..(p_out + n_ops)]; - out.iter_mut() - .zip(nodes) - .for_each(|(nout, nin)| *nout = nin); + out.par_iter_mut() + .zip(nodes.par_iter()) + .for_each(|(nout, nin)| *nout = *nin); Ok(()) } @@ -81,12 +81,15 @@ impl MerkleTreeBN128 { ); let n_ops = buff_in.len() / 16; let mut buff_out64: Vec> = vec![ElementDigest::<4>::default(); n_ops]; - buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| { - *out = self - .h - .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) - .unwrap(); - }); + buff_out64 + .par_iter_mut() + .zip(0..n_ops) + .for_each(|(out, i)| { + *out = self + .h + .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) + .unwrap(); + }); Ok(buff_out64) } @@ -187,20 +190,22 @@ impl MerkleTree for MerkleTreeBN128 { } // calculate the nodes of the specific height Merkle tree let mut nodes = vec![ElementDigest::<4>::default(); get_n_nodes(height)]; - let now = Instant::now(); if !buff.is_empty() { + let now = Instant::now(); nodes .par_chunks_mut(n_per_thread_f) .zip(buff.par_chunks(n_per_thread_f * width)) .for_each(|(out, bb)| { let cur_n = bb.len() / width; - out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| { - let batch = &bb[(j * width)..((j + 1) * width)]; - *row_out = self.h.hash_element_array(batch).unwrap(); - }); + out.par_iter_mut() + .zip((0..cur_n).into_par_iter()) + .for_each(|(row_out, j)| { + let batch = &bb[(j * width)..((j + 1) * width)]; + *row_out = self.h.hash_element_array(batch).unwrap(); + }); }); + log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); } - log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); // merklize level self.nodes = nodes; @@ -212,6 +217,7 @@ impl MerkleTree for MerkleTreeBN128 { let mut next_n256: usize = (n256 - 1) / 16 + 1; let mut p_in: usize = 0; let mut p_out: usize = p_in + next_n256 * 16; + while n256 > 1 { let now = Instant::now(); self.merklize_level(p_in, next_n256, p_out)?; diff --git a/starky/src/poseidon_bls12381_opt.rs b/starky/src/poseidon_bls12381_opt.rs index 0a97931a..38a19799 100644 --- a/starky/src/poseidon_bls12381_opt.rs +++ b/starky/src/poseidon_bls12381_opt.rs @@ -4,6 +4,7 @@ use crate::field_bls12381::Fr; use crate::poseidon_bls12381::Constants; use crate::poseidon_bls12381_constants_opt as constants; use ff::{from_hex, Field}; +use rayon::prelude::*; pub fn load_constants() -> Constants { let (c_str, m_str, p_str, s_str) = constants::constants(); @@ -125,13 +126,13 @@ impl Poseidon { let mut state = vec![*init_state; t]; state[1..].clone_from_slice(inp); state - .iter_mut() + .par_iter_mut() .enumerate() .for_each(|(i, a)| a.add_assign(&C[i])); for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(r + 1) * t + i]); }); @@ -139,7 +140,7 @@ impl Poseidon { // state.reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero) //); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -149,20 +150,20 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 - 1 + 1) * t + i]); }); //opt let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = P[j][i]; @@ -172,8 +173,8 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); @@ -199,13 +200,13 @@ impl Poseidon { } for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]); }); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -215,16 +216,16 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); + state.par_iter_mut().for_each(Self::pow5); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; diff --git a/starky/src/poseidon_bn128_opt.rs b/starky/src/poseidon_bn128_opt.rs index 0e57f571..27b3bbc0 100644 --- a/starky/src/poseidon_bn128_opt.rs +++ b/starky/src/poseidon_bn128_opt.rs @@ -5,6 +5,8 @@ use crate::poseidon_bn128::Constants; use crate::poseidon_bn128_constants_opt as constants; use ff::{from_hex, Field}; +use rayon::prelude::*; + pub fn load_constants() -> Constants { let (c_str, m_str, p_str, s_str) = constants::constants(); let mut c: Vec> = Vec::new(); @@ -119,13 +121,13 @@ impl Poseidon { let mut state = vec![*init_state; t]; state[1..].clone_from_slice(inp); state - .iter_mut() + .par_iter_mut() .enumerate() .for_each(|(i, a)| a.add_assign(&C[i])); for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(r + 1) * t + i]); }); @@ -133,7 +135,7 @@ impl Poseidon { // state.reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero) //); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -143,20 +145,20 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 - 1 + 1) * t + i]); }); //opt let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = P[j][i]; @@ -166,8 +168,8 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); @@ -194,13 +196,13 @@ impl Poseidon { } for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]); }); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -210,16 +212,16 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); + state.par_iter_mut().for_each(Self::pow5); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; diff --git a/starky/src/stark_gen.rs b/starky/src/stark_gen.rs index 03757153..bcffd672 100644 --- a/starky/src/stark_gen.rs +++ b/starky/src/stark_gen.rs @@ -170,6 +170,58 @@ impl StarkContext { } } } + + pub fn init( + cm_pols: &PolsArray, + const_tree_elem_size: usize, + starkinfo: &StarkInfo, + stark_struct: &StarkStruct, + ) -> Result { + let mut ctx = StarkContext:: { + nbits: stark_struct.nBits, + nbits_ext: stark_struct.nBitsExt, + N: 1 << stark_struct.nBits, + Next: 1 << stark_struct.nBitsExt, + cm1_n: cm_pols.write_buff(), + ..Default::default() + }; + + ctx.cm2_n = vec![F::ZERO; (starkinfo.map_sectionsN.cm2_n) * ctx.N]; + ctx.cm3_n = vec![F::ZERO; (starkinfo.map_sectionsN.cm3_n) * ctx.N]; + ctx.tmpexp_n = vec![F::ZERO; (starkinfo.map_sectionsN.tmpexp_n) * ctx.N]; + + ctx.cm1_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm1_n * ctx.Next]; + ctx.cm2_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm2_n * ctx.Next]; + ctx.cm3_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm3_n * ctx.Next]; + ctx.cm4_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm4_n * ctx.Next]; + ctx.const_2ns = vec![F::ZERO; const_tree_elem_size]; + + ctx.q_2ns = vec![F::ZERO; starkinfo.q_dim * ctx.Next]; + ctx.f_2ns = vec![F::ZERO; 3 * ctx.Next]; + + ctx.x_n = vec![F::ZERO; ctx.N]; + + let mut xx = F::ONE; + // Using the precomputing value + let w_nbits: F = F::from(MG.0[ctx.nbits]); + for i in 0..ctx.N { + ctx.x_n[i] = xx; + xx *= w_nbits; + } + + let extend_bits = ctx.nbits_ext - ctx.nbits; + ctx.x_2ns = vec![F::ZERO; ctx.N << extend_bits]; + + let mut xx: F = F::from(*SHIFT); + for i in 0..(ctx.N << extend_bits) { + ctx.x_2ns[i] = xx; + xx *= F::from(MG.0[ctx.nbits_ext]); + } + + ctx.Zi = build_Zh_Inv::(ctx.nbits, extend_bits, 0); + + Ok(ctx) + } } pub struct StarkProof { @@ -197,73 +249,30 @@ impl<'a, M: MerkleTree> StarkProof { stark_struct: &StarkStruct, prover_addr: &str, ) -> Result> { - let mut ctx = StarkContext::::default(); - //log::debug!("starkinfo: {}", starkinfo); - //log::debug!("program: {}", program); - - let mut fftobj = FFT::new(); - ctx.nbits = stark_struct.nBits; - ctx.nbits_ext = stark_struct.nBitsExt; - ctx.N = 1 << stark_struct.nBits; - ctx.Next = 1 << stark_struct.nBitsExt; - assert_eq!(1 << ctx.nbits, ctx.N, "N must be a power of 2"); - - let mut n_cm = starkinfo.n_cm1; - - ctx.cm1_n = cm_pols.write_buff(); - ctx.cm2_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm2_n) * ctx.N]; - ctx.cm3_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm3_n) * ctx.N]; - ctx.tmpexp_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.tmpexp_n) * ctx.N]; - - ctx.cm1_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm1_n * ctx.Next]; - ctx.cm2_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm2_n * ctx.Next]; - ctx.cm3_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm3_n * ctx.Next]; - ctx.cm4_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm4_n * ctx.Next]; - ctx.const_2ns = vec![M::ExtendField::ZERO; const_tree.element_size()]; - - ctx.q_2ns = vec![M::ExtendField::ZERO; starkinfo.q_dim * ctx.Next]; - ctx.f_2ns = vec![M::ExtendField::ZERO; 3 * ctx.Next]; - - ctx.x_n = vec![M::ExtendField::ZERO; ctx.N]; - - let mut xx = M::ExtendField::ONE; - // Using the precomputing value - let w_nbits: M::ExtendField = M::ExtendField::from(MG.0[ctx.nbits]); - for i in 0..ctx.N { - ctx.x_n[i] = xx; - xx *= w_nbits; - } - - let extend_bits = ctx.nbits_ext - ctx.nbits; - ctx.x_2ns = vec![M::ExtendField::ZERO; ctx.N << extend_bits]; - - let mut xx: M::ExtendField = M::ExtendField::from(*SHIFT); - for i in 0..(ctx.N << extend_bits) { - ctx.x_2ns[i] = xx; - xx *= M::ExtendField::from(MG.0[ctx.nbits_ext]); - } - - ctx.Zi = build_Zh_Inv::(ctx.nbits, extend_bits, 0); + let mut ctx = + StarkContext::init(cm_pols, const_tree.element_size(), starkinfo, stark_struct)?; ctx.const_n = const_pols.write_buff(); - const_tree.to_extend(&mut ctx.const_2ns); - ctx.publics = vec![M::ExtendField::ZERO; starkinfo.publics.len()]; + // todo refact code to parallel for (i, pe) in starkinfo.publics.iter().enumerate() { - if pe.polType.as_str() == "cmP" { - ctx.publics[i] = ctx.cm1_n[pe.idx * starkinfo.map_sectionsN.cm1_n + pe.polId]; + ctx.publics[i] = if pe.polType.as_str() == "cmP" { + ctx.cm1_n[pe.idx * starkinfo.map_sectionsN.cm1_n + pe.polId] } else if pe.polType.as_str() == "imP" { - ctx.publics[i] = Self::calculate_exp_at_point::( + Self::calculate_exp_at_point::( &mut ctx, starkinfo, &program.publics_code[i], pe.idx, - ); + ) } else { panic!("Invalid public type {}", pe.polType); - } + }; } + let extend_bits = ctx.nbits_ext - ctx.nbits; + const_tree.to_extend(&mut ctx.const_2ns); + let mut transcript = T::new(); for i in 0..starkinfo.publics.len() { let b = ctx.publics[i] @@ -278,12 +287,9 @@ impl<'a, M: MerkleTree> StarkProof { let tree1 = extend_and_merkelize::(&mut ctx, starkinfo, "cm1_n")?; tree1.to_extend(&mut ctx.cm1_2ns); - log::debug!( - "tree1 root: {}", - //crate::helper::fr_to_biguint(&tree1.root().into()) - tree1.root(), - ); + log::debug!("tree1 root: {}", tree1.root(),); transcript.put(&[tree1.root().as_elements().to_vec()])?; + // 2.- Caluculate plookups h1 and h2 ctx.challenge[0] = transcript.get_field(); //u ctx.challenge[1] = transcript.get_field(); //defVal @@ -293,6 +299,8 @@ impl<'a, M: MerkleTree> StarkProof { calculate_exps_parallel(&mut ctx, starkinfo, &program.step2prev, "n", "step2prev"); + let mut n_cm = starkinfo.n_cm1; + for pu in starkinfo.pu_ctx.iter() { let f_pol = get_pol(&mut ctx, starkinfo, starkinfo.exp2pol[&pu.f_exp_id]); let t_pol = get_pol(&mut ctx, starkinfo, starkinfo.exp2pol[&pu.t_exp_id]); @@ -423,6 +431,7 @@ impl<'a, M: MerkleTree> StarkProof { LpEv[i] = LpEv[i - 1] * wxis; } + let mut fftobj = FFT::new(); let LEv = fftobj.ifft(&LEv); let LpEv = fftobj.ifft(&LpEv); @@ -715,15 +724,19 @@ pub fn extend_and_merkelize( let nBitsExt = ctx.nbits_ext; let nBits = ctx.nbits; let n_pols = starkinfo.map_sectionsN.get(section_name); - let mut result = vec![M::ExtendField::ZERO; (1 << nBitsExt) * n_pols]; + + let len = (1 << nBitsExt) * n_pols; + let mut result = vec![M::ExtendField::ZERO; len]; let p = ctx.get_mut(section_name); interpolate(p, n_pols, nBits, &mut result, nBitsExt); + let mut p_be = vec![FGL::ZERO; result.len()]; p_be.par_iter_mut() .zip(result) .for_each(|(be_out, f3g_in)| { *be_out = f3g_in.to_be(); }); + let mut tree = M::new(); tree.merkelize(p_be, n_pols, 1 << nBitsExt)?; Ok(tree) diff --git a/starky/src/utils/mod.rs b/starky/src/utils/mod.rs new file mode 100644 index 00000000..8c1eaa52 --- /dev/null +++ b/starky/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod parallells; diff --git a/starky/src/utils/parallells.rs b/starky/src/utils/parallells.rs new file mode 100644 index 00000000..5fcfcb67 --- /dev/null +++ b/starky/src/utils/parallells.rs @@ -0,0 +1,28 @@ +use rayon::{current_num_threads, scope}; + +// This simple utility function will parallelize an operation that is to be +// performed over a mutable slice. +// parallelize(&mut p_be, |values, start| { +// for (i, v) in values.iter_mut().enumerate() { +// let idx = start + i; +// *v = result[idx].to_be(); +// } +// }); +pub fn parallelize(v: &mut [T], f: F) { + let n = v.len(); + let num_threads = current_num_threads(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = 1; + } + + scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + let start = chunk_num * chunk; + f(v, start); + }); + } + }); +}