From 1fd68f5b3e729140c807892b31b6104fa8bffb5a Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Tue, 24 Oct 2023 00:27:30 +0800 Subject: [PATCH 1/5] feat: add parellel util --- starky/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/starky/src/lib.rs b/starky/src/lib.rs index 1e3c24da..0443eb28 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -68,6 +68,7 @@ pub mod pil2circom; pub mod pilcom; pub mod prove; pub mod serializer; +pub(crate) mod utils; pub mod zkin_join; #[macro_use] From a9441a1077d72117f5b5fdeb8345e8be43442b7c Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Tue, 24 Oct 2023 00:27:48 +0800 Subject: [PATCH 2/5] feat: add parellel util --- starky/src/utils/mod.rs | 1 + starky/src/utils/parallells.rs | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 starky/src/utils/mod.rs create mode 100644 starky/src/utils/parallells.rs diff --git a/starky/src/utils/mod.rs b/starky/src/utils/mod.rs new file mode 100644 index 00000000..8c1eaa52 --- /dev/null +++ b/starky/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod parallells; diff --git a/starky/src/utils/parallells.rs b/starky/src/utils/parallells.rs new file mode 100644 index 00000000..481418e3 --- /dev/null +++ b/starky/src/utils/parallells.rs @@ -0,0 +1,22 @@ +use rayon::{current_num_threads, scope}; + +/// This simple utility function will parallelize an operation that is to be +/// performed over a mutable slice. +pub fn parallelize(v: &mut [T], f: F) { + let n = v.len(); + let num_threads = current_num_threads(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = 1; + } + + scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + let start = chunk_num * chunk; + f(v, start); + }); + } + }); +} From ca7e18e20b4d34e48f5cf7c59e0ca2f842e0cb08 Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Tue, 24 Oct 2023 01:01:00 +0800 Subject: [PATCH 3/5] feat: use cache BR for bit_reverse feat: parallel bit_reverse --- starky/src/fft_p.rs | 95 ++++++++++++++++++++++++++++++++++++--------- starky/src/lib.rs | 2 + 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/starky/src/fft_p.rs b/starky/src/fft_p.rs index 7887ac44..4f25b6fa 100644 --- a/starky/src/fft_p.rs +++ b/starky/src/fft_p.rs @@ -3,17 +3,67 @@ use crate::constant::{get_max_workers, MAX_OPS_PER_THREAD, MIN_OPS_PER_THREAD, S use crate::fft_worker::{fft_block, interpolate_prepare_block}; use crate::helper::log2_any; use crate::traits::FieldExtension; +use crate::utils::parallells::parallelize; use core::cmp::min; +use lazy_static::lazy_static; use rayon::prelude::*; +use std::collections::HashMap; +use std::sync::Mutex; +lazy_static! { + static ref BR_CACHE: Mutex>> = Mutex::new(HashMap::new()); +} pub fn BR(x: usize, domain_pow: usize) -> usize { assert!(domain_pow <= 32); - let mut x = x; - x = (x >> 16) | (x << 16); - x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8); - x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4); - x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2); - (((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow) + let cal = |x: usize, domain_pow: usize| -> usize { + let mut x = x; + x = (x >> 16) | (x << 16); + x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8); + x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4); + x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2); + (((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow) + }; + + // get cache by domain_pow + let mut map = BR_CACHE.lock().unwrap(); + let mut cache = if map.contains_key(&domain_pow) { + map.remove(&domain_pow).unwrap() // get and remove the old values. + } else { + vec![] + }; + // check if need append more to cache + let cache_len = cache.len(); + let n = 1 << domain_pow; + if cache_len <= n || cache_len < x { + let end = if n >= x { n } else { x }; + // todo parallel + for i in cache_len..=end { + let a = cal(i, domain_pow); + cache.push(a); + } + } + let res = cache[x]; + // update map with cache + map.insert(domain_pow, cache); + res +} +fn BRs(start: usize, end: usize, domain_pow: usize) -> Vec { + assert!(end > start); + // 1. obtain a useless one to precompute the cache. + // to make sure the cache existed and its len >= end. + BR(end, domain_pow); + + // 2. get cache by domain_pow + let map = BR_CACHE.lock().unwrap(); + let cache = if map.contains_key(&domain_pow) { + map.get(&domain_pow).unwrap() + } else { + // double check + BR(end, domain_pow); + map.get(&domain_pow).unwrap() + }; + + (start..end).map(|i| cache[i]).collect() } pub fn transpose( @@ -44,11 +94,14 @@ pub fn bit_reverse( nbits: usize, ) { let n = 1 << nbits; - for i in 0..n { - let ri = BR(i, nbits); - for k in 0..n_pols { - buffdst[i * n_pols + k] = buffsrc[ri * n_pols + k]; - } + let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache. + + let len = n * n_pols; + assert_eq!(len, buffdst.len()); + for j in 0..len { + let i = j / n_pols; + let k = j % n_pols; + buffdst[j] = buffsrc[ris[i] * n_pols + k]; } } @@ -59,9 +112,10 @@ pub fn interpolate_bit_reverse( nbits: usize, ) { let n = 1 << nbits; + let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache. + for i in 0..n { - let ri = BR(i, nbits); - let rii = (n - ri) % n; + let rii = (n - ris[i]) % n; for k in 0..n_pols { buffdst[i * n_pols + k] = buffsrc[rii * n_pols + k]; } @@ -76,12 +130,15 @@ pub fn inv_bit_reverse( ) { let n = 1 << nbits; let n_inv = F::inv(&F::from(n)); - for i in 0..n { - let ri = BR(i, nbits); - let rii = (n - ri) % n; - for p in 0..n_pols { - buffdst[i * n_pols + p] = buffsrc[rii * n_pols + p] * n_inv; - } + let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache. + + let len = n * n_pols; + assert_eq!(len, buffdst.len()); + for j in 0..len { + let i = j / n_pols; + let k = j % n_pols; + let rii = (n - ris[i]) % n; + buffdst[j] = buffsrc[rii * n_pols + k] * n_inv; } } diff --git a/starky/src/lib.rs b/starky/src/lib.rs index 0443eb28..65c9cb82 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::needless_range_loop)] +#![allow(dead_code)] pub mod errors; pub mod polsarray; @@ -31,6 +32,7 @@ pub mod poseidon_bls12381_opt; pub mod merklehash; pub mod merklehash_bls12381; + pub mod merklehash_bn128; mod digest; From 11ae30cb580c413bb66037874fcc5b701d870217 Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Wed, 25 Oct 2023 11:34:07 +0800 Subject: [PATCH 4/5] opti poseidon hash --- starky/src/fft_p.rs | 1 - starky/src/linearhash_bls12381.rs | 10 +-- starky/src/linearhash_bn128.rs | 11 +-- starky/src/merklehash_bls12381.rs | 31 ++++--- starky/src/merklehash_bn128.rs | 36 ++++---- starky/src/poseidon_bls12381_opt.rs | 37 ++++---- starky/src/poseidon_bn128_opt.rs | 38 ++++---- starky/src/stark_gen.rs | 133 +++++++++++++++------------- starky/src/utils/parallells.rs | 8 ++ 9 files changed, 170 insertions(+), 135 deletions(-) diff --git a/starky/src/fft_p.rs b/starky/src/fft_p.rs index 4f25b6fa..39887f33 100644 --- a/starky/src/fft_p.rs +++ b/starky/src/fft_p.rs @@ -3,7 +3,6 @@ use crate::constant::{get_max_workers, MAX_OPS_PER_THREAD, MIN_OPS_PER_THREAD, S use crate::fft_worker::{fft_block, interpolate_prepare_block}; use crate::helper::log2_any; use crate::traits::FieldExtension; -use crate::utils::parallells::parallelize; use core::cmp::min; use lazy_static::lazy_static; use rayon::prelude::*; diff --git a/starky/src/linearhash_bls12381.rs b/starky/src/linearhash_bls12381.rs index 6263038c..2597f0fb 100644 --- a/starky/src/linearhash_bls12381.rs +++ b/starky/src/linearhash_bls12381.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] +use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64}; use crate::errors::Result; use crate::field_bls12381::{Fr, FrRepr}; use crate::poseidon_bls12381_opt::Poseidon; use crate::traits::MTNodeType; use crate::ElementDigest; use ff::*; -//use rayon::prelude::*; -use crate::constant::{OFFSET_BLS12381_2_128, OFFSET_BLS12381_2_64}; use plonky::field_gl::Fr as FGL; +use rayon::prelude::*; #[derive(Default)] pub struct LinearHashBLS12381 { @@ -99,7 +99,7 @@ impl LinearHashBLS12381 { ) -> Result> { assert_eq!(elems.len(), 16); let elems = elems - .iter() + .par_iter() .map(|e| Fr((*e).as_scalar::())) .collect::>(); let digest = self.h.hash(&elems, init_state)?; @@ -119,8 +119,8 @@ impl LinearHashBLS12381 { // group into 3 * 4 let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1]; - vals.chunks(3) - .zip(tmp_buf.iter_mut()) + vals.par_chunks(3) + .zip(tmp_buf.par_iter_mut()) .for_each(|(ein, eout)| { // padding zero to 4 let mut ein_4 = [FGL::ZERO; 4]; diff --git a/starky/src/linearhash_bn128.rs b/starky/src/linearhash_bn128.rs index d5fdd6bc..08ab5186 100644 --- a/starky/src/linearhash_bn128.rs +++ b/starky/src/linearhash_bn128.rs @@ -1,13 +1,14 @@ #![allow(non_snake_case)] + +use crate::constant::{OFFSET_2_128, OFFSET_2_64}; use crate::errors::Result; use crate::field_bn128::{Fr, FrRepr}; use crate::poseidon_bn128_opt::Poseidon; use crate::traits::MTNodeType; use crate::ElementDigest; use ff::*; -//use rayon::prelude::*; -use crate::constant::{OFFSET_2_128, OFFSET_2_64}; use plonky::field_gl::Fr as FGL; +use rayon::prelude::*; #[derive(Default)] pub struct LinearHashBN128 { @@ -100,7 +101,7 @@ impl LinearHashBN128 { ) -> Result> { assert_eq!(elems.len(), 16); let elems = elems - .iter() + .par_iter() .map(|e| Fr((*e).as_scalar::())) .collect::>(); let digest = self.h.hash(&elems, init_state)?; @@ -120,8 +121,8 @@ impl LinearHashBN128 { // group into 3 * 4 let mut tmp_buf = vec![Fr::zero(); (vals.len() - 1) / 3 + 1]; - vals.chunks(3) - .zip(tmp_buf.iter_mut()) + vals.par_chunks(3) + .zip(tmp_buf.par_iter_mut()) .for_each(|(ein, eout)| { // padding zero to 4 let mut ein_4 = [FGL::ZERO; 4]; diff --git a/starky/src/merklehash_bls12381.rs b/starky/src/merklehash_bls12381.rs index 0e6f8981..2b9d5af0 100644 --- a/starky/src/merklehash_bls12381.rs +++ b/starky/src/merklehash_bls12381.rs @@ -61,9 +61,9 @@ impl MerkleTreeBLS12381 { ); let out = &mut self.nodes[p_out..(p_out + n_ops)]; - out.iter_mut() - .zip(nodes) - .for_each(|(nout, nin)| *nout = nin); + out.par_iter_mut() + .zip(nodes.par_iter()) + .for_each(|(nout, nin)| *nout = *nin); Ok(()) } @@ -81,12 +81,15 @@ impl MerkleTreeBLS12381 { ); let n_ops = buff_in.len() / 16; let mut buff_out64: Vec> = vec![ElementDigest::<4>::default(); n_ops]; - buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| { - *out = self - .h - .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) - .unwrap(); - }); + buff_out64 + .par_iter_mut() + .zip(0..n_ops) + .for_each(|(out, i)| { + *out = self + .h + .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) + .unwrap(); + }); Ok(buff_out64) } @@ -192,10 +195,12 @@ impl MerkleTree for MerkleTreeBLS12381 { .zip(buff.par_chunks(n_per_thread_f * width)) .for_each(|(out, bb)| { let cur_n = bb.len() / width; - out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| { - let batch = &bb[(j * width)..((j + 1) * width)]; - *row_out = self.h.hash_element_array(batch).unwrap(); - }); + out.par_iter_mut() + .zip((0..cur_n).into_par_iter()) + .for_each(|(row_out, j)| { + let batch = &bb[(j * width)..((j + 1) * width)]; + *row_out = self.h.hash_element_array(batch).unwrap(); + }); }); } log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); diff --git a/starky/src/merklehash_bn128.rs b/starky/src/merklehash_bn128.rs index a0529f80..e4d9571d 100644 --- a/starky/src/merklehash_bn128.rs +++ b/starky/src/merklehash_bn128.rs @@ -61,9 +61,9 @@ impl MerkleTreeBN128 { ); let out = &mut self.nodes[p_out..(p_out + n_ops)]; - out.iter_mut() - .zip(nodes) - .for_each(|(nout, nin)| *nout = nin); + out.par_iter_mut() + .zip(nodes.par_iter()) + .for_each(|(nout, nin)| *nout = *nin); Ok(()) } @@ -81,12 +81,15 @@ impl MerkleTreeBN128 { ); let n_ops = buff_in.len() / 16; let mut buff_out64: Vec> = vec![ElementDigest::<4>::default(); n_ops]; - buff_out64.iter_mut().zip(0..n_ops).for_each(|(out, i)| { - *out = self - .h - .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) - .unwrap(); - }); + buff_out64 + .par_iter_mut() + .zip(0..n_ops) + .for_each(|(out, i)| { + *out = self + .h + .hash_node(&buff_in[(i * 16)..(i * 16 + 16)], &Fr::zero()) + .unwrap(); + }); Ok(buff_out64) } @@ -187,20 +190,22 @@ impl MerkleTree for MerkleTreeBN128 { } // calculate the nodes of the specific height Merkle tree let mut nodes = vec![ElementDigest::<4>::default(); get_n_nodes(height)]; - let now = Instant::now(); if !buff.is_empty() { + let now = Instant::now(); nodes .par_chunks_mut(n_per_thread_f) .zip(buff.par_chunks(n_per_thread_f * width)) .for_each(|(out, bb)| { let cur_n = bb.len() / width; - out.iter_mut().zip(0..cur_n).for_each(|(row_out, j)| { - let batch = &bb[(j * width)..((j + 1) * width)]; - *row_out = self.h.hash_element_array(batch).unwrap(); - }); + out.par_iter_mut() + .zip((0..cur_n).into_par_iter()) + .for_each(|(row_out, j)| { + let batch = &bb[(j * width)..((j + 1) * width)]; + *row_out = self.h.hash_element_array(batch).unwrap(); + }); }); + log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); } - log::debug!("linearhash time cost: {}", now.elapsed().as_secs_f64()); // merklize level self.nodes = nodes; @@ -212,6 +217,7 @@ impl MerkleTree for MerkleTreeBN128 { let mut next_n256: usize = (n256 - 1) / 16 + 1; let mut p_in: usize = 0; let mut p_out: usize = p_in + next_n256 * 16; + while n256 > 1 { let now = Instant::now(); self.merklize_level(p_in, next_n256, p_out)?; diff --git a/starky/src/poseidon_bls12381_opt.rs b/starky/src/poseidon_bls12381_opt.rs index 0a97931a..38a19799 100644 --- a/starky/src/poseidon_bls12381_opt.rs +++ b/starky/src/poseidon_bls12381_opt.rs @@ -4,6 +4,7 @@ use crate::field_bls12381::Fr; use crate::poseidon_bls12381::Constants; use crate::poseidon_bls12381_constants_opt as constants; use ff::{from_hex, Field}; +use rayon::prelude::*; pub fn load_constants() -> Constants { let (c_str, m_str, p_str, s_str) = constants::constants(); @@ -125,13 +126,13 @@ impl Poseidon { let mut state = vec![*init_state; t]; state[1..].clone_from_slice(inp); state - .iter_mut() + .par_iter_mut() .enumerate() .for_each(|(i, a)| a.add_assign(&C[i])); for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(r + 1) * t + i]); }); @@ -139,7 +140,7 @@ impl Poseidon { // state.reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero) //); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -149,20 +150,20 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 - 1 + 1) * t + i]); }); //opt let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = P[j][i]; @@ -172,8 +173,8 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); @@ -199,13 +200,13 @@ impl Poseidon { } for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]); }); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -215,16 +216,16 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); + state.par_iter_mut().for_each(Self::pow5); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; diff --git a/starky/src/poseidon_bn128_opt.rs b/starky/src/poseidon_bn128_opt.rs index 0e57f571..27b3bbc0 100644 --- a/starky/src/poseidon_bn128_opt.rs +++ b/starky/src/poseidon_bn128_opt.rs @@ -5,6 +5,8 @@ use crate::poseidon_bn128::Constants; use crate::poseidon_bn128_constants_opt as constants; use ff::{from_hex, Field}; +use rayon::prelude::*; + pub fn load_constants() -> Constants { let (c_str, m_str, p_str, s_str) = constants::constants(); let mut c: Vec> = Vec::new(); @@ -119,13 +121,13 @@ impl Poseidon { let mut state = vec![*init_state; t]; state[1..].clone_from_slice(inp); state - .iter_mut() + .par_iter_mut() .enumerate() .for_each(|(i, a)| a.add_assign(&C[i])); for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(r + 1) * t + i]); }); @@ -133,7 +135,7 @@ impl Poseidon { // state.reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero) //); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -143,20 +145,20 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 - 1 + 1) * t + i]); }); //opt let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = P[j][i]; @@ -166,8 +168,8 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); @@ -194,13 +196,13 @@ impl Poseidon { } for r in 0..(n_rounds_f / 2 - 1) { - state.iter_mut().for_each(Self::pow5); - state.iter_mut().enumerate().for_each(|(i, a)| { + state.par_iter_mut().for_each(Self::pow5); + state.par_iter_mut().enumerate().for_each(|(i, a)| { a.add_assign(&C[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]); }); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; @@ -210,16 +212,16 @@ impl Poseidon { *out = acc; }); state - .iter_mut() - .zip(tmp_state.iter()) + .par_iter_mut() + .zip(tmp_state.par_iter()) .for_each(|(out, inp)| { *out = *inp; }); } - state.iter_mut().for_each(Self::pow5); + state.par_iter_mut().for_each(Self::pow5); let sz = state.len(); - tmp_state.iter_mut().enumerate().for_each(|(i, out)| { + tmp_state.par_iter_mut().enumerate().for_each(|(i, out)| { let mut acc = Fr::zero(); for j in 0..sz { let mut tmp = M[j][i]; diff --git a/starky/src/stark_gen.rs b/starky/src/stark_gen.rs index 03757153..bcffd672 100644 --- a/starky/src/stark_gen.rs +++ b/starky/src/stark_gen.rs @@ -170,6 +170,58 @@ impl StarkContext { } } } + + pub fn init( + cm_pols: &PolsArray, + const_tree_elem_size: usize, + starkinfo: &StarkInfo, + stark_struct: &StarkStruct, + ) -> Result { + let mut ctx = StarkContext:: { + nbits: stark_struct.nBits, + nbits_ext: stark_struct.nBitsExt, + N: 1 << stark_struct.nBits, + Next: 1 << stark_struct.nBitsExt, + cm1_n: cm_pols.write_buff(), + ..Default::default() + }; + + ctx.cm2_n = vec![F::ZERO; (starkinfo.map_sectionsN.cm2_n) * ctx.N]; + ctx.cm3_n = vec![F::ZERO; (starkinfo.map_sectionsN.cm3_n) * ctx.N]; + ctx.tmpexp_n = vec![F::ZERO; (starkinfo.map_sectionsN.tmpexp_n) * ctx.N]; + + ctx.cm1_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm1_n * ctx.Next]; + ctx.cm2_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm2_n * ctx.Next]; + ctx.cm3_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm3_n * ctx.Next]; + ctx.cm4_2ns = vec![F::ZERO; starkinfo.map_sectionsN.cm4_n * ctx.Next]; + ctx.const_2ns = vec![F::ZERO; const_tree_elem_size]; + + ctx.q_2ns = vec![F::ZERO; starkinfo.q_dim * ctx.Next]; + ctx.f_2ns = vec![F::ZERO; 3 * ctx.Next]; + + ctx.x_n = vec![F::ZERO; ctx.N]; + + let mut xx = F::ONE; + // Using the precomputing value + let w_nbits: F = F::from(MG.0[ctx.nbits]); + for i in 0..ctx.N { + ctx.x_n[i] = xx; + xx *= w_nbits; + } + + let extend_bits = ctx.nbits_ext - ctx.nbits; + ctx.x_2ns = vec![F::ZERO; ctx.N << extend_bits]; + + let mut xx: F = F::from(*SHIFT); + for i in 0..(ctx.N << extend_bits) { + ctx.x_2ns[i] = xx; + xx *= F::from(MG.0[ctx.nbits_ext]); + } + + ctx.Zi = build_Zh_Inv::(ctx.nbits, extend_bits, 0); + + Ok(ctx) + } } pub struct StarkProof { @@ -197,73 +249,30 @@ impl<'a, M: MerkleTree> StarkProof { stark_struct: &StarkStruct, prover_addr: &str, ) -> Result> { - let mut ctx = StarkContext::::default(); - //log::debug!("starkinfo: {}", starkinfo); - //log::debug!("program: {}", program); - - let mut fftobj = FFT::new(); - ctx.nbits = stark_struct.nBits; - ctx.nbits_ext = stark_struct.nBitsExt; - ctx.N = 1 << stark_struct.nBits; - ctx.Next = 1 << stark_struct.nBitsExt; - assert_eq!(1 << ctx.nbits, ctx.N, "N must be a power of 2"); - - let mut n_cm = starkinfo.n_cm1; - - ctx.cm1_n = cm_pols.write_buff(); - ctx.cm2_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm2_n) * ctx.N]; - ctx.cm3_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.cm3_n) * ctx.N]; - ctx.tmpexp_n = vec![M::ExtendField::ZERO; (starkinfo.map_sectionsN.tmpexp_n) * ctx.N]; - - ctx.cm1_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm1_n * ctx.Next]; - ctx.cm2_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm2_n * ctx.Next]; - ctx.cm3_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm3_n * ctx.Next]; - ctx.cm4_2ns = vec![M::ExtendField::ZERO; starkinfo.map_sectionsN.cm4_n * ctx.Next]; - ctx.const_2ns = vec![M::ExtendField::ZERO; const_tree.element_size()]; - - ctx.q_2ns = vec![M::ExtendField::ZERO; starkinfo.q_dim * ctx.Next]; - ctx.f_2ns = vec![M::ExtendField::ZERO; 3 * ctx.Next]; - - ctx.x_n = vec![M::ExtendField::ZERO; ctx.N]; - - let mut xx = M::ExtendField::ONE; - // Using the precomputing value - let w_nbits: M::ExtendField = M::ExtendField::from(MG.0[ctx.nbits]); - for i in 0..ctx.N { - ctx.x_n[i] = xx; - xx *= w_nbits; - } - - let extend_bits = ctx.nbits_ext - ctx.nbits; - ctx.x_2ns = vec![M::ExtendField::ZERO; ctx.N << extend_bits]; - - let mut xx: M::ExtendField = M::ExtendField::from(*SHIFT); - for i in 0..(ctx.N << extend_bits) { - ctx.x_2ns[i] = xx; - xx *= M::ExtendField::from(MG.0[ctx.nbits_ext]); - } - - ctx.Zi = build_Zh_Inv::(ctx.nbits, extend_bits, 0); + let mut ctx = + StarkContext::init(cm_pols, const_tree.element_size(), starkinfo, stark_struct)?; ctx.const_n = const_pols.write_buff(); - const_tree.to_extend(&mut ctx.const_2ns); - ctx.publics = vec![M::ExtendField::ZERO; starkinfo.publics.len()]; + // todo refact code to parallel for (i, pe) in starkinfo.publics.iter().enumerate() { - if pe.polType.as_str() == "cmP" { - ctx.publics[i] = ctx.cm1_n[pe.idx * starkinfo.map_sectionsN.cm1_n + pe.polId]; + ctx.publics[i] = if pe.polType.as_str() == "cmP" { + ctx.cm1_n[pe.idx * starkinfo.map_sectionsN.cm1_n + pe.polId] } else if pe.polType.as_str() == "imP" { - ctx.publics[i] = Self::calculate_exp_at_point::( + Self::calculate_exp_at_point::( &mut ctx, starkinfo, &program.publics_code[i], pe.idx, - ); + ) } else { panic!("Invalid public type {}", pe.polType); - } + }; } + let extend_bits = ctx.nbits_ext - ctx.nbits; + const_tree.to_extend(&mut ctx.const_2ns); + let mut transcript = T::new(); for i in 0..starkinfo.publics.len() { let b = ctx.publics[i] @@ -278,12 +287,9 @@ impl<'a, M: MerkleTree> StarkProof { let tree1 = extend_and_merkelize::(&mut ctx, starkinfo, "cm1_n")?; tree1.to_extend(&mut ctx.cm1_2ns); - log::debug!( - "tree1 root: {}", - //crate::helper::fr_to_biguint(&tree1.root().into()) - tree1.root(), - ); + log::debug!("tree1 root: {}", tree1.root(),); transcript.put(&[tree1.root().as_elements().to_vec()])?; + // 2.- Caluculate plookups h1 and h2 ctx.challenge[0] = transcript.get_field(); //u ctx.challenge[1] = transcript.get_field(); //defVal @@ -293,6 +299,8 @@ impl<'a, M: MerkleTree> StarkProof { calculate_exps_parallel(&mut ctx, starkinfo, &program.step2prev, "n", "step2prev"); + let mut n_cm = starkinfo.n_cm1; + for pu in starkinfo.pu_ctx.iter() { let f_pol = get_pol(&mut ctx, starkinfo, starkinfo.exp2pol[&pu.f_exp_id]); let t_pol = get_pol(&mut ctx, starkinfo, starkinfo.exp2pol[&pu.t_exp_id]); @@ -423,6 +431,7 @@ impl<'a, M: MerkleTree> StarkProof { LpEv[i] = LpEv[i - 1] * wxis; } + let mut fftobj = FFT::new(); let LEv = fftobj.ifft(&LEv); let LpEv = fftobj.ifft(&LpEv); @@ -715,15 +724,19 @@ pub fn extend_and_merkelize( let nBitsExt = ctx.nbits_ext; let nBits = ctx.nbits; let n_pols = starkinfo.map_sectionsN.get(section_name); - let mut result = vec![M::ExtendField::ZERO; (1 << nBitsExt) * n_pols]; + + let len = (1 << nBitsExt) * n_pols; + let mut result = vec![M::ExtendField::ZERO; len]; let p = ctx.get_mut(section_name); interpolate(p, n_pols, nBits, &mut result, nBitsExt); + let mut p_be = vec![FGL::ZERO; result.len()]; p_be.par_iter_mut() .zip(result) .for_each(|(be_out, f3g_in)| { *be_out = f3g_in.to_be(); }); + let mut tree = M::new(); tree.merkelize(p_be, n_pols, 1 << nBitsExt)?; Ok(tree) diff --git a/starky/src/utils/parallells.rs b/starky/src/utils/parallells.rs index 481418e3..cb9f6486 100644 --- a/starky/src/utils/parallells.rs +++ b/starky/src/utils/parallells.rs @@ -2,6 +2,14 @@ use rayon::{current_num_threads, scope}; /// This simple utility function will parallelize an operation that is to be /// performed over a mutable slice. +/// ``` +/// parallelize(&mut p_be, |values, start| { +// for (i, v) in values.iter_mut().enumerate() { +// let idx = start + i; +// *v = result[idx].to_be(); +// } +// }); +/// ``` pub fn parallelize(v: &mut [T], f: F) { let n = v.len(); let num_threads = current_num_threads(); From 10e6dbb14d903f4f8b310a883d844f1b08f785fb Mon Sep 17 00:00:00 2001 From: Paul Cheng Date: Wed, 25 Oct 2023 22:31:21 +0800 Subject: [PATCH 5/5] opti docs --- starky/src/utils/parallells.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/starky/src/utils/parallells.rs b/starky/src/utils/parallells.rs index cb9f6486..5fcfcb67 100644 --- a/starky/src/utils/parallells.rs +++ b/starky/src/utils/parallells.rs @@ -1,15 +1,13 @@ use rayon::{current_num_threads, scope}; -/// This simple utility function will parallelize an operation that is to be -/// performed over a mutable slice. -/// ``` -/// parallelize(&mut p_be, |values, start| { +// This simple utility function will parallelize an operation that is to be +// performed over a mutable slice. +// parallelize(&mut p_be, |values, start| { // for (i, v) in values.iter_mut().enumerate() { // let idx = start + i; // *v = result[idx].to_be(); // } // }); -/// ``` pub fn parallelize(v: &mut [T], f: F) { let n = v.len(); let num_threads = current_num_threads();