diff --git a/CLAUDE.md b/CLAUDE.md index ff455863a2..19bc990f07 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -106,8 +106,8 @@ ProofTranscript: Transcript — Fiat-Shamir transcript (Blake2bTra - `DensePolynomial`: Full field-element coefficients - `CompactPolynomial`: Small scalar coefficients (u8–i128), promoted to field on bind -- `RaPolynomial`: Lazy materialization via Round1→Round2→Round3→RoundN state machine -- `SharedRaPolynomials`: Shares eq tables across N polynomials for memory efficiency +- `RaPolynomial`: Lazy materialization via table-doubling state machine (TableRound → RoundN, materializes at 8 groups) +- `SharedRaPolynomials`: Shares eq tables across N polynomials via table-doubling (TableRound → RoundN, materializes at 16 groups) - `PrefixSuffixDecomposition`: Splits polynomial as `Σ P_i(prefix) · Q_i(suffix)` for efficient sumcheck - `MultilinearPolynomial`: Enum dispatching over all scalar types + OneHot/RLC variants diff --git a/Cargo.lock b/Cargo.lock index b8adb91691..ba671dd379 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2894,6 +2894,7 @@ dependencies = [ "strum_macros 0.28.0", "sysinfo", "thiserror 2.0.18", + "tikv-jemallocator", "tracer", "tracing", "tracing-chrome", @@ -5625,6 +5626,26 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + [[package]] name = "time" version = "0.3.47" diff --git a/Cargo.toml b/Cargo.toml index 67c8a9364a..9b90dde64e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -211,6 +211,7 @@ num-traits = { version = "0.2.19", default-features = false } sysinfo = "0.38" memory-stats = { version = "1.0.0", features = ["always_use_statm"] } mimalloc = "0.1" +tikv-jemallocator = { version = "0.6", features = ["unprefixed_malloc_on_supported_platforms"] } # Parallel Processing rayon = { version = "^1.8.0" } diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index ca30335e13..3d43d0d3ca 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -38,6 +38,7 @@ prover = [ "dory/backends", "dory/cache", # This includes `parallel` feature "dory/disk-persistence", + "dep:tikv-jemallocator", ] # This is for building jolt-core without prover capabilities, e.g. for recursion # jolt-core needs std and rayon to compile, so these are the minimal set of features. @@ -89,6 +90,7 @@ jolt-inlines-keccak256 = { workspace = true, features = [ "host", ], optional = true } sysinfo = { workspace = true, optional = true } +tikv-jemallocator = { workspace = true, optional = true } pprof = { version = "0.15", features = [ "prost-codec", "flamegraph", diff --git a/jolt-core/src/bin/jolt_core.rs b/jolt-core/src/bin/jolt_core.rs index 9c36aafd3c..f727633d0c 100644 --- a/jolt-core/src/bin/jolt_core.rs +++ b/jolt-core/src/bin/jolt_core.rs @@ -1,3 +1,10 @@ +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[allow(non_upper_case_globals)] +#[unsafe(export_name = "malloc_conf")] +pub static malloc_conf: &[u8] = b"dirty_decay_ms:0,muzzy_decay_ms:0\0"; + use clap::{Args, Parser, Subcommand, ValueEnum}; #[path = "../../benches/e2e_profiling.rs"] diff --git a/jolt-core/src/poly/ra_poly.rs b/jolt-core/src/poly/ra_poly.rs index 2ef8d97c95..f49e8adb75 100644 --- a/jolt-core/src/poly/ra_poly.rs +++ b/jolt-core/src/poly/ra_poly.rs @@ -14,23 +14,27 @@ use crate::{ utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec}, }; -/// Represents the state of an `ra_i` polynomial during the last log(T) sumcheck rounds. +/// When the table round reaches this many groups, the next bind materializes to dense. +const MATERIALIZE_THRESHOLD: usize = 8; + +/// RA polynomial with lazy materialization via a table-doubling state machine. /// -/// The first two rounds are specialized to reduce the amount of allocated memory. +/// Starts with 1 eq table group. Each bind doubles the tables (splitting on a new +/// challenge bit). After reaching `MATERIALIZE_THRESHOLD` groups, the next bind +/// materializes to a dense `MultilinearPolynomial`. #[derive(Allocative, Clone, Debug, PartialEq)] pub enum RaPolynomial + Copy + Default + Send + Sync + 'static, F: JoltField> { None, - Round1(RaPolynomialRound1), - Round2(RaPolynomialRound2), - Round3(RaPolynomialRound3), + TableRound(RaPolynomialTableRound), RoundN(MultilinearPolynomial), } impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPolynomial { pub fn new(lookup_indices: Arc>>, eq_evals: Vec) -> Self { - Self::Round1(RaPolynomialRound1 { - F: eq_evals, + Self::TableRound(RaPolynomialTableRound { + tables: vec![eq_evals], lookup_indices, + binding_order: BindingOrder::LowToHigh, }) } @@ -38,9 +42,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo pub fn get_bound_coeff(&self, j: usize) -> F { match self { Self::None => panic!("RaPolynomial::get_bound_coeff called on None"), - Self::Round1(mle) => mle.get_bound_coeff(j), - Self::Round2(mle) => mle.get_bound_coeff(j), - Self::Round3(mle) => mle.get_bound_coeff(j), + Self::TableRound(t) => t.get_bound_coeff(j), Self::RoundN(mle) => mle.get_bound_coeff(j), } } @@ -48,9 +50,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo pub fn len(&self) -> usize { match self { Self::None => panic!("RaPolynomial::len called on None"), - Self::Round1(mle) => mle.len(), - Self::Round2(mle) => mle.len(), - Self::Round3(mle) => mle.len(), + Self::TableRound(t) => t.len(), Self::RoundN(mle) => mle.len(), } } @@ -60,7 +60,11 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly for RaPolynomial { fn is_bound(&self) -> bool { - !matches!(self, Self::Round1(_)) + match self { + Self::TableRound(t) => t.n_groups() > 1, + Self::RoundN(_) => true, + Self::None => false, + } } fn bind(&mut self, _r: F::Challenge, _order: BindingOrder) { @@ -70,9 +74,13 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly fn bind_parallel(&mut self, r: F::Challenge, order: BindingOrder) { match self { Self::None => panic!("RaPolynomial::bind called on None"), - Self::Round1(mle) => *self = Self::Round2(mem::take(mle).bind(r, order)), - Self::Round2(mle) => *self = Self::Round3(mem::take(mle).bind(r, order)), - Self::Round3(mle) => *self = Self::RoundN(mem::take(mle).bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= MATERIALIZE_THRESHOLD { + *self = Self::RoundN(mem::take(t).materialize(r, order)); + } else { + *self = Self::TableRound(mem::take(t).bind(r, order)); + } + } Self::RoundN(mle) => mle.bind_parallel(r, order), }; } @@ -141,197 +149,100 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly } } -/// Represents MLE `ra_i` during the 1st round of the last log(T) sumcheck rounds. +/// Generic table round for RaPolynomial with `n_groups` eq table groups. +/// +/// Tables are stored in LowToHigh interleaving order: after k binds, table at +/// index `i` corresponds to the bit pattern where bit_0 = r0_val, bit_1 = r1_val, etc. +/// (LSB-first encoding of the bound challenge values). #[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound1 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r)`. - F: Vec, +pub struct RaPolynomialTableRound< + I: Into + Copy + Default + Send + Sync + 'static, + F: JoltField, +> { + tables: Vec>, lookup_indices: Arc>>, + binding_order: BindingOrder, } impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound1 + RaPolynomialTableRound { - fn len(&self) -> usize { - self.lookup_indices.len() - } - - #[tracing::instrument(skip_all, name = "RaPolynomialRound1::bind")] - fn bind(self, r0: F::Challenge, binding_order: BindingOrder) -> RaPolynomialRound2 { - // Construct lookup tables. - let eq_0_r0 = EqPolynomial::mle(&[F::zero()], &[r0]); - let eq_1_r0 = EqPolynomial::mle(&[F::one()], &[r0]); - let F_0 = self.F.iter().map(|v| eq_0_r0 * v).collect(); - let F_1 = self.F.iter().map(|v| eq_1_r0 * v).collect(); - drop_in_background_thread(self.F); - RaPolynomialRound2 { - F_0, - F_1, - lookup_indices: self.lookup_indices, - r0, - binding_order, - } - } - #[inline] - fn get_bound_coeff(&self, j: usize) -> F { - // Lookup ra_i(r, j). - self.lookup_indices - .get(j) - .expect("j out of bounds") - .map_or(F::zero(), |i| self.F[i.into()]) + fn n_groups(&self) -> usize { + self.tables.len() } -} - -/// Represents `ra_i` during the 2nd of the last log(T) sumcheck rounds. -/// -/// i.e. represents MLE `ra_i(r, r0, x)` -#[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound2 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r_address_chunk_i) * eq(0, r0)`. - F_0: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(1, r0)`. - F_1: Vec, - lookup_indices: Arc>>, - r0: F::Challenge, - binding_order: BindingOrder, -} -impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound2 -{ fn len(&self) -> usize { - self.lookup_indices.len() / 2 + self.lookup_indices.len() / self.n_groups() } - #[tracing::instrument(skip_all, name = "RaPolynomialRound2::bind")] - fn bind(self, r1: F::Challenge, binding_order: BindingOrder) -> RaPolynomialRound3 { - assert_eq!(binding_order, self.binding_order); - // Construct lookup tables. - let eq_0_r1 = EqPolynomial::mle(&[F::zero()], &[r1]); - let eq_1_r1 = EqPolynomial::mle(&[F::one()], &[r1]); - let mut F_00: Vec = self.F_0.clone(); - let mut F_01: Vec = self.F_0; - let mut F_10: Vec = self.F_1.clone(); - let mut F_11: Vec = self.F_1; - - F_00.par_iter_mut().for_each(|f| *f *= eq_0_r1); - F_01.par_iter_mut().for_each(|f| *f *= eq_1_r1); - F_10.par_iter_mut().for_each(|f| *f *= eq_0_r1); - F_11.par_iter_mut().for_each(|f| *f *= eq_1_r1); - - RaPolynomialRound3 { - F_00, - F_01, - F_10, - F_11, - lookup_indices: self.lookup_indices, - r1, - binding_order: self.binding_order, + /// Double tables from N to 2N groups by splitting on a new challenge. + /// First N groups get scaled by eq(0, r), second N by eq(1, r). + fn double_tables(tables: Vec>, r: F::Challenge) -> Vec> { + let eq_0 = EqPolynomial::mle(&[F::zero()], &[r]); + let eq_1 = EqPolynomial::mle(&[F::one()], &[r]); + let n = tables.len(); + let mut doubled: Vec> = Vec::with_capacity(2 * n); + for t in &tables { + doubled.push(t.clone()); } - } - - #[inline] - fn get_bound_coeff(&self, j: usize) -> F { - let mid = self.lookup_indices.len() / 2; - match self.binding_order { - BindingOrder::HighToLow => { - let H_0 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_0[i.into()]); - let H_1 = self.lookup_indices[mid + j].map_or(F::zero(), |i| self.F_1[i.into()]); - // Compute ra_i(r, r0, j) = eq(0, r0) * ra_i(r, 0, j) + - // eq(1, r0) * ra_i(r, 1, j) - H_0 + H_1 - } - BindingOrder::LowToHigh => { - let H_0 = self.lookup_indices[2 * j].map_or(F::zero(), |i| self.F_0[i.into()]); - let H_1 = self.lookup_indices[2 * j + 1].map_or(F::zero(), |i| self.F_1[i.into()]); - // Compute ra_i(r, r0, j) = eq(0, r0) * ra_i(r, 0, j) + - // eq(1, r0) * ra_i(r, 1, j) - H_0 + H_1 - } + for t in tables { + doubled.push(t); } + let (lo, hi) = doubled.split_at_mut(n); + rayon::join( + || { + lo.par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0)) + }, + || { + hi.par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1)) + }, + ); + doubled } -} - -/// Represents `ra_i` during the 3nd of the last log(T) sumcheck rounds. -/// -/// i.e. represents MLE `ra_i(r, r0, x)` -#[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound3 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r_address_chunk_i) * eq(00, r0 r1)`. - F_00: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(01, r0 r1)`. - F_01: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(10, r0 r1)`. - F_10: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(11, r0 r1)`. - F_11: Vec, - lookup_indices: Arc>>, - r1: F::Challenge, - binding_order: BindingOrder, -} -impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound3 -{ - fn len(&self) -> usize { - self.lookup_indices.len() / 4 + #[tracing::instrument(skip_all, name = "RaPolynomialTableRound::bind")] + fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { + if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + } + Self { + tables: Self::double_tables(self.tables, r), + lookup_indices: self.lookup_indices, + binding_order: order, + } } - #[tracing::instrument(skip_all, name = "RaPolynomialRound3::bind")] - fn bind(self, r2: F::Challenge, _binding_order: BindingOrder) -> MultilinearPolynomial { - // Construct lookup tables. - let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); - let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); - let mut F_000: Vec = self.F_00.clone(); - let mut F_001: Vec = self.F_00; - let mut F_010: Vec = self.F_01.clone(); - let mut F_011: Vec = self.F_01; - let mut F_100: Vec = self.F_10.clone(); - let mut F_101: Vec = self.F_10; - let mut F_110: Vec = self.F_11.clone(); - let mut F_111: Vec = self.F_11; - - F_000.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_010.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_100.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_110.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_001.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_011.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_101.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_111.par_iter_mut().for_each(|f| *f *= eq_1_r2); - + #[tracing::instrument(skip_all, name = "RaPolynomialTableRound::materialize")] + fn materialize(self, r: F::Challenge, order: BindingOrder) -> MultilinearPolynomial { + let binding_order = if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + self.binding_order + } else { + order + }; + let tables = Self::double_tables(self.tables, r); + let n_groups = tables.len(); let lookup_indices = &self.lookup_indices; - let n = lookup_indices.len() / 8; - let mut res = unsafe_allocate_zero_vec(n); + let n = lookup_indices.len() / n_groups; + let mut res: Vec = unsafe_allocate_zero_vec(n); let chunk_size = 1 << 16; - - // Eval ra_i(r, r0, r1, j) for all j in the hypercube. - match self.binding_order { + match binding_order { BindingOrder::HighToLow => { + let n_bits = n_groups.trailing_zeros() as usize; res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - let H_000 = lookup_indices[j].map_or(F::zero(), |i| F_000[i.into()]); - let H_001 = - lookup_indices[j + n].map_or(F::zero(), |i| F_001[i.into()]); - let H_010 = - lookup_indices[j + n * 2].map_or(F::zero(), |i| F_010[i.into()]); - let H_011 = - lookup_indices[j + n * 3].map_or(F::zero(), |i| F_011[i.into()]); - let H_100 = - lookup_indices[j + n * 4].map_or(F::zero(), |i| F_100[i.into()]); - let H_101 = - lookup_indices[j + n * 5].map_or(F::zero(), |i| F_101[i.into()]); - let H_110 = - lookup_indices[j + n * 6].map_or(F::zero(), |i| F_110[i.into()]); - let H_111 = - lookup_indices[j + n * 7].map_or(F::zero(), |i| F_111[i.into()]); - *eval = H_000 + H_010 + H_100 + H_110 + H_001 + H_011 + H_101 + H_111; + *eval = (0..n_groups) + .map(|seg| { + let table_idx = bit_reverse(seg, n_bits); + lookup_indices[seg * n + j] + .map_or(F::zero(), |i| tables[table_idx][i.into()]) + }) + .sum(); } }, ); @@ -340,23 +251,12 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - let H_000 = - lookup_indices[8 * j].map_or(F::zero(), |i| F_000[i.into()]); - let H_100 = - lookup_indices[8 * j + 1].map_or(F::zero(), |i| F_100[i.into()]); - let H_010 = - lookup_indices[8 * j + 2].map_or(F::zero(), |i| F_010[i.into()]); - let H_110 = - lookup_indices[8 * j + 3].map_or(F::zero(), |i| F_110[i.into()]); - let H_001 = - lookup_indices[8 * j + 4].map_or(F::zero(), |i| F_001[i.into()]); - let H_101 = - lookup_indices[8 * j + 5].map_or(F::zero(), |i| F_101[i.into()]); - let H_011 = - lookup_indices[8 * j + 6].map_or(F::zero(), |i| F_011[i.into()]); - let H_111 = - lookup_indices[8 * j + 7].map_or(F::zero(), |i| F_111[i.into()]); - *eval = H_000 + H_010 + H_100 + H_110 + H_001 + H_011 + H_101 + H_111; + *eval = (0..n_groups) + .map(|offset| { + lookup_indices[n_groups * j + offset] + .map_or(F::zero(), |i| tables[offset][i.into()]) + }) + .sum(); } }, ); @@ -364,41 +264,45 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> } drop_in_background_thread(self.lookup_indices); - drop_in_background_thread(F_000); - drop_in_background_thread(F_100); - drop_in_background_thread(F_010); - drop_in_background_thread(F_110); - drop_in_background_thread(F_001); - drop_in_background_thread(F_101); - drop_in_background_thread(F_011); - drop_in_background_thread(F_111); - res.into() } #[inline] fn get_bound_coeff(&self, j: usize) -> F { + let n_groups = self.n_groups(); match self.binding_order { BindingOrder::HighToLow => { - let n = self.lookup_indices.len() / 4; - let H_00 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_01 = self.lookup_indices[j + n].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_10 = - self.lookup_indices[j + n * 2].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_11 = - self.lookup_indices[j + n * 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 - } - BindingOrder::LowToHigh => { - let H_00 = self.lookup_indices[4 * j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_10 = - self.lookup_indices[4 * j + 1].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_01 = - self.lookup_indices[4 * j + 2].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_11 = - self.lookup_indices[4 * j + 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 + let segment = self.lookup_indices.len() / n_groups; + let n_bits = n_groups.trailing_zeros() as usize; + (0..n_groups) + .map(|seg| { + let table_idx = bit_reverse(seg, n_bits); + self.lookup_indices[seg * segment + j] + .map_or(F::zero(), |i| self.tables[table_idx][i.into()]) + }) + .sum() } + BindingOrder::LowToHigh => (0..n_groups) + .map(|offset| { + self.lookup_indices[n_groups * j + offset] + .map_or(F::zero(), |i| self.tables[offset][i.into()]) + }) + .sum(), } } } + +/// Reverse the lowest `bits` bits of `x`. +#[inline] +pub(crate) fn bit_reverse(x: usize, bits: usize) -> usize { + if bits == 0 { + return 0; + } + let mut result = 0; + let mut x = x; + for _ in 0..bits { + result = (result << 1) | (x & 1); + x >>= 1; + } + result +} diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index 2eab8220e8..2a46b4f28d 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -27,10 +27,11 @@ use allocative::Allocative; use ark_std::Zero; use fixedbitset::FixedBitSet; +use std::sync::Arc; + use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; -use crate::utils::thread::drop_in_background_thread; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::OneHotParams; @@ -232,9 +233,8 @@ pub fn compute_all_G_and_ra_indices( memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, r_cycle: &[F::Challenge], -) -> (Vec>, Vec) { +) -> (Vec>, Arc>) { let T = trace.len(); - // Pre-allocate ra_indices let mut ra_indices: Vec = unsafe_allocate_zero_vec(T); let G = compute_all_G_impl::( @@ -246,7 +246,7 @@ pub fn compute_all_G_and_ra_indices( Some(&mut ra_indices), ); - (G, ra_indices) + (G, Arc::new(ra_indices)) } /// Core implementation for computing G evaluations. @@ -444,75 +444,52 @@ fn compute_all_G_impl( ) } -/// Shared RA polynomials that use a single eq table for all polynomials. +/// When the table round reaches this many groups, the next bind materializes to dense. +const SHARED_MATERIALIZE_THRESHOLD: usize = 16; + +/// Shared RA polynomials using table-doubling state machine. /// /// Instead of N separate `RaPolynomial` each with their own eq table copy, -/// this stores: -/// - ONE (small) eq table per polynomial (or split tables for later rounds) -/// - `Vec` (size T, non-transposed) -/// -/// This saves memory and improves cache locality. +/// stores per-polynomial eq tables and a single shared `Vec`. +/// Each bind doubles the table groups; once the threshold is reached, +/// the next bind materializes to dense `MultilinearPolynomial`s. #[derive(Allocative)] pub enum SharedRaPolynomials { - /// Round 1: Single shared eq table - Round1(SharedRaRound1), - /// Round 2: Split into F_0, F_1 - Round2(SharedRaRound2), - /// Round 3: Split into F_00, F_01, F_10, F_11 - Round3(SharedRaRound3), - /// Round N: Fully materialized multilinear polynomials + TableRound(SharedRaTableRound), RoundN(Vec>), } -/// Round 1 state: single shared eq table -#[derive(Allocative, Default)] -pub struct SharedRaRound1 { - /// Per-polynomial eq tables: tables[poly_idx][k] for k in 0..K - /// - /// In the booleanity sumcheck, these tables may already be pre-scaled by a per-polynomial - /// constant (e.g. a batching coefficient). - tables: Vec>, - /// RA indices for all cycles (non-transposed) - indices: Vec, - /// Number of polynomials - num_polys: usize, - /// OneHotParams for index extraction - #[allocative(skip)] - one_hot_params: OneHotParams, -} - -/// Round 2 state: split eq tables -#[derive(Allocative, Default)] -pub struct SharedRaRound2 { - /// Per-polynomial tables for the 0-branch: tables_0[poly_idx][k] - tables_0: Vec>, - /// Per-polynomial tables for the 1-branch: tables_1[poly_idx][k] - tables_1: Vec>, - /// RA indices for all cycles - indices: Vec, +/// Generic table round for SharedRaPolynomials with `n_groups` eq table groups. +/// +/// `tables[group_idx][poly_idx][k]` — tables are in LowToHigh interleaving order. +#[derive(Allocative)] +pub struct SharedRaTableRound { + tables: Vec>>, + indices: Arc>, num_polys: usize, #[allocative(skip)] one_hot_params: OneHotParams, binding_order: BindingOrder, } -/// Round 3 state: further split eq tables -#[derive(Allocative, Default)] -pub struct SharedRaRound3 { - tables_00: Vec>, - tables_01: Vec>, - tables_10: Vec>, - tables_11: Vec>, - indices: Vec, - num_polys: usize, - #[allocative(skip)] - one_hot_params: OneHotParams, - binding_order: BindingOrder, +impl Default for SharedRaTableRound { + fn default() -> Self { + Self { + tables: Vec::new(), + indices: Arc::new(Vec::new()), + num_polys: 0, + one_hot_params: OneHotParams::default(), + binding_order: BindingOrder::LowToHigh, + } + } } impl SharedRaPolynomials { - /// Create new SharedRaPolynomials from eq table and indices. - pub fn new(tables: Vec>, indices: Vec, one_hot_params: OneHotParams) -> Self { + pub fn new( + tables: Vec>, + indices: Arc>, + one_hot_params: OneHotParams, + ) -> Self { let num_polys = one_hot_params.instruction_d + one_hot_params.bytecode_d + one_hot_params.ram_d; debug_assert!( @@ -521,46 +498,53 @@ impl SharedRaPolynomials { tables.len(), num_polys ); - Self::Round1(SharedRaRound1 { - tables, + Self::TableRound(SharedRaTableRound { + tables: vec![tables], + indices, + num_polys, + one_hot_params, + binding_order: BindingOrder::LowToHigh, + }) + } + + pub fn new_instruction_only( + tables: Vec>, + indices: Arc>, + one_hot_params: OneHotParams, + ) -> Self { + let num_polys = one_hot_params.instruction_d; + debug_assert_eq!(tables.len(), num_polys); + Self::TableRound(SharedRaTableRound { + tables: vec![tables], indices, num_polys, one_hot_params, + binding_order: BindingOrder::LowToHigh, }) } - /// Get the number of polynomials pub fn num_polys(&self) -> usize { match self { - Self::Round1(r) => r.num_polys, - Self::Round2(r) => r.num_polys, - Self::Round3(r) => r.num_polys, + Self::TableRound(t) => t.num_polys, Self::RoundN(polys) => polys.len(), } } - /// Get the current length (number of cycles / 2^rounds_so_far) pub fn len(&self) -> usize { match self { - Self::Round1(r) => r.indices.len(), - Self::Round2(r) => r.indices.len() / 2, - Self::Round3(r) => r.indices.len() / 4, + Self::TableRound(t) => t.len(), Self::RoundN(polys) => polys[0].len(), } } - /// Get bound coefficient for polynomial `poly_idx` at position `j` #[inline] pub fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { match self { - Self::Round1(r) => r.get_bound_coeff(poly_idx, j), - Self::Round2(r) => r.get_bound_coeff(poly_idx, j), - Self::Round3(r) => r.get_bound_coeff(poly_idx, j), + Self::TableRound(t) => t.get_bound_coeff(poly_idx, j), Self::RoundN(polys) => polys[poly_idx].get_bound_coeff(j), } } - /// Get final sumcheck claim for polynomial `poly_idx` pub fn final_sumcheck_claim(&self, poly_idx: usize) -> F { match self { Self::RoundN(polys) => polys[poly_idx].final_sumcheck_claim(), @@ -568,13 +552,15 @@ impl SharedRaPolynomials { } } - /// Bind with a challenge, transitioning to next round state. - /// Consumes self and returns the new state. pub fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { match self { - Self::Round1(r1) => Self::Round2(r1.bind(r, order)), - Self::Round2(r2) => Self::Round3(r2.bind(r, order)), - Self::Round3(r3) => Self::RoundN(r3.bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= SHARED_MATERIALIZE_THRESHOLD { + Self::RoundN(t.materialize(r, order)) + } else { + Self::TableRound(t.bind(r, order)) + } + } Self::RoundN(mut polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); Self::RoundN(polys) @@ -582,13 +568,15 @@ impl SharedRaPolynomials { } } - /// Bind in place with a challenge, transitioning to next round state. pub fn bind_in_place(&mut self, r: F::Challenge, order: BindingOrder) { - // Use mem::take pattern (same as ra_poly.rs) for efficiency match self { - Self::Round1(r1) => *self = Self::Round2(std::mem::take(r1).bind(r, order)), - Self::Round2(r2) => *self = Self::Round3(std::mem::take(r2).bind(r, order)), - Self::Round3(r3) => *self = Self::RoundN(std::mem::take(r3).bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= SHARED_MATERIALIZE_THRESHOLD { + *self = Self::RoundN(std::mem::take(t).materialize(r, order)); + } else { + *self = Self::TableRound(std::mem::take(t).bind(r, order)); + } + } Self::RoundN(polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); } @@ -596,263 +584,105 @@ impl SharedRaPolynomials { } } -impl SharedRaRound1 { +impl SharedRaTableRound { #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[poly_idx][k as usize]) + fn n_groups(&self) -> usize { + self.tables.len() } - fn bind(self, r0: F::Challenge, order: BindingOrder) -> SharedRaRound2 { - let eq_0_r0 = EqPolynomial::mle(&[F::zero()], &[r0]); - let eq_1_r0 = EqPolynomial::mle(&[F::one()], &[r0]); - let (tables_0, tables_1) = rayon::join( - || { - self.tables - .par_iter() - .map(|t| t.iter().map(|v| eq_0_r0 * v).collect::>()) - .collect::>>() - }, - || { - self.tables - .par_iter() - .map(|t| t.iter().map(|v| eq_1_r0 * v).collect::>()) - .collect::>>() - }, - ); - drop_in_background_thread(self.tables); - - SharedRaRound2 { - tables_0, - tables_1, - indices: self.indices, - num_polys: self.num_polys, - one_hot_params: self.one_hot_params, - binding_order: order, - } + fn len(&self) -> usize { + self.indices.len() / self.n_groups() } -} -impl SharedRaRound2 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::HighToLow => { - let mid = self.indices.len() / 2; - let h_0 = self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_0[poly_idx][k as usize]); - let h_1 = self.indices[mid + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_1[poly_idx][k as usize]); - h_0 + h_1 - } - BindingOrder::LowToHigh => { - let h_0 = self.indices[2 * j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_0[poly_idx][k as usize]); - let h_1 = self.indices[2 * j + 1] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_1[poly_idx][k as usize]); - h_0 + h_1 - } + fn double_tables(tables: Vec>>, r: F::Challenge) -> Vec>> { + let eq_0 = EqPolynomial::mle(&[F::zero()], &[r]); + let eq_1 = EqPolynomial::mle(&[F::one()], &[r]); + let n = tables.len(); + let mut doubled: Vec>> = Vec::with_capacity(2 * n); + for t in &tables { + doubled.push(t.clone()); } - } - - fn bind(self, r1: F::Challenge, order: BindingOrder) -> SharedRaRound3 { - assert_eq!(order, self.binding_order); - let eq_0_r1 = EqPolynomial::mle(&[F::zero()], &[r1]); - let eq_1_r1 = EqPolynomial::mle(&[F::one()], &[r1]); - - let mut tables_00 = self.tables_0.clone(); - let mut tables_01 = self.tables_0; - let mut tables_10 = self.tables_1.clone(); - let mut tables_11 = self.tables_1; - - // Scale all four groups in parallel. + for t in tables { + doubled.push(t); + } + let (lo, hi) = doubled.split_at_mut(n); rayon::join( || { - rayon::join( - || { - tables_00 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r1)) - }, - || { - tables_01 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r1)) - }, - ) + lo.par_iter_mut().for_each(|group| { + group + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0)) + }) }, || { - rayon::join( - || { - tables_10 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r1)) - }, - || { - tables_11 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r1)) - }, - ) + hi.par_iter_mut().for_each(|group| { + group + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1)) + }) }, ); + doubled + } - SharedRaRound3 { - tables_00, - tables_01, - tables_10, - tables_11, + #[tracing::instrument(skip_all, name = "SharedRaTableRound::bind")] + fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { + if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + } + Self { + tables: Self::double_tables(self.tables, r), indices: self.indices, num_polys: self.num_polys, one_hot_params: self.one_hot_params, binding_order: order, } } -} - -impl SharedRaRound3 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::HighToLow => { - let quarter = self.indices.len() / 4; - let h_00 = self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_00[poly_idx][k as usize]); - let h_01 = self.indices[quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_01[poly_idx][k as usize]); - let h_10 = self.indices[2 * quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_10[poly_idx][k as usize]); - let h_11 = self.indices[3 * quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_11[poly_idx][k as usize]); - h_00 + h_01 + h_10 + h_11 - } - BindingOrder::LowToHigh => { - // Bit pattern for offset: (r1, r0), so offset 1 = r0=1,r1=0 → F_10 - let h_00 = self.indices[4 * j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_00[poly_idx][k as usize]); - let h_10 = self.indices[4 * j + 1] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_10[poly_idx][k as usize]); - let h_01 = self.indices[4 * j + 2] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_01[poly_idx][k as usize]); - let h_11 = self.indices[4 * j + 3] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_11[poly_idx][k as usize]); - h_00 + h_10 + h_01 + h_11 - } - } - } - - #[tracing::instrument(skip_all, name = "SharedRaRound3::bind")] - fn bind(self, r2: F::Challenge, order: BindingOrder) -> Vec> { - assert_eq!(order, self.binding_order); - - // Create 8 F tables: F_ABC where A=r0, B=r1, C=r2 - let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); - let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); - - let mut tables_000 = self.tables_00.clone(); - let mut tables_001 = self.tables_00; - let mut tables_010 = self.tables_01.clone(); - let mut tables_011 = self.tables_01; - let mut tables_100 = self.tables_10.clone(); - let mut tables_101 = self.tables_10; - let mut tables_110 = self.tables_11.clone(); - let mut tables_111 = self.tables_11; - // Scale by eq(r2, bit) - rayon::join( - || { - [ - &mut tables_000, - &mut tables_010, - &mut tables_100, - &mut tables_110, - ] - .into_par_iter() - .for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r2)) - }) - }, - || { - [ - &mut tables_001, - &mut tables_011, - &mut tables_101, - &mut tables_111, - ] - .into_par_iter() - .for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r2)) - }) - }, - ); - - // Collect all 8 table groups for indexed access: group[offset][poly_idx][k] - let table_groups = [ - &tables_000, - &tables_100, - &tables_010, - &tables_110, - &tables_001, - &tables_101, - &tables_011, - &tables_111, - ]; - - // Materialize all polynomials in parallel - let num_polys = self.num_polys; + #[tracing::instrument(skip_all, name = "SharedRaTableRound::materialize")] + fn materialize(self, r: F::Challenge, order: BindingOrder) -> Vec> { + let binding_order = if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + self.binding_order + } else { + order + }; + let tables = Self::double_tables(self.tables, r); + let n_total = tables.len(); let indices = &self.indices; let one_hot_params = &self.one_hot_params; - let new_len = indices.len() / 8; + let new_len = indices.len() / n_total; - (0..num_polys) + (0..self.num_polys) .into_par_iter() .map(|poly_idx| { - let coeffs: Vec = match order { - BindingOrder::LowToHigh => { - (0..new_len) - .into_par_iter() - .map(|j| { - // Sum over 8 consecutive indices, each using appropriate F table - (0..8) - .map(|offset| { - indices[8 * j + offset] - .get_index(poly_idx, one_hot_params) - .map_or(F::zero(), |k| { - table_groups[offset][poly_idx][k as usize] - }) - }) - .sum() - }) - .collect() - } + let coeffs: Vec = match binding_order { + BindingOrder::LowToHigh => (0..new_len) + .into_par_iter() + .map(|j| { + (0..n_total) + .map(|offset| { + indices[n_total * j + offset] + .get_index(poly_idx, one_hot_params) + .map_or(F::zero(), |k| tables[offset][poly_idx][k as usize]) + }) + .sum() + }) + .collect(), BindingOrder::HighToLow => { - let eighth = indices.len() / 8; + let n_bits = n_total.trailing_zeros() as usize; + let segment = indices.len() / n_total; (0..new_len) .into_par_iter() .map(|j| { - (0..8) + (0..n_total) .map(|seg| { - indices[seg * eighth + j] + let table_idx = + crate::poly::ra_poly::bit_reverse(seg, n_bits); + indices[seg * segment + j] .get_index(poly_idx, one_hot_params) .map_or(F::zero(), |k| { - table_groups[seg][poly_idx][k as usize] + tables[table_idx][poly_idx][k as usize] }) }) .sum() @@ -864,20 +694,30 @@ impl SharedRaRound3 { }) .collect() } -} -/// Compute all RaIndices in parallel (non-transposed). -/// -/// Returns one `RaIndices` per cycle. -#[tracing::instrument(skip_all, name = "shared_ra_polys::compute_ra_indices")] -pub fn compute_ra_indices( - trace: &[Cycle], - bytecode: &BytecodePreprocessing, - memory_layout: &MemoryLayout, - one_hot_params: &OneHotParams, -) -> Vec { - trace - .par_iter() - .map(|cycle| RaIndices::from_cycle(cycle, bytecode, memory_layout, one_hot_params)) - .collect() + #[inline] + fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { + let n_groups = self.n_groups(); + match self.binding_order { + BindingOrder::HighToLow => { + let segment = self.indices.len() / n_groups; + let n_bits = n_groups.trailing_zeros() as usize; + (0..n_groups) + .map(|seg| { + let table_idx = crate::poly::ra_poly::bit_reverse(seg, n_bits); + self.indices[seg * segment + j] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[table_idx][poly_idx][k as usize]) + }) + .sum() + } + BindingOrder::LowToHigh => (0..n_groups) + .map(|offset| { + self.indices[n_groups * j + offset] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) + }) + .sum(), + } + } } diff --git a/jolt-core/src/subprotocols/booleanity.rs b/jolt-core/src/subprotocols/booleanity.rs index 1f3bf6cadd..40ba1e721b 100644 --- a/jolt-core/src/subprotocols/booleanity.rs +++ b/jolt-core/src/subprotocols/booleanity.rs @@ -23,6 +23,7 @@ use allocative::FlameGraphBuilder; use ark_std::Zero; use rayon::prelude::*; use std::iter::zip; +use std::sync::Arc; use common::jolt_device::MemoryLayout; use tracer::instruction::Cycle; @@ -284,8 +285,8 @@ pub struct BooleanitySumcheckProver { F: ExpandingTable, /// eq(r_address, r_address) at end of phase 1 eq_r_r: F, - /// RA indices (non-transposed, one per cycle) - ra_indices: Vec, + /// RA indices (non-transposed, one per cycle), shared via Arc + ra_indices: Arc>, pub params: BooleanitySumcheckParams, } @@ -296,14 +297,14 @@ impl BooleanitySumcheckProver { /// - Compute G polynomials and RA indices in a single pass over the trace /// - Initialize split-eq polynomials for address (B) and cycle (D) variables /// - Initialize expanding table for phase 1 + /// Returns (prover, shared_ra_indices) — the Arc can be passed to other provers. #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::initialize")] pub fn initialize( params: BooleanitySumcheckParams, trace: &[Cycle], bytecode: &BytecodePreprocessing, memory_layout: &MemoryLayout, - ) -> Self { - // Compute G and RA indices in a single pass over the trace + ) -> (Self, Arc>) { let (G, ra_indices) = compute_all_G_and_ra_indices::( trace, bytecode, @@ -337,18 +338,22 @@ impl BooleanitySumcheckProver { rho_i *= gamma_f; } - Self { - gamma_powers, - gamma_powers_inv, - B, - D, - G, - ra_indices, - H: None, - F: F_table, - eq_r_r: F::zero(), - params, - } + let shared = Arc::clone(&ra_indices); + ( + Self { + gamma_powers, + gamma_powers_inv, + B, + D, + G, + ra_indices, + H: None, + F: F_table, + eq_r_r: F::zero(), + params, + }, + shared, + ) } fn compute_phase1_message(&self, round: usize, previous_claim: F) -> UniPoly { @@ -475,7 +480,7 @@ impl SumcheckInstanceProver for BooleanitySum // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i) let F_table = std::mem::take(&mut self.F); - let ra_indices = std::mem::take(&mut self.ra_indices); + let ra_indices = std::mem::replace(&mut self.ra_indices, Arc::new(Vec::new())); let base_eq = F_table.clone_values(); let num_polys = self.params.polynomial_types.len(); debug_assert!( diff --git a/jolt-core/src/subprotocols/mles_product_sum.rs b/jolt-core/src/subprotocols/mles_product_sum.rs index d687fd0ee7..b4d63fea50 100644 --- a/jolt-core/src/subprotocols/mles_product_sum.rs +++ b/jolt-core/src/subprotocols/mles_product_sum.rs @@ -1,6 +1,9 @@ use crate::{ field::{BarrettReduce, FMAdd, JoltField}, - poly::{ra_poly::RaPolynomial, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly}, + poly::{ + ra_poly::RaPolynomial, shared_ra_polys::SharedRaPolynomials, + split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, + }, utils::accumulation::SmallAccumS, }; use core::{mem::MaybeUninit, ptr}; @@ -227,6 +230,67 @@ impl_mles_sum_of_products_evals_d!( eval_prod_16_assign ); +/// Like `impl_mles_sum_of_products_evals_d` but reads from SharedRaPolynomials +/// instead of &[RaPolynomial]. Eliminates the transposed per-poly index storage. +macro_rules! impl_shared_ra_sum_of_products_evals_d { + ($fn_name:ident, $d:expr, $eval_prod:ident) => { + #[inline] + pub fn $fn_name( + shared_ra: &SharedRaPolynomials, + n_products: usize, + eq_poly: &GruenSplitEqPolynomial, + ) -> Vec { + debug_assert!(n_products > 0); + + let current_scalar = eq_poly.get_current_scalar(); + + let sum_evals_arr: [F; $d] = eq_poly.par_fold_out_in_unreduced::<$d>(&|g| { + let mut sums = [F::zero(); $d]; + + for t in 0..n_products { + let base = t * $d; + + let pairs: [(F, F); $d] = core::array::from_fn(|i| { + let p0 = shared_ra.get_bound_coeff(base + i, 2 * g); + let p1 = shared_ra.get_bound_coeff(base + i, 2 * g + 1); + (p0, p1) + }); + + let mut endpoints = [F::zero(); $d]; + $eval_prod::(&pairs, &mut endpoints); + + for k in 0..$d { + sums[k] += endpoints[k]; + } + } + + sums + }); + + sum_evals_arr + .into_iter() + .map(|x| x * current_scalar) + .collect() + } + }; +} + +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d4, + 4, + eval_prod_4_assign +); +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d8, + 8, + eval_prod_8_assign +); +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d16, + 16, + eval_prod_16_assign +); + /// Given the evaluations of `g(X) / eq(X, r[round])` on the grid /// `[1, 2, ..., d - 1, ∞]`, recover the full univariate polynomial /// `g(X) = eq(X, r[round]) * (interpolated quotient)` such that @@ -727,8 +791,6 @@ fn eval_prod_8_assign(p: &[(F, F); 8], outputs: &mut [F]) { /// correct alignment. /// - The sub-slices are non-overlapping. fn eval_prod_9_accumulate(p: &[(F, F); 9], outputs: &mut [F::UnreducedProductAccum]) { - // TODO: Implement more optimal way to do this. - // 5x4 split probably better than current 8x1 split. let p8 = p[0..8].try_into().unwrap(); let [a1, a2, a3, a4, a5, a6, a7, a8, a_inf] = eval_linear_prod_8_internal(p8); diff --git a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs index 50802fc2e0..b6d2f4d527 100644 --- a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs +++ b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs @@ -10,20 +10,20 @@ use crate::{ field::JoltField, poly::{ eq_poly::EqPolynomial, - multilinear_polynomial::{BindingOrder, PolynomialBinding}, + multilinear_polynomial::BindingOrder, opening_proof::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }, - ra_poly::RaPolynomial, + shared_ra_polys::{RaIndices, SharedRaPolynomials}, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, subprotocols::{ mles_product_sum::{ - compute_mles_product_sum_evals_sum_of_products_d16, - compute_mles_product_sum_evals_sum_of_products_d4, - compute_mles_product_sum_evals_sum_of_products_d8, finish_mles_product_sum_from_evals, + compute_shared_ra_sum_of_products_evals_d16, + compute_shared_ra_sum_of_products_evals_d4, compute_shared_ra_sum_of_products_evals_d8, + finish_mles_product_sum_from_evals, }, sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, @@ -31,15 +31,12 @@ use crate::{ transcripts::Transcript, zkvm::{ config::OneHotParams, - instruction::LookupQuery, instruction_lookups::LOG_K, witness::{CommittedPolynomial, VirtualPolynomial}, }, }; use allocative::Allocative; -use common::constants::XLEN; use rayon::prelude::*; -use tracer::instruction::Cycle; #[derive(Allocative, Clone)] pub struct InstructionRaSumcheckParams { @@ -185,43 +182,29 @@ impl SumcheckInstanceParams for InstructionRaSumcheckParams #[derive(Allocative)] pub struct InstructionRaSumcheckProver { - ra_i_polys: Vec>, + shared_ra: SharedRaPolynomials, eq_poly: GruenSplitEqPolynomial, pub params: InstructionRaSumcheckParams, } impl InstructionRaSumcheckProver { + /// Initialize using shared RA indices — no transposed per-poly index storage. #[tracing::instrument(skip_all, name = "InstructionRaSumcheckProver::initialize")] - pub fn initialize(params: InstructionRaSumcheckParams, trace: &[Cycle]) -> Self { - // Compute r_address_chunks with proper padding + pub fn initialize( + params: InstructionRaSumcheckParams, + shared_indices: &Arc>, + ) -> Self { let r_address_chunks = params .one_hot_params .compute_r_address_chunks::(¶ms.r_address.r); - let H_indices: Vec>> = (0..params.one_hot_params.instruction_d) - .map(|i| { - trace - .par_iter() - .map(|cycle| { - let lookup_index = LookupQuery::::to_lookup_index(cycle); - Some(params.one_hot_params.lookup_index_chunk(lookup_index, i)) - }) - .collect() - }) - .collect(); - let n_committed_per_virtual = params.n_committed_per_virtual; let gamma_powers = ¶ms.gamma_powers; + let instruction_d = params.one_hot_params.instruction_d; - let ra_i_polys = H_indices + let tables: Vec> = (0..instruction_d) .into_par_iter() - .enumerate() - .map(|(i, lookup_indices)| { - // Pre-scale the first committed polynomial in each virtual batch by γ^batch. - // - // This pushes the γ weight *inside* the product term so we can form - // (Σ γ^i · ∏ ra_{i,*}) before multiplying by split-eq's inner weights e_in, - // allowing a single split-eq fold for the whole sumcheck message. + .map(|i| { let scaling_factor = if i % n_committed_per_virtual == 0 { let batch = i / n_committed_per_virtual; let gamma = gamma_powers[batch]; @@ -233,14 +216,18 @@ impl InstructionRaSumcheckProver { } else { None }; - let eq_evals = - EqPolynomial::evals_with_scaling(&r_address_chunks[i], scaling_factor); - RaPolynomial::new(Arc::new(lookup_indices), eq_evals) + EqPolynomial::evals_with_scaling(&r_address_chunks[i], scaling_factor) }) .collect(); + let shared_ra = SharedRaPolynomials::new_instruction_only( + tables, + Arc::clone(shared_indices), + params.one_hot_params.clone(), + ); + Self { - ra_i_polys, + shared_ra, eq_poly: GruenSplitEqPolynomial::new(¶ms.r_cycle.r, BindingOrder::LowToHigh), params, } @@ -256,22 +243,19 @@ impl SumcheckInstanceProver for InstructionRa fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { let eq_poly = &self.eq_poly; - // Compute q(X) = Σ_i ∏_j ra_{i,j}(X,·) on the U_D grid using a *single* - // split-eq fold. The per-batch γ^i weights have already been absorbed by - // pre-scaling the first polynomial in each batch (see `initialize`). let evals = match self.params.n_committed_per_virtual { - 4 => compute_mles_product_sum_evals_sum_of_products_d4( - &self.ra_i_polys, + 4 => compute_shared_ra_sum_of_products_evals_d4( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), - 8 => compute_mles_product_sum_evals_sum_of_products_d8( - &self.ra_i_polys, + 8 => compute_shared_ra_sum_of_products_evals_d8( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), - 16 => compute_mles_product_sum_evals_sum_of_products_d16( - &self.ra_i_polys, + 16 => compute_shared_ra_sum_of_products_evals_d16( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), @@ -283,9 +267,7 @@ impl SumcheckInstanceProver for InstructionRa #[tracing::instrument(skip_all, name = "InstructionRaSumcheckProver::ingest_challenge")] fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { - self.ra_i_polys - .iter_mut() - .for_each(|p| p.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.shared_ra.bind_in_place(r_j, BindingOrder::LowToHigh); self.eq_poly.bind(r_j); } @@ -296,16 +278,13 @@ impl SumcheckInstanceProver for InstructionRa ) { let r_cycle = self.params.normalize_opening_point(sumcheck_challenges); - // Compute r_address_chunks with proper padding let r_address_chunks = self .params .one_hot_params .compute_r_address_chunks::(&self.params.r_address.r); for (i, r_address) in r_address_chunks.into_iter().enumerate() { - // Undo the per-batch γ scaling applied in `initialize` before caching openings, - // so the claimed openings match the *actual* committed polynomials. - let mut claim = self.ra_i_polys[i].final_sumcheck_claim(); + let mut claim = self.shared_ra.final_sumcheck_claim(i); if i % self.params.n_committed_per_virtual == 0 { let batch = i / self.params.n_committed_per_virtual; let gamma = self.params.gamma_powers[batch]; diff --git a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs index 70ff1bdc81..3db94a1e76 100644 --- a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs +++ b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs @@ -26,6 +26,7 @@ use crate::{ VerifierOpeningAccumulator, BIG_ENDIAN, }, prefix_suffix::{Prefix, PrefixRegistry, PrefixSuffixDecomposition}, + ra_poly::RaPolynomial, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, @@ -352,8 +353,10 @@ pub struct InstructionReadRafSumcheckProver { /// Gruen-split equality polynomial over cycle vars. Present only in the last log(T) rounds. eq_r_reduction: GruenSplitEqPolynomial, - /// Materialized `ra_i(k_i, j)` polynomials. Present only in the last log(T) rounds. - ra_polys: Option>>, + /// Lazy `ra_i(k_i, j)` polynomials using combined expanding-table lookups. + /// Starts as RaPolynomial (compact: 4 bytes/cycle + 2MB table per poly), + /// automatically materializes to dense after 3 cycle rounds. + ra_polys: Option>>, /// Materialized Val_j(k) + γ · RafVal_j(k) over (address, cycle) for final log T rounds. /// Combines lookup table values with γ-weighted RAF operand contributions. @@ -713,55 +716,97 @@ impl InstructionReadRafSumcheckProver { let m = 1 << log_m; let m_mask = m - 1; let num_cycles = self.lookup_indices.len(); - // Drop stuff that's no longer needed drop_in_background_thread(std::mem::take(&mut self.u_evals)); - let ra_polys: Vec> = { - let span = tracing::span!(tracing::Level::INFO, "Materialize ra polynomials"); + let n = LOG_K / self.params.ra_virtual_log_k_chunk; + let chunk_size = self.v.len() / n; + // Combined table size K^chunk_size. Only use lazy path when it fits in u16. + let combined_table_bits = log_m * chunk_size; + + let ra_polys: Vec> = if combined_table_bits <= 16 { + let span = tracing::span!(tracing::Level::INFO, "Build lazy ra polynomials"); let _guard = span.enter(); assert!(self.v.len().is_power_of_two()); - let n = LOG_K / self.params.ra_virtual_log_k_chunk; - let chunk_size = self.v.len() / n; + let combined_size = 1usize << combined_table_bits; + self.v .chunks(chunk_size) .enumerate() .map(|(chunk_i, v_chunk)| { let phase_offset = chunk_i * chunk_size; - let res = self + + // Build combined eq table: combined_table[key] = ∏ v[phase][chunk_val] + let combined_table: Vec = (0..combined_size) + .map(|combined_idx| { + let mut product = F::one(); + let mut remaining = combined_idx; + for table in v_chunk.iter().rev() { + product *= table[remaining & m_mask]; + remaining >>= log_m; + } + product + }) + .collect(); + + // Extract combined keys per cycle + let combined_keys: Vec> = self + .lookup_indices + .par_iter() + .with_min_len(1024) + .map(|bits| { + let v: u128 = (*bits).into(); + let mut key: usize = 0; + let mut shift = (self.params.phases - 1 - phase_offset) * log_m; + for p in 0..chunk_size { + let chunk_val = ((v >> shift) as usize) & m_mask; + key = (key << log_m) | chunk_val; + if p + 1 < chunk_size { + shift -= log_m; + } + } + Some(key as u16) + }) + .collect(); + + RaPolynomial::new(Arc::new(combined_keys), combined_table) + }) + .collect() + } else { + // Fallback: dense materialization for large combined tables + let span = tracing::span!(tracing::Level::INFO, "Materialize ra polynomials (dense)"); + let _guard = span.enter(); + assert!(self.v.len().is_power_of_two()); + + self.v + .chunks(chunk_size) + .enumerate() + .map(|(chunk_i, v_chunk)| { + let phase_offset = chunk_i * chunk_size; + let res: Vec = self .lookup_indices .par_iter() .with_min_len(1024) .map(|i| { - // Hot path: compute ra_i(k_i, j) as a product of per-phase expanding-table - // values. This is performance sensitive, so we: - // - Convert `LookupBits` -> `u128` once per cycle - // - Use a decrementing shift instead of recomputing `(phases-1-phase)*log_m` - // - Avoid an initial multiply-by-one by seeding `acc` with the first term let v: u128 = (*i).into(); - if v_chunk.is_empty() { return F::one(); } - - // shift(phase) = (phases - 1 - phase) * log_m - // For consecutive phases, this decreases by `log_m` each step. let mut shift = (self.params.phases - 1 - phase_offset) * log_m; - let mut iter = v_chunk.iter(); let first = iter.next().unwrap(); let first_idx = ((v >> shift) as usize) & m_mask; let mut acc = first[first_idx]; - for table in iter { shift -= log_m; let idx = ((v >> shift) as usize) & m_mask; acc *= table[idx]; } - acc }) - .collect::>(); - res.into() + .collect(); + // Wrap dense vec in RaPolynomial::RoundN so the type matches + let poly: MultilinearPolynomial = res.into(); + RaPolynomial::RoundN(poly) }) .collect() }; diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 7e025a65a0..40b98bbf05 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -173,7 +173,7 @@ pub struct JoltCpuProver< > { pub preprocessing: &'a JoltProverPreprocessing, pub program_io: JoltDevice, - pub lazy_trace: LazyTraceIterator, + pub lazy_trace: Option, pub trace: Arc>, pub advice: JoltAdvice, /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). @@ -449,7 +449,7 @@ impl< Self { preprocessing, program_io, - lazy_trace, + lazy_trace: Some(lazy_trace), trace: trace.into(), advice: JoltAdvice { untrusted_advice_polynomial: None, @@ -503,6 +503,8 @@ impl< ); let (commitments, mut opening_proof_hints) = self.generate_and_commit_witness_polynomials(); + // Free emulator memory — lazy_trace only needed for streaming commit + self.lazy_trace.take(); let untrusted_advice_commitment = self.generate_and_commit_untrusted_advice(); self.generate_and_commit_trusted_advice(); @@ -679,9 +681,10 @@ impl< polys.len() ); - // Materialize the trace for non-streaming commit let trace: Vec = self .lazy_trace + .as_ref() + .expect("lazy_trace consumed") .clone() .pad_using(T, |_| Cycle::NoOp) .collect(); @@ -719,6 +722,8 @@ impl< let mut row_commitments: Vec> = vec![vec![]; num_rows]; self.lazy_trace + .as_ref() + .expect("lazy_trace consumed") .clone() .pad_using(T, |_| Cycle::NoOp) .iter_chunks(row_len) @@ -1303,7 +1308,7 @@ impl< let mut ram_hamming_booleanity = HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); - let mut booleanity = BooleanitySumcheckProver::initialize( + let (mut booleanity, shared_ra_indices) = BooleanitySumcheckProver::initialize( booleanity_params, &self.trace, &self.preprocessing.shared.bytecode, @@ -1312,12 +1317,11 @@ impl< let mut ram_ra_virtual = RamRaVirtualSumcheckProver::initialize( ram_ra_virtual_params, - &self.trace, - &self.program_io.memory_layout, + &shared_ra_indices, &self.one_hot_params, ); let mut lookups_ra_virtual = - LookupsRaSumcheckProver::initialize(lookups_ra_virtual_params, &self.trace); + LookupsRaSumcheckProver::initialize(lookups_ra_virtual_params, &shared_ra_indices); let mut inc_reduction = IncClaimReductionSumcheckProver::initialize(inc_reduction_params, self.trace.clone()); diff --git a/jolt-core/src/zkvm/ram/ra_virtual.rs b/jolt-core/src/zkvm/ram/ra_virtual.rs index 7df6975ae9..7dbf697322 100644 --- a/jolt-core/src/zkvm/ram/ra_virtual.rs +++ b/jolt-core/src/zkvm/ram/ra_virtual.rs @@ -45,9 +45,7 @@ //! //! Variables are bound low-to-high, matching the polynomial layout. -use common::jolt_device::MemoryLayout; use std::sync::Arc; -use tracer::instruction::Cycle; #[cfg(feature = "zk")] use crate::poly::opening_proof::OpeningId; @@ -56,6 +54,7 @@ use crate::poly::opening_proof::{ VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }; use crate::poly::ra_poly::RaPolynomial; +use crate::poly::shared_ra_polys::RaIndices; use crate::poly::split_eq_poly::GruenSplitEqPolynomial; use crate::poly::unipoly::UniPoly; #[cfg(feature = "zk")] @@ -66,7 +65,6 @@ use crate::subprotocols::mles_product_sum::compute_mles_product_sum; use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; use crate::zkvm::config::OneHotParams; -use crate::zkvm::ram::remap_address; use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; use crate::{ field::JoltField, @@ -199,34 +197,32 @@ pub struct RamRaVirtualSumcheckProver { } impl RamRaVirtualSumcheckProver { + /// Initialize from shared RA indices (avoids re-reading trace for RAM indices). #[tracing::instrument(skip_all, name = "RamRaVirtualSumcheckProver::initialize")] pub fn initialize( params: RamRaVirtualParams, - trace: &[Cycle], - memory_layout: &MemoryLayout, + shared_indices: &Arc>, one_hot_params: &OneHotParams, ) -> Self { - // Precompute EQ tables for each address chunk let eq_tables: Vec> = params .r_address_chunks .iter() .map(|chunk| EqPolynomial::evals(chunk)) .collect(); - // Create eq polynomial with Gruen optimization for r_cycle_reduced let eq_poly = GruenSplitEqPolynomial::new(¶ms.r_cycle.r, BindingOrder::LowToHigh); - // Create ra_i polynomials for each decomposition chunk + let instruction_d = one_hot_params.instruction_d; + let bytecode_d = one_hot_params.bytecode_d; + let ram_offset = instruction_d + bytecode_d; + let ra_i_polys: Vec> = (0..params.d) .into_par_iter() .zip(eq_tables.into_par_iter()) .map(|(i, eq_table)| { - let ra_i_indices: Vec> = trace + let ra_i_indices: Vec> = shared_indices .par_iter() - .map(|cycle| { - remap_address(cycle.ram_access().address() as u64, memory_layout) - .map(|address| one_hot_params.ram_address_chunk(address, i)) - }) + .map(|ra| ra.get_index(ram_offset + i, one_hot_params)) .collect(); RaPolynomial::new(Arc::new(ra_i_indices), eq_table) }) diff --git a/scripts/jolt_benchmarks.sh b/scripts/jolt_benchmarks.sh index d460104ffc..d84ba70ae1 100755 --- a/scripts/jolt_benchmarks.sh +++ b/scripts/jolt_benchmarks.sh @@ -52,7 +52,7 @@ fi # Set stack size for Rust export RUST_MIN_STACK=33554432 -export RUSTFLAGS="-C target-cpu=native -C opt-level=3 -C codegen-units=1 -C embed-bitcode=yes" +export RUSTFLAGS="-C opt-level=3 -C codegen-units=1 -C embed-bitcode=yes" export RUST_LOG="info" # Create output directories