From cfb3ce57b2a55812777c9e6baab88f7ed0d41828 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:27:09 -0500 Subject: [PATCH 01/11] hotfix: remove native Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> --- scripts/jolt_benchmarks.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/jolt_benchmarks.sh b/scripts/jolt_benchmarks.sh index d460104ffc..d84ba70ae1 100755 --- a/scripts/jolt_benchmarks.sh +++ b/scripts/jolt_benchmarks.sh @@ -52,7 +52,7 @@ fi # Set stack size for Rust export RUST_MIN_STACK=33554432 -export RUSTFLAGS="-C target-cpu=native -C opt-level=3 -C codegen-units=1 -C embed-bitcode=yes" +export RUSTFLAGS="-C opt-level=3 -C codegen-units=1 -C embed-bitcode=yes" export RUST_LOG="info" # Create output directories From c11bef26540e26495b133b7573a6c25f847627a6 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:55:19 -0500 Subject: [PATCH 02/11] feat: use jemalloc with aggressive purging for memory profiling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Behind the `monitor` feature, use tikv-jemallocator as the global allocator with dirty_decay_ms:0 and muzzy_decay_ms:0. This forces immediate page return to the OS so RSS accurately reflects live heap usage — the system allocator holds freed pages indefinitely, inflating RSS from ~2 GB (actual) to ~5 GB (watermark). --- Cargo.lock | 21 +++++++++++++++++++++ Cargo.toml | 1 + jolt-core/Cargo.toml | 3 ++- jolt-core/src/bin/jolt_core.rs | 9 +++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index b8adb91691..ba671dd379 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2894,6 +2894,7 @@ dependencies = [ "strum_macros 0.28.0", "sysinfo", "thiserror 2.0.18", + "tikv-jemallocator", "tracer", "tracing", "tracing-chrome", @@ -5625,6 +5626,26 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + [[package]] name = "time" version = "0.3.47" diff --git a/Cargo.toml b/Cargo.toml index 67c8a9364a..9b90dde64e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -211,6 +211,7 @@ num-traits = { version = "0.2.19", default-features = false } sysinfo = "0.38" memory-stats = { version = "1.0.0", features = ["always_use_statm"] } mimalloc = "0.1" +tikv-jemallocator = { version = "0.6", features = ["unprefixed_malloc_on_supported_platforms"] } # Parallel Processing rayon = { version = "^1.8.0" } diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index ca30335e13..debb07c2b2 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -43,7 +43,7 @@ prover = [ # jolt-core needs std and rayon to compile, so these are the minimal set of features. minimal = ["ark-ec/std", "ark-ff/std", "ark-std/std", "ark-ff/asm", "rayon"] allocative = ["dep:inferno"] -monitor = ["dep:sysinfo"] +monitor = ["dep:sysinfo", "dep:tikv-jemallocator"] pprof = ["dep:pprof", "dep:prost"] test_incremental = [] challenge-254-bit = [] @@ -89,6 +89,7 @@ jolt-inlines-keccak256 = { workspace = true, features = [ "host", ], optional = true } sysinfo = { workspace = true, optional = true } +tikv-jemallocator = { workspace = true, optional = true } pprof = { version = "0.15", features = [ "prost-codec", "flamegraph", diff --git a/jolt-core/src/bin/jolt_core.rs b/jolt-core/src/bin/jolt_core.rs index 9c36aafd3c..e81bcbeaa7 100644 --- a/jolt-core/src/bin/jolt_core.rs +++ b/jolt-core/src/bin/jolt_core.rs @@ -1,3 +1,12 @@ +#[cfg(feature = "monitor")] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[cfg(feature = "monitor")] +#[allow(non_upper_case_globals)] +#[unsafe(export_name = "malloc_conf")] +pub static malloc_conf: &[u8] = b"dirty_decay_ms:0,muzzy_decay_ms:0\0"; + use clap::{Args, Parser, Subcommand, ValueEnum}; #[path = "../../benches/e2e_profiling.rs"] From 2699fdf70ae967ee4251b8d43a674340ceae3f22 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 00:53:04 -0500 Subject: [PATCH 03/11] perf: reduce prover peak memory in stages 5-6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stage 5: Replace dense ra_polys materialization (8×T×32B = 2GB at 2^23) with lazy RaPolynomial using combined expanding-table lookups. Each virtual RA poly stores a 64K-entry combined table + per-cycle u16 keys (4B/cycle). Automatically materializes to dense after 3 cycle rounds at T/8 length (0.25GB). Falls back to dense for large K configs. Stage 6: Add SharedRaRound4 to SharedRaPolynomials state machine, delaying materialization by one extra round. This staggers the booleanity materialization (now at T/16) from InstructionRaSumcheck's materialization (at T/8), preventing simultaneous peak allocations. Booleanity dense polys: 1.31GB → 0.66GB. --- jolt-core/src/poly/shared_ra_polys.rs | 160 +++++++++++++----- .../instruction_lookups/read_raf_checking.rs | 89 +++++++--- 2 files changed, 188 insertions(+), 61 deletions(-) diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index 2eab8220e8..b051dd3148 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -460,6 +460,8 @@ pub enum SharedRaPolynomials { Round2(SharedRaRound2), /// Round 3: Split into F_00, F_01, F_10, F_11 Round3(SharedRaRound3), + /// Round 4: Split into 8 table groups (F_000 through F_111) + Round4(SharedRaRound4), /// Round N: Fully materialized multilinear polynomials RoundN(Vec>), } @@ -510,6 +512,19 @@ pub struct SharedRaRound3 { binding_order: BindingOrder, } +/// Round 4 state: 8 table groups per polynomial. +/// Delays materialization one more round to reduce peak memory. +#[derive(Allocative, Default)] +pub struct SharedRaRound4 { + /// tables[group][poly_idx][k] — 8 groups of per-poly eq tables + tables: [Vec>; 8], + indices: Vec, + num_polys: usize, + #[allocative(skip)] + one_hot_params: OneHotParams, + binding_order: BindingOrder, +} + impl SharedRaPolynomials { /// Create new SharedRaPolynomials from eq table and indices. pub fn new(tables: Vec>, indices: Vec, one_hot_params: OneHotParams) -> Self { @@ -535,6 +550,7 @@ impl SharedRaPolynomials { Self::Round1(r) => r.num_polys, Self::Round2(r) => r.num_polys, Self::Round3(r) => r.num_polys, + Self::Round4(r) => r.num_polys, Self::RoundN(polys) => polys.len(), } } @@ -545,6 +561,7 @@ impl SharedRaPolynomials { Self::Round1(r) => r.indices.len(), Self::Round2(r) => r.indices.len() / 2, Self::Round3(r) => r.indices.len() / 4, + Self::Round4(r) => r.indices.len() / 8, Self::RoundN(polys) => polys[0].len(), } } @@ -556,6 +573,7 @@ impl SharedRaPolynomials { Self::Round1(r) => r.get_bound_coeff(poly_idx, j), Self::Round2(r) => r.get_bound_coeff(poly_idx, j), Self::Round3(r) => r.get_bound_coeff(poly_idx, j), + Self::Round4(r) => r.get_bound_coeff(poly_idx, j), Self::RoundN(polys) => polys[poly_idx].get_bound_coeff(j), } } @@ -574,7 +592,8 @@ impl SharedRaPolynomials { match self { Self::Round1(r1) => Self::Round2(r1.bind(r, order)), Self::Round2(r2) => Self::Round3(r2.bind(r, order)), - Self::Round3(r3) => Self::RoundN(r3.bind(r, order)), + Self::Round3(r3) => Self::Round4(r3.bind(r, order)), + Self::Round4(r4) => Self::RoundN(r4.bind(r, order)), Self::RoundN(mut polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); Self::RoundN(polys) @@ -584,11 +603,11 @@ impl SharedRaPolynomials { /// Bind in place with a challenge, transitioning to next round state. pub fn bind_in_place(&mut self, r: F::Challenge, order: BindingOrder) { - // Use mem::take pattern (same as ra_poly.rs) for efficiency match self { Self::Round1(r1) => *self = Self::Round2(std::mem::take(r1).bind(r, order)), Self::Round2(r2) => *self = Self::Round3(std::mem::take(r2).bind(r, order)), - Self::Round3(r3) => *self = Self::RoundN(std::mem::take(r3).bind(r, order)), + Self::Round3(r3) => *self = Self::Round4(std::mem::take(r3).bind(r, order)), + Self::Round4(r4) => *self = Self::RoundN(std::mem::take(r4).bind(r, order)), Self::RoundN(polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); } @@ -755,10 +774,9 @@ impl SharedRaRound3 { } #[tracing::instrument(skip_all, name = "SharedRaRound3::bind")] - fn bind(self, r2: F::Challenge, order: BindingOrder) -> Vec> { + fn bind(self, r2: F::Challenge, order: BindingOrder) -> SharedRaRound4 { assert_eq!(order, self.binding_order); - // Create 8 F tables: F_ABC where A=r0, B=r1, C=r2 let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); @@ -771,7 +789,6 @@ impl SharedRaRound3 { let mut tables_110 = self.tables_11.clone(); let mut tables_111 = self.tables_11; - // Scale by eq(r2, bit) rayon::join( || { [ @@ -803,56 +820,121 @@ impl SharedRaRound3 { }, ); - // Collect all 8 table groups for indexed access: group[offset][poly_idx][k] - let table_groups = [ - &tables_000, - &tables_100, - &tables_010, - &tables_110, - &tables_001, - &tables_101, - &tables_011, - &tables_111, + SharedRaRound4 { + tables: [ + tables_000, tables_100, tables_010, tables_110, tables_001, tables_101, tables_011, + tables_111, + ], + indices: self.indices, + num_polys: self.num_polys, + one_hot_params: self.one_hot_params, + binding_order: order, + } + } +} + +impl SharedRaRound4 { + #[inline] + fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { + match self.binding_order { + BindingOrder::LowToHigh => (0..8) + .map(|offset| { + self.indices[8 * j + offset] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) + }) + .sum(), + BindingOrder::HighToLow => { + let eighth = self.indices.len() / 8; + (0..8) + .map(|seg| { + self.indices[seg * eighth + j] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[seg][poly_idx][k as usize]) + }) + .sum() + } + } + } + + #[tracing::instrument(skip_all, name = "SharedRaRound4::bind")] + fn bind(self, r3: F::Challenge, order: BindingOrder) -> Vec> { + assert_eq!(order, self.binding_order); + + let eq_0_r3 = EqPolynomial::mle(&[F::zero()], &[r3]); + let eq_1_r3 = EqPolynomial::mle(&[F::one()], &[r3]); + + // 16 groups: [0..8) for bit3=0 (scale eq_0_r3), [8..16) for bit3=1 (scale eq_1_r3) + let [t0, t1, t2, t3, t4, t5, t6, t7] = self.tables; + let mut tables: [Vec>; 16] = [ + t0.clone(), + t1.clone(), + t2.clone(), + t3.clone(), + t4.clone(), + t5.clone(), + t6.clone(), + t7.clone(), + t0, + t1, + t2, + t3, + t4, + t5, + t6, + t7, ]; - // Materialize all polynomials in parallel + let (lo, hi) = tables.split_at_mut(8); + rayon::join( + || { + lo.par_iter_mut().for_each(|table| { + table + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r3)) + }) + }, + || { + hi.par_iter_mut().for_each(|table| { + table + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r3)) + }) + }, + ); + let num_polys = self.num_polys; let indices = &self.indices; let one_hot_params = &self.one_hot_params; - let new_len = indices.len() / 8; + let new_len = indices.len() / 16; (0..num_polys) .into_par_iter() .map(|poly_idx| { let coeffs: Vec = match order { - BindingOrder::LowToHigh => { - (0..new_len) - .into_par_iter() - .map(|j| { - // Sum over 8 consecutive indices, each using appropriate F table - (0..8) - .map(|offset| { - indices[8 * j + offset] - .get_index(poly_idx, one_hot_params) - .map_or(F::zero(), |k| { - table_groups[offset][poly_idx][k as usize] - }) - }) - .sum() - }) - .collect() - } + BindingOrder::LowToHigh => (0..new_len) + .into_par_iter() + .map(|j| { + (0..16) + .map(|offset| { + indices[16 * j + offset] + .get_index(poly_idx, one_hot_params) + .map_or(F::zero(), |k| tables[offset][poly_idx][k as usize]) + }) + .sum() + }) + .collect(), BindingOrder::HighToLow => { - let eighth = indices.len() / 8; + let sixteenth = indices.len() / 16; (0..new_len) .into_par_iter() .map(|j| { - (0..8) + (0..16) .map(|seg| { - indices[seg * eighth + j] + indices[seg * sixteenth + j] .get_index(poly_idx, one_hot_params) .map_or(F::zero(), |k| { - table_groups[seg][poly_idx][k as usize] + tables[seg][poly_idx][k as usize] }) }) .sum() diff --git a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs index 70ff1bdc81..3db94a1e76 100644 --- a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs +++ b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs @@ -26,6 +26,7 @@ use crate::{ VerifierOpeningAccumulator, BIG_ENDIAN, }, prefix_suffix::{Prefix, PrefixRegistry, PrefixSuffixDecomposition}, + ra_poly::RaPolynomial, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, @@ -352,8 +353,10 @@ pub struct InstructionReadRafSumcheckProver { /// Gruen-split equality polynomial over cycle vars. Present only in the last log(T) rounds. eq_r_reduction: GruenSplitEqPolynomial, - /// Materialized `ra_i(k_i, j)` polynomials. Present only in the last log(T) rounds. - ra_polys: Option>>, + /// Lazy `ra_i(k_i, j)` polynomials using combined expanding-table lookups. + /// Starts as RaPolynomial (compact: 4 bytes/cycle + 2MB table per poly), + /// automatically materializes to dense after 3 cycle rounds. + ra_polys: Option>>, /// Materialized Val_j(k) + γ · RafVal_j(k) over (address, cycle) for final log T rounds. /// Combines lookup table values with γ-weighted RAF operand contributions. @@ -713,55 +716,97 @@ impl InstructionReadRafSumcheckProver { let m = 1 << log_m; let m_mask = m - 1; let num_cycles = self.lookup_indices.len(); - // Drop stuff that's no longer needed drop_in_background_thread(std::mem::take(&mut self.u_evals)); - let ra_polys: Vec> = { - let span = tracing::span!(tracing::Level::INFO, "Materialize ra polynomials"); + let n = LOG_K / self.params.ra_virtual_log_k_chunk; + let chunk_size = self.v.len() / n; + // Combined table size K^chunk_size. Only use lazy path when it fits in u16. + let combined_table_bits = log_m * chunk_size; + + let ra_polys: Vec> = if combined_table_bits <= 16 { + let span = tracing::span!(tracing::Level::INFO, "Build lazy ra polynomials"); let _guard = span.enter(); assert!(self.v.len().is_power_of_two()); - let n = LOG_K / self.params.ra_virtual_log_k_chunk; - let chunk_size = self.v.len() / n; + let combined_size = 1usize << combined_table_bits; + self.v .chunks(chunk_size) .enumerate() .map(|(chunk_i, v_chunk)| { let phase_offset = chunk_i * chunk_size; - let res = self + + // Build combined eq table: combined_table[key] = ∏ v[phase][chunk_val] + let combined_table: Vec = (0..combined_size) + .map(|combined_idx| { + let mut product = F::one(); + let mut remaining = combined_idx; + for table in v_chunk.iter().rev() { + product *= table[remaining & m_mask]; + remaining >>= log_m; + } + product + }) + .collect(); + + // Extract combined keys per cycle + let combined_keys: Vec> = self + .lookup_indices + .par_iter() + .with_min_len(1024) + .map(|bits| { + let v: u128 = (*bits).into(); + let mut key: usize = 0; + let mut shift = (self.params.phases - 1 - phase_offset) * log_m; + for p in 0..chunk_size { + let chunk_val = ((v >> shift) as usize) & m_mask; + key = (key << log_m) | chunk_val; + if p + 1 < chunk_size { + shift -= log_m; + } + } + Some(key as u16) + }) + .collect(); + + RaPolynomial::new(Arc::new(combined_keys), combined_table) + }) + .collect() + } else { + // Fallback: dense materialization for large combined tables + let span = tracing::span!(tracing::Level::INFO, "Materialize ra polynomials (dense)"); + let _guard = span.enter(); + assert!(self.v.len().is_power_of_two()); + + self.v + .chunks(chunk_size) + .enumerate() + .map(|(chunk_i, v_chunk)| { + let phase_offset = chunk_i * chunk_size; + let res: Vec = self .lookup_indices .par_iter() .with_min_len(1024) .map(|i| { - // Hot path: compute ra_i(k_i, j) as a product of per-phase expanding-table - // values. This is performance sensitive, so we: - // - Convert `LookupBits` -> `u128` once per cycle - // - Use a decrementing shift instead of recomputing `(phases-1-phase)*log_m` - // - Avoid an initial multiply-by-one by seeding `acc` with the first term let v: u128 = (*i).into(); - if v_chunk.is_empty() { return F::one(); } - - // shift(phase) = (phases - 1 - phase) * log_m - // For consecutive phases, this decreases by `log_m` each step. let mut shift = (self.params.phases - 1 - phase_offset) * log_m; - let mut iter = v_chunk.iter(); let first = iter.next().unwrap(); let first_idx = ((v >> shift) as usize) & m_mask; let mut acc = first[first_idx]; - for table in iter { shift -= log_m; let idx = ((v >> shift) as usize) & m_mask; acc *= table[idx]; } - acc }) - .collect::>(); - res.into() + .collect(); + // Wrap dense vec in RaPolynomial::RoundN so the type matches + let poly: MultilinearPolynomial = res.into(); + RaPolynomial::RoundN(poly) }) .collect() }; From 1f3b5388965159a1f31cbfa4326208818005559e Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:07:53 -0500 Subject: [PATCH 04/11] perf: add Round4 to RaPolynomial for smaller dense materialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delays RaPolynomial materialization by one more round (Round3→Round4 instead of Round3→RoundN). Dense polys materialize at T/16 instead of T/8, halving peak allocation during stage 6 InstructionRa transition. --- jolt-core/src/poly/ra_poly.rs | 188 ++++++++++++++++++++++------------ 1 file changed, 121 insertions(+), 67 deletions(-) diff --git a/jolt-core/src/poly/ra_poly.rs b/jolt-core/src/poly/ra_poly.rs index 2ef8d97c95..691a943a54 100644 --- a/jolt-core/src/poly/ra_poly.rs +++ b/jolt-core/src/poly/ra_poly.rs @@ -23,6 +23,7 @@ pub enum RaPolynomial + Copy + Default + Send + Sync + 'static, F Round1(RaPolynomialRound1), Round2(RaPolynomialRound2), Round3(RaPolynomialRound3), + Round4(RaPolynomialRound4), RoundN(MultilinearPolynomial), } @@ -41,6 +42,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo Self::Round1(mle) => mle.get_bound_coeff(j), Self::Round2(mle) => mle.get_bound_coeff(j), Self::Round3(mle) => mle.get_bound_coeff(j), + Self::Round4(mle) => mle.get_bound_coeff(j), Self::RoundN(mle) => mle.get_bound_coeff(j), } } @@ -51,6 +53,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo Self::Round1(mle) => mle.len(), Self::Round2(mle) => mle.len(), Self::Round3(mle) => mle.len(), + Self::Round4(mle) => mle.len(), Self::RoundN(mle) => mle.len(), } } @@ -72,7 +75,8 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly Self::None => panic!("RaPolynomial::bind called on None"), Self::Round1(mle) => *self = Self::Round2(mem::take(mle).bind(r, order)), Self::Round2(mle) => *self = Self::Round3(mem::take(mle).bind(r, order)), - Self::Round3(mle) => *self = Self::RoundN(mem::take(mle).bind(r, order)), + Self::Round3(mle) => *self = Self::Round4(mem::take(mle).bind(r, order)), + Self::Round4(mle) => *self = Self::RoundN(mem::take(mle).bind(r, order)), Self::RoundN(mle) => mle.bind_parallel(r, order), }; } @@ -282,8 +286,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> } #[tracing::instrument(skip_all, name = "RaPolynomialRound3::bind")] - fn bind(self, r2: F::Challenge, _binding_order: BindingOrder) -> MultilinearPolynomial { - // Construct lookup tables. + fn bind(self, r2: F::Challenge, _binding_order: BindingOrder) -> RaPolynomialRound4 { let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); let mut F_000: Vec = self.F_00.clone(); @@ -304,34 +307,110 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> F_101.par_iter_mut().for_each(|f| *f *= eq_1_r2); F_111.par_iter_mut().for_each(|f| *f *= eq_1_r2); + RaPolynomialRound4 { + tables: [F_000, F_100, F_010, F_110, F_001, F_101, F_011, F_111], + lookup_indices: self.lookup_indices, + binding_order: self.binding_order, + } + } + + #[inline] + fn get_bound_coeff(&self, j: usize) -> F { + match self.binding_order { + BindingOrder::HighToLow => { + let n = self.lookup_indices.len() / 4; + let H_00 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_00[i.into()]); + let H_01 = self.lookup_indices[j + n].map_or(F::zero(), |i| self.F_01[i.into()]); + let H_10 = + self.lookup_indices[j + n * 2].map_or(F::zero(), |i| self.F_10[i.into()]); + let H_11 = + self.lookup_indices[j + n * 3].map_or(F::zero(), |i| self.F_11[i.into()]); + H_00 + H_10 + H_01 + H_11 + } + BindingOrder::LowToHigh => { + let H_00 = self.lookup_indices[4 * j].map_or(F::zero(), |i| self.F_00[i.into()]); + let H_10 = + self.lookup_indices[4 * j + 1].map_or(F::zero(), |i| self.F_10[i.into()]); + let H_01 = + self.lookup_indices[4 * j + 2].map_or(F::zero(), |i| self.F_01[i.into()]); + let H_11 = + self.lookup_indices[4 * j + 3].map_or(F::zero(), |i| self.F_11[i.into()]); + H_00 + H_10 + H_01 + H_11 + } + } + } +} + +/// Round 4: 8 eq tables, delays materialization one more round. +#[derive(Allocative, Default, Clone, Debug, PartialEq)] +pub struct RaPolynomialRound4 + Copy + Default + Send + Sync + 'static, F: JoltField> +{ + tables: [Vec; 8], + lookup_indices: Arc>>, + binding_order: BindingOrder, +} + +impl + Copy + Default + Send + Sync + 'static, F: JoltField> + RaPolynomialRound4 +{ + fn len(&self) -> usize { + self.lookup_indices.len() / 8 + } + + #[tracing::instrument(skip_all, name = "RaPolynomialRound4::bind")] + fn bind(self, r3: F::Challenge, _binding_order: BindingOrder) -> MultilinearPolynomial { + let eq_0_r3 = EqPolynomial::mle(&[F::zero()], &[r3]); + let eq_1_r3 = EqPolynomial::mle(&[F::one()], &[r3]); + + // 16 groups: [0..8) for bit3=0 (eq_0_r3), [8..16) for bit3=1 (eq_1_r3) + let [t0, t1, t2, t3, t4, t5, t6, t7] = self.tables; + let mut tables: [Vec; 16] = [ + t0.clone(), + t1.clone(), + t2.clone(), + t3.clone(), + t4.clone(), + t5.clone(), + t6.clone(), + t7.clone(), + t0, + t1, + t2, + t3, + t4, + t5, + t6, + t7, + ]; + + let (lo, hi) = tables.split_at_mut(8); + rayon::join( + || { + lo.par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r3)) + }, + || { + hi.par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r3)) + }, + ); + let lookup_indices = &self.lookup_indices; - let n = lookup_indices.len() / 8; - let mut res = unsafe_allocate_zero_vec(n); + let n = lookup_indices.len() / 16; + let mut res: Vec = unsafe_allocate_zero_vec(n); let chunk_size = 1 << 16; - - // Eval ra_i(r, r0, r1, j) for all j in the hypercube. match self.binding_order { BindingOrder::HighToLow => { res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - let H_000 = lookup_indices[j].map_or(F::zero(), |i| F_000[i.into()]); - let H_001 = - lookup_indices[j + n].map_or(F::zero(), |i| F_001[i.into()]); - let H_010 = - lookup_indices[j + n * 2].map_or(F::zero(), |i| F_010[i.into()]); - let H_011 = - lookup_indices[j + n * 3].map_or(F::zero(), |i| F_011[i.into()]); - let H_100 = - lookup_indices[j + n * 4].map_or(F::zero(), |i| F_100[i.into()]); - let H_101 = - lookup_indices[j + n * 5].map_or(F::zero(), |i| F_101[i.into()]); - let H_110 = - lookup_indices[j + n * 6].map_or(F::zero(), |i| F_110[i.into()]); - let H_111 = - lookup_indices[j + n * 7].map_or(F::zero(), |i| F_111[i.into()]); - *eval = H_000 + H_010 + H_100 + H_110 + H_001 + H_011 + H_101 + H_111; + *eval = (0..16) + .map(|seg| { + lookup_indices[seg * n + j] + .map_or(F::zero(), |i| tables[seg][i.into()]) + }) + .sum(); } }, ); @@ -340,23 +419,12 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - let H_000 = - lookup_indices[8 * j].map_or(F::zero(), |i| F_000[i.into()]); - let H_100 = - lookup_indices[8 * j + 1].map_or(F::zero(), |i| F_100[i.into()]); - let H_010 = - lookup_indices[8 * j + 2].map_or(F::zero(), |i| F_010[i.into()]); - let H_110 = - lookup_indices[8 * j + 3].map_or(F::zero(), |i| F_110[i.into()]); - let H_001 = - lookup_indices[8 * j + 4].map_or(F::zero(), |i| F_001[i.into()]); - let H_101 = - lookup_indices[8 * j + 5].map_or(F::zero(), |i| F_101[i.into()]); - let H_011 = - lookup_indices[8 * j + 6].map_or(F::zero(), |i| F_011[i.into()]); - let H_111 = - lookup_indices[8 * j + 7].map_or(F::zero(), |i| F_111[i.into()]); - *eval = H_000 + H_010 + H_100 + H_110 + H_001 + H_011 + H_101 + H_111; + *eval = (0..16) + .map(|offset| { + lookup_indices[16 * j + offset] + .map_or(F::zero(), |i| tables[offset][i.into()]) + }) + .sum(); } }, ); @@ -364,15 +432,6 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> } drop_in_background_thread(self.lookup_indices); - drop_in_background_thread(F_000); - drop_in_background_thread(F_100); - drop_in_background_thread(F_010); - drop_in_background_thread(F_110); - drop_in_background_thread(F_001); - drop_in_background_thread(F_101); - drop_in_background_thread(F_011); - drop_in_background_thread(F_111); - res.into() } @@ -380,25 +439,20 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> fn get_bound_coeff(&self, j: usize) -> F { match self.binding_order { BindingOrder::HighToLow => { - let n = self.lookup_indices.len() / 4; - let H_00 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_01 = self.lookup_indices[j + n].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_10 = - self.lookup_indices[j + n * 2].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_11 = - self.lookup_indices[j + n * 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 - } - BindingOrder::LowToHigh => { - let H_00 = self.lookup_indices[4 * j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_10 = - self.lookup_indices[4 * j + 1].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_01 = - self.lookup_indices[4 * j + 2].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_11 = - self.lookup_indices[4 * j + 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 + let n = self.lookup_indices.len() / 8; + (0..8) + .map(|seg| { + self.lookup_indices[seg * n + j] + .map_or(F::zero(), |i| self.tables[seg][i.into()]) + }) + .sum() } + BindingOrder::LowToHigh => (0..8) + .map(|offset| { + self.lookup_indices[8 * j + offset] + .map_or(F::zero(), |i| self.tables[offset][i.into()]) + }) + .sum(), } } } From 1da3be7e2dc290b1d2004b1f67fc72ec4ec657b1 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:25:41 -0500 Subject: [PATCH 05/11] refactor: share Arc> between booleanity and InstructionRa SharedRaPolynomials now stores Arc> instead of owned Vec. BooleanitySumcheckProver::initialize returns the shared Arc so InstructionRaSumcheckProver can reuse it. Currently InstructionRa still creates transposed per-poly indices from the shared data; full deduplication requires changing the product-sum evaluation functions. --- jolt-core/src/poly/shared_ra_polys.rs | 38 ++++++++++++---- jolt-core/src/subprotocols/booleanity.rs | 39 +++++++++------- .../zkvm/instruction_lookups/ra_virtual.rs | 44 ++++++------------- jolt-core/src/zkvm/prover.rs | 4 +- 4 files changed, 67 insertions(+), 58 deletions(-) diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index b051dd3148..852a3148f2 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -27,6 +27,8 @@ use allocative::Allocative; use ark_std::Zero; use fixedbitset::FixedBitSet; +use std::sync::Arc; + use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; @@ -232,9 +234,8 @@ pub fn compute_all_G_and_ra_indices( memory_layout: &MemoryLayout, one_hot_params: &OneHotParams, r_cycle: &[F::Challenge], -) -> (Vec>, Vec) { +) -> (Vec>, Arc>) { let T = trace.len(); - // Pre-allocate ra_indices let mut ra_indices: Vec = unsafe_allocate_zero_vec(T); let G = compute_all_G_impl::( @@ -246,7 +247,7 @@ pub fn compute_all_G_and_ra_indices( Some(&mut ra_indices), ); - (G, ra_indices) + (G, Arc::new(ra_indices)) } /// Core implementation for computing G evaluations. @@ -475,7 +476,7 @@ pub struct SharedRaRound1 { /// constant (e.g. a batching coefficient). tables: Vec>, /// RA indices for all cycles (non-transposed) - indices: Vec, + indices: Arc>, /// Number of polynomials num_polys: usize, /// OneHotParams for index extraction @@ -491,7 +492,7 @@ pub struct SharedRaRound2 { /// Per-polynomial tables for the 1-branch: tables_1[poly_idx][k] tables_1: Vec>, /// RA indices for all cycles - indices: Vec, + indices: Arc>, num_polys: usize, #[allocative(skip)] one_hot_params: OneHotParams, @@ -505,7 +506,7 @@ pub struct SharedRaRound3 { tables_01: Vec>, tables_10: Vec>, tables_11: Vec>, - indices: Vec, + indices: Arc>, num_polys: usize, #[allocative(skip)] one_hot_params: OneHotParams, @@ -518,7 +519,7 @@ pub struct SharedRaRound3 { pub struct SharedRaRound4 { /// tables[group][poly_idx][k] — 8 groups of per-poly eq tables tables: [Vec>; 8], - indices: Vec, + indices: Arc>, num_polys: usize, #[allocative(skip)] one_hot_params: OneHotParams, @@ -527,7 +528,11 @@ pub struct SharedRaRound4 { impl SharedRaPolynomials { /// Create new SharedRaPolynomials from eq table and indices. - pub fn new(tables: Vec>, indices: Vec, one_hot_params: OneHotParams) -> Self { + pub fn new( + tables: Vec>, + indices: Arc>, + one_hot_params: OneHotParams, + ) -> Self { let num_polys = one_hot_params.instruction_d + one_hot_params.bytecode_d + one_hot_params.ram_d; debug_assert!( @@ -544,6 +549,23 @@ impl SharedRaPolynomials { }) } + /// Create SharedRaPolynomials that only uses instruction polys from the shared indices. + /// `tables` should have exactly `instruction_d` entries. + pub fn new_instruction_only( + tables: Vec>, + indices: Arc>, + one_hot_params: OneHotParams, + ) -> Self { + let num_polys = one_hot_params.instruction_d; + debug_assert_eq!(tables.len(), num_polys); + Self::Round1(SharedRaRound1 { + tables, + indices, + num_polys, + one_hot_params, + }) + } + /// Get the number of polynomials pub fn num_polys(&self) -> usize { match self { diff --git a/jolt-core/src/subprotocols/booleanity.rs b/jolt-core/src/subprotocols/booleanity.rs index 1f3bf6cadd..40ba1e721b 100644 --- a/jolt-core/src/subprotocols/booleanity.rs +++ b/jolt-core/src/subprotocols/booleanity.rs @@ -23,6 +23,7 @@ use allocative::FlameGraphBuilder; use ark_std::Zero; use rayon::prelude::*; use std::iter::zip; +use std::sync::Arc; use common::jolt_device::MemoryLayout; use tracer::instruction::Cycle; @@ -284,8 +285,8 @@ pub struct BooleanitySumcheckProver { F: ExpandingTable, /// eq(r_address, r_address) at end of phase 1 eq_r_r: F, - /// RA indices (non-transposed, one per cycle) - ra_indices: Vec, + /// RA indices (non-transposed, one per cycle), shared via Arc + ra_indices: Arc>, pub params: BooleanitySumcheckParams, } @@ -296,14 +297,14 @@ impl BooleanitySumcheckProver { /// - Compute G polynomials and RA indices in a single pass over the trace /// - Initialize split-eq polynomials for address (B) and cycle (D) variables /// - Initialize expanding table for phase 1 + /// Returns (prover, shared_ra_indices) — the Arc can be passed to other provers. #[tracing::instrument(skip_all, name = "BooleanitySumcheckProver::initialize")] pub fn initialize( params: BooleanitySumcheckParams, trace: &[Cycle], bytecode: &BytecodePreprocessing, memory_layout: &MemoryLayout, - ) -> Self { - // Compute G and RA indices in a single pass over the trace + ) -> (Self, Arc>) { let (G, ra_indices) = compute_all_G_and_ra_indices::( trace, bytecode, @@ -337,18 +338,22 @@ impl BooleanitySumcheckProver { rho_i *= gamma_f; } - Self { - gamma_powers, - gamma_powers_inv, - B, - D, - G, - ra_indices, - H: None, - F: F_table, - eq_r_r: F::zero(), - params, - } + let shared = Arc::clone(&ra_indices); + ( + Self { + gamma_powers, + gamma_powers_inv, + B, + D, + G, + ra_indices, + H: None, + F: F_table, + eq_r_r: F::zero(), + params, + }, + shared, + ) } fn compute_phase1_message(&self, round: usize, previous_claim: F) -> UniPoly { @@ -475,7 +480,7 @@ impl SumcheckInstanceProver for BooleanitySum // Initialize SharedRaPolynomials with per-poly pre-scaled eq tables (by rho_i) let F_table = std::mem::take(&mut self.F); - let ra_indices = std::mem::take(&mut self.ra_indices); + let ra_indices = std::mem::replace(&mut self.ra_indices, Arc::new(Vec::new())); let base_eq = F_table.clone_values(); let num_polys = self.params.polynomial_types.len(); debug_assert!( diff --git a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs index 50802fc2e0..d2aa1645e3 100644 --- a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs +++ b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs @@ -16,6 +16,7 @@ use crate::{ VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }, ra_poly::RaPolynomial, + shared_ra_polys::RaIndices, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, @@ -31,15 +32,12 @@ use crate::{ transcripts::Transcript, zkvm::{ config::OneHotParams, - instruction::LookupQuery, instruction_lookups::LOG_K, witness::{CommittedPolynomial, VirtualPolynomial}, }, }; use allocative::Allocative; -use common::constants::XLEN; use rayon::prelude::*; -use tracer::instruction::Cycle; #[derive(Allocative, Clone)] pub struct InstructionRaSumcheckParams { @@ -191,37 +189,27 @@ pub struct InstructionRaSumcheckProver { } impl InstructionRaSumcheckProver { + /// Initialize from shared RA indices (avoids re-reading the trace). #[tracing::instrument(skip_all, name = "InstructionRaSumcheckProver::initialize")] - pub fn initialize(params: InstructionRaSumcheckParams, trace: &[Cycle]) -> Self { - // Compute r_address_chunks with proper padding + pub fn initialize( + params: InstructionRaSumcheckParams, + shared_indices: &Arc>, + ) -> Self { let r_address_chunks = params .one_hot_params .compute_r_address_chunks::(¶ms.r_address.r); - let H_indices: Vec>> = (0..params.one_hot_params.instruction_d) - .map(|i| { - trace - .par_iter() - .map(|cycle| { - let lookup_index = LookupQuery::::to_lookup_index(cycle); - Some(params.one_hot_params.lookup_index_chunk(lookup_index, i)) - }) - .collect() - }) - .collect(); - let n_committed_per_virtual = params.n_committed_per_virtual; let gamma_powers = ¶ms.gamma_powers; - let ra_i_polys = H_indices + let ra_i_polys = (0..params.one_hot_params.instruction_d) .into_par_iter() - .enumerate() - .map(|(i, lookup_indices)| { - // Pre-scale the first committed polynomial in each virtual batch by γ^batch. - // - // This pushes the γ weight *inside* the product term so we can form - // (Σ γ^i · ∏ ra_{i,*}) before multiplying by split-eq's inner weights e_in, - // allowing a single split-eq fold for the whole sumcheck message. + .map(|i| { + let lookup_indices: Vec> = shared_indices + .par_iter() + .map(|ra| Some(ra.instruction[i])) + .collect(); + let scaling_factor = if i % n_committed_per_virtual == 0 { let batch = i / n_committed_per_virtual; let gamma = gamma_powers[batch]; @@ -256,9 +244,6 @@ impl SumcheckInstanceProver for InstructionRa fn compute_message(&mut self, _round: usize, previous_claim: F) -> UniPoly { let eq_poly = &self.eq_poly; - // Compute q(X) = Σ_i ∏_j ra_{i,j}(X,·) on the U_D grid using a *single* - // split-eq fold. The per-batch γ^i weights have already been absorbed by - // pre-scaling the first polynomial in each batch (see `initialize`). let evals = match self.params.n_committed_per_virtual { 4 => compute_mles_product_sum_evals_sum_of_products_d4( &self.ra_i_polys, @@ -296,15 +281,12 @@ impl SumcheckInstanceProver for InstructionRa ) { let r_cycle = self.params.normalize_opening_point(sumcheck_challenges); - // Compute r_address_chunks with proper padding let r_address_chunks = self .params .one_hot_params .compute_r_address_chunks::(&self.params.r_address.r); for (i, r_address) in r_address_chunks.into_iter().enumerate() { - // Undo the per-batch γ scaling applied in `initialize` before caching openings, - // so the claimed openings match the *actual* committed polynomials. let mut claim = self.ra_i_polys[i].final_sumcheck_claim(); if i % self.params.n_committed_per_virtual == 0 { let batch = i / self.params.n_committed_per_virtual; diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 7e025a65a0..136239f5fd 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1303,7 +1303,7 @@ impl< let mut ram_hamming_booleanity = HammingBooleanitySumcheckProver::initialize(ram_hamming_booleanity_params, &self.trace); - let mut booleanity = BooleanitySumcheckProver::initialize( + let (mut booleanity, shared_ra_indices) = BooleanitySumcheckProver::initialize( booleanity_params, &self.trace, &self.preprocessing.shared.bytecode, @@ -1317,7 +1317,7 @@ impl< &self.one_hot_params, ); let mut lookups_ra_virtual = - LookupsRaSumcheckProver::initialize(lookups_ra_virtual_params, &self.trace); + LookupsRaSumcheckProver::initialize(lookups_ra_virtual_params, &shared_ra_indices); let mut inc_reduction = IncClaimReductionSumcheckProver::initialize(inc_reduction_params, self.trace.clone()); From 624db26e0fb183b70ed8b568a5da6cbaf7b0f72a Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:29:16 -0500 Subject: [PATCH 06/11] perf: eliminate transposed instruction RA indices (saves 0.50 GB) InstructionRaSumcheckProver now uses SharedRaPolynomials directly instead of creating 32 separate transposed Vec> arrays. The shared non-transposed Vec is accessed via get_bound_coeff(poly_idx, j) during the sumcheck. New compute_shared_ra_sum_of_products_evals_d{4,8,16} functions in mles_product_sum.rs provide the same eval_prod pattern but read from SharedRaPolynomials. --- .../src/subprotocols/mles_product_sum.rs | 66 ++++++++++++++++++- .../zkvm/instruction_lookups/ra_virtual.rs | 53 +++++++-------- 2 files changed, 90 insertions(+), 29 deletions(-) diff --git a/jolt-core/src/subprotocols/mles_product_sum.rs b/jolt-core/src/subprotocols/mles_product_sum.rs index d687fd0ee7..f885652b69 100644 --- a/jolt-core/src/subprotocols/mles_product_sum.rs +++ b/jolt-core/src/subprotocols/mles_product_sum.rs @@ -1,6 +1,9 @@ use crate::{ field::{BarrettReduce, FMAdd, JoltField}, - poly::{ra_poly::RaPolynomial, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly}, + poly::{ + ra_poly::RaPolynomial, shared_ra_polys::SharedRaPolynomials, + split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, + }, utils::accumulation::SmallAccumS, }; use core::{mem::MaybeUninit, ptr}; @@ -227,6 +230,67 @@ impl_mles_sum_of_products_evals_d!( eval_prod_16_assign ); +/// Like `impl_mles_sum_of_products_evals_d` but reads from SharedRaPolynomials +/// instead of &[RaPolynomial]. Eliminates the transposed per-poly index storage. +macro_rules! impl_shared_ra_sum_of_products_evals_d { + ($fn_name:ident, $d:expr, $eval_prod:ident) => { + #[inline] + pub fn $fn_name( + shared_ra: &SharedRaPolynomials, + n_products: usize, + eq_poly: &GruenSplitEqPolynomial, + ) -> Vec { + debug_assert!(n_products > 0); + + let current_scalar = eq_poly.get_current_scalar(); + + let sum_evals_arr: [F; $d] = eq_poly.par_fold_out_in_unreduced::<$d>(&|g| { + let mut sums = [F::zero(); $d]; + + for t in 0..n_products { + let base = t * $d; + + let pairs: [(F, F); $d] = core::array::from_fn(|i| { + let p0 = shared_ra.get_bound_coeff(base + i, 2 * g); + let p1 = shared_ra.get_bound_coeff(base + i, 2 * g + 1); + (p0, p1) + }); + + let mut endpoints = [F::zero(); $d]; + $eval_prod::(&pairs, &mut endpoints); + + for k in 0..$d { + sums[k] += endpoints[k]; + } + } + + sums + }); + + sum_evals_arr + .into_iter() + .map(|x| x * current_scalar) + .collect() + } + }; +} + +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d4, + 4, + eval_prod_4_assign +); +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d8, + 8, + eval_prod_8_assign +); +impl_shared_ra_sum_of_products_evals_d!( + compute_shared_ra_sum_of_products_evals_d16, + 16, + eval_prod_16_assign +); + /// Given the evaluations of `g(X) / eq(X, r[round])` on the grid /// `[1, 2, ..., d - 1, ∞]`, recover the full univariate polynomial /// `g(X) = eq(X, r[round]) * (interpolated quotient)` such that diff --git a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs index d2aa1645e3..b6d2f4d527 100644 --- a/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs +++ b/jolt-core/src/zkvm/instruction_lookups/ra_virtual.rs @@ -10,21 +10,20 @@ use crate::{ field::JoltField, poly::{ eq_poly::EqPolynomial, - multilinear_polynomial::{BindingOrder, PolynomialBinding}, + multilinear_polynomial::BindingOrder, opening_proof::{ OpeningAccumulator, OpeningPoint, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }, - ra_poly::RaPolynomial, - shared_ra_polys::RaIndices, + shared_ra_polys::{RaIndices, SharedRaPolynomials}, split_eq_poly::GruenSplitEqPolynomial, unipoly::UniPoly, }, subprotocols::{ mles_product_sum::{ - compute_mles_product_sum_evals_sum_of_products_d16, - compute_mles_product_sum_evals_sum_of_products_d4, - compute_mles_product_sum_evals_sum_of_products_d8, finish_mles_product_sum_from_evals, + compute_shared_ra_sum_of_products_evals_d16, + compute_shared_ra_sum_of_products_evals_d4, compute_shared_ra_sum_of_products_evals_d8, + finish_mles_product_sum_from_evals, }, sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}, @@ -183,13 +182,13 @@ impl SumcheckInstanceParams for InstructionRaSumcheckParams #[derive(Allocative)] pub struct InstructionRaSumcheckProver { - ra_i_polys: Vec>, + shared_ra: SharedRaPolynomials, eq_poly: GruenSplitEqPolynomial, pub params: InstructionRaSumcheckParams, } impl InstructionRaSumcheckProver { - /// Initialize from shared RA indices (avoids re-reading the trace). + /// Initialize using shared RA indices — no transposed per-poly index storage. #[tracing::instrument(skip_all, name = "InstructionRaSumcheckProver::initialize")] pub fn initialize( params: InstructionRaSumcheckParams, @@ -201,15 +200,11 @@ impl InstructionRaSumcheckProver { let n_committed_per_virtual = params.n_committed_per_virtual; let gamma_powers = ¶ms.gamma_powers; + let instruction_d = params.one_hot_params.instruction_d; - let ra_i_polys = (0..params.one_hot_params.instruction_d) + let tables: Vec> = (0..instruction_d) .into_par_iter() .map(|i| { - let lookup_indices: Vec> = shared_indices - .par_iter() - .map(|ra| Some(ra.instruction[i])) - .collect(); - let scaling_factor = if i % n_committed_per_virtual == 0 { let batch = i / n_committed_per_virtual; let gamma = gamma_powers[batch]; @@ -221,14 +216,18 @@ impl InstructionRaSumcheckProver { } else { None }; - let eq_evals = - EqPolynomial::evals_with_scaling(&r_address_chunks[i], scaling_factor); - RaPolynomial::new(Arc::new(lookup_indices), eq_evals) + EqPolynomial::evals_with_scaling(&r_address_chunks[i], scaling_factor) }) .collect(); + let shared_ra = SharedRaPolynomials::new_instruction_only( + tables, + Arc::clone(shared_indices), + params.one_hot_params.clone(), + ); + Self { - ra_i_polys, + shared_ra, eq_poly: GruenSplitEqPolynomial::new(¶ms.r_cycle.r, BindingOrder::LowToHigh), params, } @@ -245,18 +244,18 @@ impl SumcheckInstanceProver for InstructionRa let eq_poly = &self.eq_poly; let evals = match self.params.n_committed_per_virtual { - 4 => compute_mles_product_sum_evals_sum_of_products_d4( - &self.ra_i_polys, + 4 => compute_shared_ra_sum_of_products_evals_d4( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), - 8 => compute_mles_product_sum_evals_sum_of_products_d8( - &self.ra_i_polys, + 8 => compute_shared_ra_sum_of_products_evals_d8( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), - 16 => compute_mles_product_sum_evals_sum_of_products_d16( - &self.ra_i_polys, + 16 => compute_shared_ra_sum_of_products_evals_d16( + &self.shared_ra, self.params.n_virtual_ra_polys, eq_poly, ), @@ -268,9 +267,7 @@ impl SumcheckInstanceProver for InstructionRa #[tracing::instrument(skip_all, name = "InstructionRaSumcheckProver::ingest_challenge")] fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { - self.ra_i_polys - .iter_mut() - .for_each(|p| p.bind_parallel(r_j, BindingOrder::LowToHigh)); + self.shared_ra.bind_in_place(r_j, BindingOrder::LowToHigh); self.eq_poly.bind(r_j); } @@ -287,7 +284,7 @@ impl SumcheckInstanceProver for InstructionRa .compute_r_address_chunks::(&self.params.r_address.r); for (i, r_address) in r_address_chunks.into_iter().enumerate() { - let mut claim = self.ra_i_polys[i].final_sumcheck_claim(); + let mut claim = self.shared_ra.final_sumcheck_claim(i); if i % self.params.n_committed_per_virtual == 0 { let batch = i / self.params.n_committed_per_virtual; let gamma = self.params.gamma_powers[batch]; From 0f7a5e55f94b326c71f33a31db86185a13f17e11 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:38:54 -0500 Subject: [PATCH 07/11] perf: RamRaVirtual also uses shared RA indices RamRaVirtualSumcheckProver now reads RAM RA indices from the shared Arc> instead of re-reading the trace. Saves 0.13 GB of transposed RAM index storage and avoids a full trace iteration. --- jolt-core/src/zkvm/prover.rs | 3 +-- jolt-core/src/zkvm/ram/ra_virtual.rs | 22 +++++++++------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index 136239f5fd..b86fca429c 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -1312,8 +1312,7 @@ impl< let mut ram_ra_virtual = RamRaVirtualSumcheckProver::initialize( ram_ra_virtual_params, - &self.trace, - &self.program_io.memory_layout, + &shared_ra_indices, &self.one_hot_params, ); let mut lookups_ra_virtual = diff --git a/jolt-core/src/zkvm/ram/ra_virtual.rs b/jolt-core/src/zkvm/ram/ra_virtual.rs index 7df6975ae9..7dbf697322 100644 --- a/jolt-core/src/zkvm/ram/ra_virtual.rs +++ b/jolt-core/src/zkvm/ram/ra_virtual.rs @@ -45,9 +45,7 @@ //! //! Variables are bound low-to-high, matching the polynomial layout. -use common::jolt_device::MemoryLayout; use std::sync::Arc; -use tracer::instruction::Cycle; #[cfg(feature = "zk")] use crate::poly::opening_proof::OpeningId; @@ -56,6 +54,7 @@ use crate::poly::opening_proof::{ VerifierOpeningAccumulator, BIG_ENDIAN, LITTLE_ENDIAN, }; use crate::poly::ra_poly::RaPolynomial; +use crate::poly::shared_ra_polys::RaIndices; use crate::poly::split_eq_poly::GruenSplitEqPolynomial; use crate::poly::unipoly::UniPoly; #[cfg(feature = "zk")] @@ -66,7 +65,6 @@ use crate::subprotocols::mles_product_sum::compute_mles_product_sum; use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; use crate::subprotocols::sumcheck_verifier::{SumcheckInstanceParams, SumcheckInstanceVerifier}; use crate::zkvm::config::OneHotParams; -use crate::zkvm::ram::remap_address; use crate::zkvm::witness::{CommittedPolynomial, VirtualPolynomial}; use crate::{ field::JoltField, @@ -199,34 +197,32 @@ pub struct RamRaVirtualSumcheckProver { } impl RamRaVirtualSumcheckProver { + /// Initialize from shared RA indices (avoids re-reading trace for RAM indices). #[tracing::instrument(skip_all, name = "RamRaVirtualSumcheckProver::initialize")] pub fn initialize( params: RamRaVirtualParams, - trace: &[Cycle], - memory_layout: &MemoryLayout, + shared_indices: &Arc>, one_hot_params: &OneHotParams, ) -> Self { - // Precompute EQ tables for each address chunk let eq_tables: Vec> = params .r_address_chunks .iter() .map(|chunk| EqPolynomial::evals(chunk)) .collect(); - // Create eq polynomial with Gruen optimization for r_cycle_reduced let eq_poly = GruenSplitEqPolynomial::new(¶ms.r_cycle.r, BindingOrder::LowToHigh); - // Create ra_i polynomials for each decomposition chunk + let instruction_d = one_hot_params.instruction_d; + let bytecode_d = one_hot_params.bytecode_d; + let ram_offset = instruction_d + bytecode_d; + let ra_i_polys: Vec> = (0..params.d) .into_par_iter() .zip(eq_tables.into_par_iter()) .map(|(i, eq_table)| { - let ra_i_indices: Vec> = trace + let ra_i_indices: Vec> = shared_indices .par_iter() - .map(|cycle| { - remap_address(cycle.ram_access().address() as u64, memory_layout) - .map(|address| one_hot_params.ram_address_chunk(address, i)) - }) + .map(|ra| ra.get_index(ram_offset + i, one_hot_params)) .collect(); RaPolynomial::new(Arc::new(ra_i_indices), eq_table) }) From b364c053c5d76c2f3d6043b1ce703a670fcf64c0 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:48:33 -0500 Subject: [PATCH 08/11] perf: add Round5 to SharedRaPolynomials (T/32 materialization) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Further delays booleanity materialization to T/32, staggering it from InstructionRa's Round4→RoundN at T/16. This prevents simultaneous dense poly allocation spikes in stage 6. --- jolt-core/src/poly/shared_ra_polys.rs | 123 +++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index 852a3148f2..5fea66dba7 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -463,6 +463,8 @@ pub enum SharedRaPolynomials { Round3(SharedRaRound3), /// Round 4: Split into 8 table groups (F_000 through F_111) Round4(SharedRaRound4), + /// Round 5: Split into 16 table groups + Round5(SharedRaRound5), /// Round N: Fully materialized multilinear polynomials RoundN(Vec>), } @@ -526,6 +528,29 @@ pub struct SharedRaRound4 { binding_order: BindingOrder, } +/// Round 5 state: 16 table groups per polynomial. +#[derive(Allocative)] +pub struct SharedRaRound5 { + tables: Vec>>, // tables[group_idx][poly_idx][k], 16 groups + indices: Arc>, + num_polys: usize, + #[allocative(skip)] + one_hot_params: OneHotParams, + binding_order: BindingOrder, +} + +impl Default for SharedRaRound5 { + fn default() -> Self { + Self { + tables: Vec::new(), + indices: Arc::new(Vec::new()), + num_polys: 0, + one_hot_params: OneHotParams::default(), + binding_order: BindingOrder::LowToHigh, + } + } +} + impl SharedRaPolynomials { /// Create new SharedRaPolynomials from eq table and indices. pub fn new( @@ -573,6 +598,7 @@ impl SharedRaPolynomials { Self::Round2(r) => r.num_polys, Self::Round3(r) => r.num_polys, Self::Round4(r) => r.num_polys, + Self::Round5(r) => r.num_polys, Self::RoundN(polys) => polys.len(), } } @@ -584,6 +610,7 @@ impl SharedRaPolynomials { Self::Round2(r) => r.indices.len() / 2, Self::Round3(r) => r.indices.len() / 4, Self::Round4(r) => r.indices.len() / 8, + Self::Round5(r) => r.indices.len() / 16, Self::RoundN(polys) => polys[0].len(), } } @@ -596,6 +623,7 @@ impl SharedRaPolynomials { Self::Round2(r) => r.get_bound_coeff(poly_idx, j), Self::Round3(r) => r.get_bound_coeff(poly_idx, j), Self::Round4(r) => r.get_bound_coeff(poly_idx, j), + Self::Round5(r) => r.get_bound_coeff(poly_idx, j), Self::RoundN(polys) => polys[poly_idx].get_bound_coeff(j), } } @@ -615,7 +643,8 @@ impl SharedRaPolynomials { Self::Round1(r1) => Self::Round2(r1.bind(r, order)), Self::Round2(r2) => Self::Round3(r2.bind(r, order)), Self::Round3(r3) => Self::Round4(r3.bind(r, order)), - Self::Round4(r4) => Self::RoundN(r4.bind(r, order)), + Self::Round4(r4) => Self::Round5(r4.bind(r, order)), + Self::Round5(r5) => Self::RoundN(r5.bind(r, order)), Self::RoundN(mut polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); Self::RoundN(polys) @@ -629,7 +658,8 @@ impl SharedRaPolynomials { Self::Round1(r1) => *self = Self::Round2(std::mem::take(r1).bind(r, order)), Self::Round2(r2) => *self = Self::Round3(std::mem::take(r2).bind(r, order)), Self::Round3(r3) => *self = Self::Round4(std::mem::take(r3).bind(r, order)), - Self::Round4(r4) => *self = Self::RoundN(std::mem::take(r4).bind(r, order)), + Self::Round4(r4) => *self = Self::Round5(std::mem::take(r4).bind(r, order)), + Self::Round5(r5) => *self = Self::RoundN(std::mem::take(r5).bind(r, order)), Self::RoundN(polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); } @@ -880,15 +910,14 @@ impl SharedRaRound4 { } #[tracing::instrument(skip_all, name = "SharedRaRound4::bind")] - fn bind(self, r3: F::Challenge, order: BindingOrder) -> Vec> { + fn bind(self, r3: F::Challenge, order: BindingOrder) -> SharedRaRound5 { assert_eq!(order, self.binding_order); let eq_0_r3 = EqPolynomial::mle(&[F::zero()], &[r3]); let eq_1_r3 = EqPolynomial::mle(&[F::one()], &[r3]); - // 16 groups: [0..8) for bit3=0 (scale eq_0_r3), [8..16) for bit3=1 (scale eq_1_r3) let [t0, t1, t2, t3, t4, t5, t6, t7] = self.tables; - let mut tables: [Vec>; 16] = [ + let mut tables_vec: Vec>> = vec![ t0.clone(), t1.clone(), t2.clone(), @@ -907,7 +936,7 @@ impl SharedRaRound4 { t7, ]; - let (lo, hi) = tables.split_at_mut(8); + let (lo, hi) = tables_vec.split_at_mut(8); rayon::join( || { lo.par_iter_mut().for_each(|table| { @@ -925,10 +954,80 @@ impl SharedRaRound4 { }, ); + SharedRaRound5 { + tables: tables_vec, + indices: self.indices, + num_polys: self.num_polys, + one_hot_params: self.one_hot_params, + binding_order: order, + } + } +} + +impl SharedRaRound5 { + #[inline] + fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { + match self.binding_order { + BindingOrder::LowToHigh => (0..16) + .map(|offset| { + self.indices[16 * j + offset] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) + }) + .sum(), + BindingOrder::HighToLow => { + let sixteenth = self.indices.len() / 16; + (0..16) + .map(|seg| { + self.indices[seg * sixteenth + j] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[seg][poly_idx][k as usize]) + }) + .sum() + } + } + } + + #[tracing::instrument(skip_all, name = "SharedRaRound5::bind")] + fn bind(self, r4: F::Challenge, order: BindingOrder) -> Vec> { + assert_eq!(order, self.binding_order); + + let eq_0_r4 = EqPolynomial::mle(&[F::zero()], &[r4]); + let eq_1_r4 = EqPolynomial::mle(&[F::one()], &[r4]); + + // 32 groups from 16: [0..16) for bit4=0, [16..32) for bit4=1 + let n_groups = self.tables.len(); + let mut tables: Vec>> = Vec::with_capacity(n_groups * 2); + for t in &self.tables { + tables.push(t.clone()); + } + for t in self.tables { + tables.push(t); + } + + let (lo, hi) = tables.split_at_mut(n_groups); + rayon::join( + || { + lo.par_iter_mut().for_each(|table| { + table + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r4)) + }) + }, + || { + hi.par_iter_mut().for_each(|table| { + table + .par_iter_mut() + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r4)) + }) + }, + ); + let num_polys = self.num_polys; let indices = &self.indices; let one_hot_params = &self.one_hot_params; - let new_len = indices.len() / 16; + let n_total = tables.len(); // 32 + let new_len = indices.len() / n_total; (0..num_polys) .into_par_iter() @@ -937,9 +1036,9 @@ impl SharedRaRound4 { BindingOrder::LowToHigh => (0..new_len) .into_par_iter() .map(|j| { - (0..16) + (0..n_total) .map(|offset| { - indices[16 * j + offset] + indices[n_total * j + offset] .get_index(poly_idx, one_hot_params) .map_or(F::zero(), |k| tables[offset][poly_idx][k as usize]) }) @@ -947,13 +1046,13 @@ impl SharedRaRound4 { }) .collect(), BindingOrder::HighToLow => { - let sixteenth = indices.len() / 16; + let segment = indices.len() / n_total; (0..new_len) .into_par_iter() .map(|j| { - (0..16) + (0..n_total) .map(|seg| { - indices[seg * sixteenth + j] + indices[seg * segment + j] .get_index(poly_idx, one_hot_params) .map_or(F::zero(), |k| { tables[seg][poly_idx][k as usize] From f35a1107698ca24d54a1dc0f6bb101064eb2988a Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:00:17 -0500 Subject: [PATCH 09/11] perf: drop lazy_trace after witness gen to free emulator memory The LazyTraceIterator holds the emulator state (including HashMap-based memory) which persists until the prover is dropped. After streaming witness commitment finishes, this data is never used again. Drop it immediately to free the emulator's memory map. --- jolt-core/src/zkvm/prover.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jolt-core/src/zkvm/prover.rs b/jolt-core/src/zkvm/prover.rs index b86fca429c..40b98bbf05 100644 --- a/jolt-core/src/zkvm/prover.rs +++ b/jolt-core/src/zkvm/prover.rs @@ -173,7 +173,7 @@ pub struct JoltCpuProver< > { pub preprocessing: &'a JoltProverPreprocessing, pub program_io: JoltDevice, - pub lazy_trace: LazyTraceIterator, + pub lazy_trace: Option, pub trace: Arc>, pub advice: JoltAdvice, /// The advice claim reduction sumcheck effectively spans two stages (6 and 7). @@ -449,7 +449,7 @@ impl< Self { preprocessing, program_io, - lazy_trace, + lazy_trace: Some(lazy_trace), trace: trace.into(), advice: JoltAdvice { untrusted_advice_polynomial: None, @@ -503,6 +503,8 @@ impl< ); let (commitments, mut opening_proof_hints) = self.generate_and_commit_witness_polynomials(); + // Free emulator memory — lazy_trace only needed for streaming commit + self.lazy_trace.take(); let untrusted_advice_commitment = self.generate_and_commit_untrusted_advice(); self.generate_and_commit_trusted_advice(); @@ -679,9 +681,10 @@ impl< polys.len() ); - // Materialize the trace for non-streaming commit let trace: Vec = self .lazy_trace + .as_ref() + .expect("lazy_trace consumed") .clone() .pad_using(T, |_| Cycle::NoOp) .collect(); @@ -719,6 +722,8 @@ impl< let mut row_commitments: Vec> = vec![vec![]; num_rows]; self.lazy_trace + .as_ref() + .expect("lazy_trace consumed") .clone() .pad_using(T, |_| Cycle::NoOp) .iter_chunks(row_len) From 73e1d5456680fc832e56c12046347322c1316577 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:48:56 -0500 Subject: [PATCH 10/11] feat: use jemalloc for all prover builds Move tikv-jemallocator from monitor-only to the prover feature so jemalloc with aggressive page purging is used in all builds, not just profiling. This ensures RSS accurately reflects live heap. --- jolt-core/Cargo.toml | 3 ++- jolt-core/src/bin/jolt_core.rs | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index debb07c2b2..3d43d0d3ca 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -38,12 +38,13 @@ prover = [ "dory/backends", "dory/cache", # This includes `parallel` feature "dory/disk-persistence", + "dep:tikv-jemallocator", ] # This is for building jolt-core without prover capabilities, e.g. for recursion # jolt-core needs std and rayon to compile, so these are the minimal set of features. minimal = ["ark-ec/std", "ark-ff/std", "ark-std/std", "ark-ff/asm", "rayon"] allocative = ["dep:inferno"] -monitor = ["dep:sysinfo", "dep:tikv-jemallocator"] +monitor = ["dep:sysinfo"] pprof = ["dep:pprof", "dep:prost"] test_incremental = [] challenge-254-bit = [] diff --git a/jolt-core/src/bin/jolt_core.rs b/jolt-core/src/bin/jolt_core.rs index e81bcbeaa7..f727633d0c 100644 --- a/jolt-core/src/bin/jolt_core.rs +++ b/jolt-core/src/bin/jolt_core.rs @@ -1,8 +1,6 @@ -#[cfg(feature = "monitor")] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -#[cfg(feature = "monitor")] #[allow(non_upper_case_globals)] #[unsafe(export_name = "malloc_conf")] pub static malloc_conf: &[u8] = b"dirty_decay_ms:0,muzzy_decay_ms:0\0"; From 78fdee861ea065ab8a02969f39e98842ea53e22b Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Fri, 6 Mar 2026 20:39:54 -0500 Subject: [PATCH 11/11] refactor: unify RA polynomial round types into generic table-doubling state machine Replace hand-unrolled Round1-4 (RaPolynomial) and Round1-5 (SharedRaPolynomials) with a single TableRound type per enum. Each bind doubles table groups; materialization triggers at a configurable threshold (8 / 16 groups respectively). -515 lines, no behavioral change. --- CLAUDE.md | 4 +- jolt-core/src/poly/ra_poly.rs | 366 ++++------- jolt-core/src/poly/shared_ra_polys.rs | 575 ++++-------------- .../src/subprotocols/mles_product_sum.rs | 2 - 4 files changed, 216 insertions(+), 731 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ff455863a2..19bc990f07 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -106,8 +106,8 @@ ProofTranscript: Transcript — Fiat-Shamir transcript (Blake2bTra - `DensePolynomial`: Full field-element coefficients - `CompactPolynomial`: Small scalar coefficients (u8–i128), promoted to field on bind -- `RaPolynomial`: Lazy materialization via Round1→Round2→Round3→RoundN state machine -- `SharedRaPolynomials`: Shares eq tables across N polynomials for memory efficiency +- `RaPolynomial`: Lazy materialization via table-doubling state machine (TableRound → RoundN, materializes at 8 groups) +- `SharedRaPolynomials`: Shares eq tables across N polynomials via table-doubling (TableRound → RoundN, materializes at 16 groups) - `PrefixSuffixDecomposition`: Splits polynomial as `Σ P_i(prefix) · Q_i(suffix)` for efficient sumcheck - `MultilinearPolynomial`: Enum dispatching over all scalar types + OneHot/RLC variants diff --git a/jolt-core/src/poly/ra_poly.rs b/jolt-core/src/poly/ra_poly.rs index 691a943a54..f49e8adb75 100644 --- a/jolt-core/src/poly/ra_poly.rs +++ b/jolt-core/src/poly/ra_poly.rs @@ -14,24 +14,27 @@ use crate::{ utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec}, }; -/// Represents the state of an `ra_i` polynomial during the last log(T) sumcheck rounds. +/// When the table round reaches this many groups, the next bind materializes to dense. +const MATERIALIZE_THRESHOLD: usize = 8; + +/// RA polynomial with lazy materialization via a table-doubling state machine. /// -/// The first two rounds are specialized to reduce the amount of allocated memory. +/// Starts with 1 eq table group. Each bind doubles the tables (splitting on a new +/// challenge bit). After reaching `MATERIALIZE_THRESHOLD` groups, the next bind +/// materializes to a dense `MultilinearPolynomial`. #[derive(Allocative, Clone, Debug, PartialEq)] pub enum RaPolynomial + Copy + Default + Send + Sync + 'static, F: JoltField> { None, - Round1(RaPolynomialRound1), - Round2(RaPolynomialRound2), - Round3(RaPolynomialRound3), - Round4(RaPolynomialRound4), + TableRound(RaPolynomialTableRound), RoundN(MultilinearPolynomial), } impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPolynomial { pub fn new(lookup_indices: Arc>>, eq_evals: Vec) -> Self { - Self::Round1(RaPolynomialRound1 { - F: eq_evals, + Self::TableRound(RaPolynomialTableRound { + tables: vec![eq_evals], lookup_indices, + binding_order: BindingOrder::LowToHigh, }) } @@ -39,10 +42,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo pub fn get_bound_coeff(&self, j: usize) -> F { match self { Self::None => panic!("RaPolynomial::get_bound_coeff called on None"), - Self::Round1(mle) => mle.get_bound_coeff(j), - Self::Round2(mle) => mle.get_bound_coeff(j), - Self::Round3(mle) => mle.get_bound_coeff(j), - Self::Round4(mle) => mle.get_bound_coeff(j), + Self::TableRound(t) => t.get_bound_coeff(j), Self::RoundN(mle) => mle.get_bound_coeff(j), } } @@ -50,10 +50,7 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> RaPo pub fn len(&self) -> usize { match self { Self::None => panic!("RaPolynomial::len called on None"), - Self::Round1(mle) => mle.len(), - Self::Round2(mle) => mle.len(), - Self::Round3(mle) => mle.len(), - Self::Round4(mle) => mle.len(), + Self::TableRound(t) => t.len(), Self::RoundN(mle) => mle.len(), } } @@ -63,7 +60,11 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly for RaPolynomial { fn is_bound(&self) -> bool { - !matches!(self, Self::Round1(_)) + match self { + Self::TableRound(t) => t.n_groups() > 1, + Self::RoundN(_) => true, + Self::None => false, + } } fn bind(&mut self, _r: F::Challenge, _order: BindingOrder) { @@ -73,10 +74,13 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly fn bind_parallel(&mut self, r: F::Challenge, order: BindingOrder) { match self { Self::None => panic!("RaPolynomial::bind called on None"), - Self::Round1(mle) => *self = Self::Round2(mem::take(mle).bind(r, order)), - Self::Round2(mle) => *self = Self::Round3(mem::take(mle).bind(r, order)), - Self::Round3(mle) => *self = Self::Round4(mem::take(mle).bind(r, order)), - Self::Round4(mle) => *self = Self::RoundN(mem::take(mle).bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= MATERIALIZE_THRESHOLD { + *self = Self::RoundN(mem::take(t).materialize(r, order)); + } else { + *self = Self::TableRound(mem::take(t).bind(r, order)); + } + } Self::RoundN(mle) => mle.bind_parallel(r, order), }; } @@ -145,270 +149,98 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> Poly } } -/// Represents MLE `ra_i` during the 1st round of the last log(T) sumcheck rounds. -#[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound1 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r)`. - F: Vec, - lookup_indices: Arc>>, -} - -impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound1 -{ - fn len(&self) -> usize { - self.lookup_indices.len() - } - - #[tracing::instrument(skip_all, name = "RaPolynomialRound1::bind")] - fn bind(self, r0: F::Challenge, binding_order: BindingOrder) -> RaPolynomialRound2 { - // Construct lookup tables. - let eq_0_r0 = EqPolynomial::mle(&[F::zero()], &[r0]); - let eq_1_r0 = EqPolynomial::mle(&[F::one()], &[r0]); - let F_0 = self.F.iter().map(|v| eq_0_r0 * v).collect(); - let F_1 = self.F.iter().map(|v| eq_1_r0 * v).collect(); - drop_in_background_thread(self.F); - RaPolynomialRound2 { - F_0, - F_1, - lookup_indices: self.lookup_indices, - r0, - binding_order, - } - } - - #[inline] - fn get_bound_coeff(&self, j: usize) -> F { - // Lookup ra_i(r, j). - self.lookup_indices - .get(j) - .expect("j out of bounds") - .map_or(F::zero(), |i| self.F[i.into()]) - } -} - -/// Represents `ra_i` during the 2nd of the last log(T) sumcheck rounds. +/// Generic table round for RaPolynomial with `n_groups` eq table groups. /// -/// i.e. represents MLE `ra_i(r, r0, x)` +/// Tables are stored in LowToHigh interleaving order: after k binds, table at +/// index `i` corresponds to the bit pattern where bit_0 = r0_val, bit_1 = r1_val, etc. +/// (LSB-first encoding of the bound challenge values). #[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound2 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r_address_chunk_i) * eq(0, r0)`. - F_0: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(1, r0)`. - F_1: Vec, +pub struct RaPolynomialTableRound< + I: Into + Copy + Default + Send + Sync + 'static, + F: JoltField, +> { + tables: Vec>, lookup_indices: Arc>>, - r0: F::Challenge, binding_order: BindingOrder, } impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound2 + RaPolynomialTableRound { - fn len(&self) -> usize { - self.lookup_indices.len() / 2 - } - - #[tracing::instrument(skip_all, name = "RaPolynomialRound2::bind")] - fn bind(self, r1: F::Challenge, binding_order: BindingOrder) -> RaPolynomialRound3 { - assert_eq!(binding_order, self.binding_order); - // Construct lookup tables. - let eq_0_r1 = EqPolynomial::mle(&[F::zero()], &[r1]); - let eq_1_r1 = EqPolynomial::mle(&[F::one()], &[r1]); - let mut F_00: Vec = self.F_0.clone(); - let mut F_01: Vec = self.F_0; - let mut F_10: Vec = self.F_1.clone(); - let mut F_11: Vec = self.F_1; - - F_00.par_iter_mut().for_each(|f| *f *= eq_0_r1); - F_01.par_iter_mut().for_each(|f| *f *= eq_1_r1); - F_10.par_iter_mut().for_each(|f| *f *= eq_0_r1); - F_11.par_iter_mut().for_each(|f| *f *= eq_1_r1); - - RaPolynomialRound3 { - F_00, - F_01, - F_10, - F_11, - lookup_indices: self.lookup_indices, - r1, - binding_order: self.binding_order, - } - } - #[inline] - fn get_bound_coeff(&self, j: usize) -> F { - let mid = self.lookup_indices.len() / 2; - match self.binding_order { - BindingOrder::HighToLow => { - let H_0 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_0[i.into()]); - let H_1 = self.lookup_indices[mid + j].map_or(F::zero(), |i| self.F_1[i.into()]); - // Compute ra_i(r, r0, j) = eq(0, r0) * ra_i(r, 0, j) + - // eq(1, r0) * ra_i(r, 1, j) - H_0 + H_1 - } - BindingOrder::LowToHigh => { - let H_0 = self.lookup_indices[2 * j].map_or(F::zero(), |i| self.F_0[i.into()]); - let H_1 = self.lookup_indices[2 * j + 1].map_or(F::zero(), |i| self.F_1[i.into()]); - // Compute ra_i(r, r0, j) = eq(0, r0) * ra_i(r, 0, j) + - // eq(1, r0) * ra_i(r, 1, j) - H_0 + H_1 - } - } + fn n_groups(&self) -> usize { + self.tables.len() } -} -/// Represents `ra_i` during the 3nd of the last log(T) sumcheck rounds. -/// -/// i.e. represents MLE `ra_i(r, r0, x)` -#[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound3 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - // Index `x` stores `eq(x, r_address_chunk_i) * eq(00, r0 r1)`. - F_00: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(01, r0 r1)`. - F_01: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(10, r0 r1)`. - F_10: Vec, - // Index `x` stores `eq(x, r_address_chunk_i) * eq(11, r0 r1)`. - F_11: Vec, - lookup_indices: Arc>>, - r1: F::Challenge, - binding_order: BindingOrder, -} - -impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound3 -{ fn len(&self) -> usize { - self.lookup_indices.len() / 4 + self.lookup_indices.len() / self.n_groups() } - #[tracing::instrument(skip_all, name = "RaPolynomialRound3::bind")] - fn bind(self, r2: F::Challenge, _binding_order: BindingOrder) -> RaPolynomialRound4 { - let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); - let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); - let mut F_000: Vec = self.F_00.clone(); - let mut F_001: Vec = self.F_00; - let mut F_010: Vec = self.F_01.clone(); - let mut F_011: Vec = self.F_01; - let mut F_100: Vec = self.F_10.clone(); - let mut F_101: Vec = self.F_10; - let mut F_110: Vec = self.F_11.clone(); - let mut F_111: Vec = self.F_11; - - F_000.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_010.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_100.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_110.par_iter_mut().for_each(|f| *f *= eq_0_r2); - F_001.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_011.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_101.par_iter_mut().for_each(|f| *f *= eq_1_r2); - F_111.par_iter_mut().for_each(|f| *f *= eq_1_r2); - - RaPolynomialRound4 { - tables: [F_000, F_100, F_010, F_110, F_001, F_101, F_011, F_111], - lookup_indices: self.lookup_indices, - binding_order: self.binding_order, + /// Double tables from N to 2N groups by splitting on a new challenge. + /// First N groups get scaled by eq(0, r), second N by eq(1, r). + fn double_tables(tables: Vec>, r: F::Challenge) -> Vec> { + let eq_0 = EqPolynomial::mle(&[F::zero()], &[r]); + let eq_1 = EqPolynomial::mle(&[F::one()], &[r]); + let n = tables.len(); + let mut doubled: Vec> = Vec::with_capacity(2 * n); + for t in &tables { + doubled.push(t.clone()); } - } - - #[inline] - fn get_bound_coeff(&self, j: usize) -> F { - match self.binding_order { - BindingOrder::HighToLow => { - let n = self.lookup_indices.len() / 4; - let H_00 = self.lookup_indices[j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_01 = self.lookup_indices[j + n].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_10 = - self.lookup_indices[j + n * 2].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_11 = - self.lookup_indices[j + n * 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 - } - BindingOrder::LowToHigh => { - let H_00 = self.lookup_indices[4 * j].map_or(F::zero(), |i| self.F_00[i.into()]); - let H_10 = - self.lookup_indices[4 * j + 1].map_or(F::zero(), |i| self.F_10[i.into()]); - let H_01 = - self.lookup_indices[4 * j + 2].map_or(F::zero(), |i| self.F_01[i.into()]); - let H_11 = - self.lookup_indices[4 * j + 3].map_or(F::zero(), |i| self.F_11[i.into()]); - H_00 + H_10 + H_01 + H_11 - } + for t in tables { + doubled.push(t); } - } -} - -/// Round 4: 8 eq tables, delays materialization one more round. -#[derive(Allocative, Default, Clone, Debug, PartialEq)] -pub struct RaPolynomialRound4 + Copy + Default + Send + Sync + 'static, F: JoltField> -{ - tables: [Vec; 8], - lookup_indices: Arc>>, - binding_order: BindingOrder, -} - -impl + Copy + Default + Send + Sync + 'static, F: JoltField> - RaPolynomialRound4 -{ - fn len(&self) -> usize { - self.lookup_indices.len() / 8 - } - - #[tracing::instrument(skip_all, name = "RaPolynomialRound4::bind")] - fn bind(self, r3: F::Challenge, _binding_order: BindingOrder) -> MultilinearPolynomial { - let eq_0_r3 = EqPolynomial::mle(&[F::zero()], &[r3]); - let eq_1_r3 = EqPolynomial::mle(&[F::one()], &[r3]); - - // 16 groups: [0..8) for bit3=0 (eq_0_r3), [8..16) for bit3=1 (eq_1_r3) - let [t0, t1, t2, t3, t4, t5, t6, t7] = self.tables; - let mut tables: [Vec; 16] = [ - t0.clone(), - t1.clone(), - t2.clone(), - t3.clone(), - t4.clone(), - t5.clone(), - t6.clone(), - t7.clone(), - t0, - t1, - t2, - t3, - t4, - t5, - t6, - t7, - ]; - - let (lo, hi) = tables.split_at_mut(8); + let (lo, hi) = doubled.split_at_mut(n); rayon::join( || { lo.par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r3)) + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0)) }, || { hi.par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r3)) + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1)) }, ); + doubled + } + #[tracing::instrument(skip_all, name = "RaPolynomialTableRound::bind")] + fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { + if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + } + Self { + tables: Self::double_tables(self.tables, r), + lookup_indices: self.lookup_indices, + binding_order: order, + } + } + + #[tracing::instrument(skip_all, name = "RaPolynomialTableRound::materialize")] + fn materialize(self, r: F::Challenge, order: BindingOrder) -> MultilinearPolynomial { + let binding_order = if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + self.binding_order + } else { + order + }; + let tables = Self::double_tables(self.tables, r); + let n_groups = tables.len(); let lookup_indices = &self.lookup_indices; - let n = lookup_indices.len() / 16; + let n = lookup_indices.len() / n_groups; let mut res: Vec = unsafe_allocate_zero_vec(n); let chunk_size = 1 << 16; - match self.binding_order { + match binding_order { BindingOrder::HighToLow => { + let n_bits = n_groups.trailing_zeros() as usize; res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - *eval = (0..16) + *eval = (0..n_groups) .map(|seg| { + let table_idx = bit_reverse(seg, n_bits); lookup_indices[seg * n + j] - .map_or(F::zero(), |i| tables[seg][i.into()]) + .map_or(F::zero(), |i| tables[table_idx][i.into()]) }) .sum(); } @@ -419,9 +251,9 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> res.par_chunks_mut(chunk_size).enumerate().for_each( |(chunk_index, evals_chunk)| { for (j, eval) in zip(chunk_index * chunk_size.., evals_chunk) { - *eval = (0..16) + *eval = (0..n_groups) .map(|offset| { - lookup_indices[16 * j + offset] + lookup_indices[n_groups * j + offset] .map_or(F::zero(), |i| tables[offset][i.into()]) }) .sum(); @@ -437,22 +269,40 @@ impl + Copy + Default + Send + Sync + 'static, F: JoltField> #[inline] fn get_bound_coeff(&self, j: usize) -> F { + let n_groups = self.n_groups(); match self.binding_order { BindingOrder::HighToLow => { - let n = self.lookup_indices.len() / 8; - (0..8) + let segment = self.lookup_indices.len() / n_groups; + let n_bits = n_groups.trailing_zeros() as usize; + (0..n_groups) .map(|seg| { - self.lookup_indices[seg * n + j] - .map_or(F::zero(), |i| self.tables[seg][i.into()]) + let table_idx = bit_reverse(seg, n_bits); + self.lookup_indices[seg * segment + j] + .map_or(F::zero(), |i| self.tables[table_idx][i.into()]) }) .sum() } - BindingOrder::LowToHigh => (0..8) + BindingOrder::LowToHigh => (0..n_groups) .map(|offset| { - self.lookup_indices[8 * j + offset] + self.lookup_indices[n_groups * j + offset] .map_or(F::zero(), |i| self.tables[offset][i.into()]) }) .sum(), } } } + +/// Reverse the lowest `bits` bits of `x`. +#[inline] +pub(crate) fn bit_reverse(x: usize, bits: usize) -> usize { + if bits == 0 { + return 0; + } + let mut result = 0; + let mut x = x; + for _ in 0..bits { + result = (result << 1) | (x & 1); + x >>= 1; + } + result +} diff --git a/jolt-core/src/poly/shared_ra_polys.rs b/jolt-core/src/poly/shared_ra_polys.rs index 5fea66dba7..2a46b4f28d 100644 --- a/jolt-core/src/poly/shared_ra_polys.rs +++ b/jolt-core/src/poly/shared_ra_polys.rs @@ -32,7 +32,6 @@ use std::sync::Arc; use crate::field::JoltField; use crate::poly::eq_poly::EqPolynomial; use crate::poly::multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding}; -use crate::utils::thread::drop_in_background_thread; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::config::OneHotParams; @@ -445,93 +444,27 @@ fn compute_all_G_impl( ) } -/// Shared RA polynomials that use a single eq table for all polynomials. +/// When the table round reaches this many groups, the next bind materializes to dense. +const SHARED_MATERIALIZE_THRESHOLD: usize = 16; + +/// Shared RA polynomials using table-doubling state machine. /// /// Instead of N separate `RaPolynomial` each with their own eq table copy, -/// this stores: -/// - ONE (small) eq table per polynomial (or split tables for later rounds) -/// - `Vec` (size T, non-transposed) -/// -/// This saves memory and improves cache locality. +/// stores per-polynomial eq tables and a single shared `Vec`. +/// Each bind doubles the table groups; once the threshold is reached, +/// the next bind materializes to dense `MultilinearPolynomial`s. #[derive(Allocative)] pub enum SharedRaPolynomials { - /// Round 1: Single shared eq table - Round1(SharedRaRound1), - /// Round 2: Split into F_0, F_1 - Round2(SharedRaRound2), - /// Round 3: Split into F_00, F_01, F_10, F_11 - Round3(SharedRaRound3), - /// Round 4: Split into 8 table groups (F_000 through F_111) - Round4(SharedRaRound4), - /// Round 5: Split into 16 table groups - Round5(SharedRaRound5), - /// Round N: Fully materialized multilinear polynomials + TableRound(SharedRaTableRound), RoundN(Vec>), } -/// Round 1 state: single shared eq table -#[derive(Allocative, Default)] -pub struct SharedRaRound1 { - /// Per-polynomial eq tables: tables[poly_idx][k] for k in 0..K - /// - /// In the booleanity sumcheck, these tables may already be pre-scaled by a per-polynomial - /// constant (e.g. a batching coefficient). - tables: Vec>, - /// RA indices for all cycles (non-transposed) - indices: Arc>, - /// Number of polynomials - num_polys: usize, - /// OneHotParams for index extraction - #[allocative(skip)] - one_hot_params: OneHotParams, -} - -/// Round 2 state: split eq tables -#[derive(Allocative, Default)] -pub struct SharedRaRound2 { - /// Per-polynomial tables for the 0-branch: tables_0[poly_idx][k] - tables_0: Vec>, - /// Per-polynomial tables for the 1-branch: tables_1[poly_idx][k] - tables_1: Vec>, - /// RA indices for all cycles - indices: Arc>, - num_polys: usize, - #[allocative(skip)] - one_hot_params: OneHotParams, - binding_order: BindingOrder, -} - -/// Round 3 state: further split eq tables -#[derive(Allocative, Default)] -pub struct SharedRaRound3 { - tables_00: Vec>, - tables_01: Vec>, - tables_10: Vec>, - tables_11: Vec>, - indices: Arc>, - num_polys: usize, - #[allocative(skip)] - one_hot_params: OneHotParams, - binding_order: BindingOrder, -} - -/// Round 4 state: 8 table groups per polynomial. -/// Delays materialization one more round to reduce peak memory. -#[derive(Allocative, Default)] -pub struct SharedRaRound4 { - /// tables[group][poly_idx][k] — 8 groups of per-poly eq tables - tables: [Vec>; 8], - indices: Arc>, - num_polys: usize, - #[allocative(skip)] - one_hot_params: OneHotParams, - binding_order: BindingOrder, -} - -/// Round 5 state: 16 table groups per polynomial. +/// Generic table round for SharedRaPolynomials with `n_groups` eq table groups. +/// +/// `tables[group_idx][poly_idx][k]` — tables are in LowToHigh interleaving order. #[derive(Allocative)] -pub struct SharedRaRound5 { - tables: Vec>>, // tables[group_idx][poly_idx][k], 16 groups +pub struct SharedRaTableRound { + tables: Vec>>, indices: Arc>, num_polys: usize, #[allocative(skip)] @@ -539,7 +472,7 @@ pub struct SharedRaRound5 { binding_order: BindingOrder, } -impl Default for SharedRaRound5 { +impl Default for SharedRaTableRound { fn default() -> Self { Self { tables: Vec::new(), @@ -552,7 +485,6 @@ impl Default for SharedRaRound5 { } impl SharedRaPolynomials { - /// Create new SharedRaPolynomials from eq table and indices. pub fn new( tables: Vec>, indices: Arc>, @@ -566,16 +498,15 @@ impl SharedRaPolynomials { tables.len(), num_polys ); - Self::Round1(SharedRaRound1 { - tables, + Self::TableRound(SharedRaTableRound { + tables: vec![tables], indices, num_polys, one_hot_params, + binding_order: BindingOrder::LowToHigh, }) } - /// Create SharedRaPolynomials that only uses instruction polys from the shared indices. - /// `tables` should have exactly `instruction_d` entries. pub fn new_instruction_only( tables: Vec>, indices: Arc>, @@ -583,52 +514,37 @@ impl SharedRaPolynomials { ) -> Self { let num_polys = one_hot_params.instruction_d; debug_assert_eq!(tables.len(), num_polys); - Self::Round1(SharedRaRound1 { - tables, + Self::TableRound(SharedRaTableRound { + tables: vec![tables], indices, num_polys, one_hot_params, + binding_order: BindingOrder::LowToHigh, }) } - /// Get the number of polynomials pub fn num_polys(&self) -> usize { match self { - Self::Round1(r) => r.num_polys, - Self::Round2(r) => r.num_polys, - Self::Round3(r) => r.num_polys, - Self::Round4(r) => r.num_polys, - Self::Round5(r) => r.num_polys, + Self::TableRound(t) => t.num_polys, Self::RoundN(polys) => polys.len(), } } - /// Get the current length (number of cycles / 2^rounds_so_far) pub fn len(&self) -> usize { match self { - Self::Round1(r) => r.indices.len(), - Self::Round2(r) => r.indices.len() / 2, - Self::Round3(r) => r.indices.len() / 4, - Self::Round4(r) => r.indices.len() / 8, - Self::Round5(r) => r.indices.len() / 16, + Self::TableRound(t) => t.len(), Self::RoundN(polys) => polys[0].len(), } } - /// Get bound coefficient for polynomial `poly_idx` at position `j` #[inline] pub fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { match self { - Self::Round1(r) => r.get_bound_coeff(poly_idx, j), - Self::Round2(r) => r.get_bound_coeff(poly_idx, j), - Self::Round3(r) => r.get_bound_coeff(poly_idx, j), - Self::Round4(r) => r.get_bound_coeff(poly_idx, j), - Self::Round5(r) => r.get_bound_coeff(poly_idx, j), + Self::TableRound(t) => t.get_bound_coeff(poly_idx, j), Self::RoundN(polys) => polys[poly_idx].get_bound_coeff(j), } } - /// Get final sumcheck claim for polynomial `poly_idx` pub fn final_sumcheck_claim(&self, poly_idx: usize) -> F { match self { Self::RoundN(polys) => polys[poly_idx].final_sumcheck_claim(), @@ -636,15 +552,15 @@ impl SharedRaPolynomials { } } - /// Bind with a challenge, transitioning to next round state. - /// Consumes self and returns the new state. pub fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { match self { - Self::Round1(r1) => Self::Round2(r1.bind(r, order)), - Self::Round2(r2) => Self::Round3(r2.bind(r, order)), - Self::Round3(r3) => Self::Round4(r3.bind(r, order)), - Self::Round4(r4) => Self::Round5(r4.bind(r, order)), - Self::Round5(r5) => Self::RoundN(r5.bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= SHARED_MATERIALIZE_THRESHOLD { + Self::RoundN(t.materialize(r, order)) + } else { + Self::TableRound(t.bind(r, order)) + } + } Self::RoundN(mut polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); Self::RoundN(polys) @@ -652,14 +568,15 @@ impl SharedRaPolynomials { } } - /// Bind in place with a challenge, transitioning to next round state. pub fn bind_in_place(&mut self, r: F::Challenge, order: BindingOrder) { match self { - Self::Round1(r1) => *self = Self::Round2(std::mem::take(r1).bind(r, order)), - Self::Round2(r2) => *self = Self::Round3(std::mem::take(r2).bind(r, order)), - Self::Round3(r3) => *self = Self::Round4(std::mem::take(r3).bind(r, order)), - Self::Round4(r4) => *self = Self::Round5(std::mem::take(r4).bind(r, order)), - Self::Round5(r5) => *self = Self::RoundN(std::mem::take(r5).bind(r, order)), + Self::TableRound(t) => { + if t.n_groups() >= SHARED_MATERIALIZE_THRESHOLD { + *self = Self::RoundN(std::mem::take(t).materialize(r, order)); + } else { + *self = Self::TableRound(std::mem::take(t).bind(r, order)); + } + } Self::RoundN(polys) => { polys.par_iter_mut().for_each(|p| p.bind_parallel(r, order)); } @@ -667,372 +584,79 @@ impl SharedRaPolynomials { } } -impl SharedRaRound1 { +impl SharedRaTableRound { #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[poly_idx][k as usize]) - } - - fn bind(self, r0: F::Challenge, order: BindingOrder) -> SharedRaRound2 { - let eq_0_r0 = EqPolynomial::mle(&[F::zero()], &[r0]); - let eq_1_r0 = EqPolynomial::mle(&[F::one()], &[r0]); - let (tables_0, tables_1) = rayon::join( - || { - self.tables - .par_iter() - .map(|t| t.iter().map(|v| eq_0_r0 * v).collect::>()) - .collect::>>() - }, - || { - self.tables - .par_iter() - .map(|t| t.iter().map(|v| eq_1_r0 * v).collect::>()) - .collect::>>() - }, - ); - drop_in_background_thread(self.tables); - - SharedRaRound2 { - tables_0, - tables_1, - indices: self.indices, - num_polys: self.num_polys, - one_hot_params: self.one_hot_params, - binding_order: order, - } + fn n_groups(&self) -> usize { + self.tables.len() } -} -impl SharedRaRound2 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::HighToLow => { - let mid = self.indices.len() / 2; - let h_0 = self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_0[poly_idx][k as usize]); - let h_1 = self.indices[mid + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_1[poly_idx][k as usize]); - h_0 + h_1 - } - BindingOrder::LowToHigh => { - let h_0 = self.indices[2 * j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_0[poly_idx][k as usize]); - let h_1 = self.indices[2 * j + 1] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_1[poly_idx][k as usize]); - h_0 + h_1 - } - } + fn len(&self) -> usize { + self.indices.len() / self.n_groups() } - fn bind(self, r1: F::Challenge, order: BindingOrder) -> SharedRaRound3 { - assert_eq!(order, self.binding_order); - let eq_0_r1 = EqPolynomial::mle(&[F::zero()], &[r1]); - let eq_1_r1 = EqPolynomial::mle(&[F::one()], &[r1]); - - let mut tables_00 = self.tables_0.clone(); - let mut tables_01 = self.tables_0; - let mut tables_10 = self.tables_1.clone(); - let mut tables_11 = self.tables_1; - - // Scale all four groups in parallel. - rayon::join( - || { - rayon::join( - || { - tables_00 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r1)) - }, - || { - tables_01 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r1)) - }, - ) - }, - || { - rayon::join( - || { - tables_10 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r1)) - }, - || { - tables_11 - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r1)) - }, - ) - }, - ); - - SharedRaRound3 { - tables_00, - tables_01, - tables_10, - tables_11, - indices: self.indices, - num_polys: self.num_polys, - one_hot_params: self.one_hot_params, - binding_order: order, + fn double_tables(tables: Vec>>, r: F::Challenge) -> Vec>> { + let eq_0 = EqPolynomial::mle(&[F::zero()], &[r]); + let eq_1 = EqPolynomial::mle(&[F::one()], &[r]); + let n = tables.len(); + let mut doubled: Vec>> = Vec::with_capacity(2 * n); + for t in &tables { + doubled.push(t.clone()); } - } -} - -impl SharedRaRound3 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::HighToLow => { - let quarter = self.indices.len() / 4; - let h_00 = self.indices[j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_00[poly_idx][k as usize]); - let h_01 = self.indices[quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_01[poly_idx][k as usize]); - let h_10 = self.indices[2 * quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_10[poly_idx][k as usize]); - let h_11 = self.indices[3 * quarter + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_11[poly_idx][k as usize]); - h_00 + h_01 + h_10 + h_11 - } - BindingOrder::LowToHigh => { - // Bit pattern for offset: (r1, r0), so offset 1 = r0=1,r1=0 → F_10 - let h_00 = self.indices[4 * j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_00[poly_idx][k as usize]); - let h_10 = self.indices[4 * j + 1] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_10[poly_idx][k as usize]); - let h_01 = self.indices[4 * j + 2] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_01[poly_idx][k as usize]); - let h_11 = self.indices[4 * j + 3] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables_11[poly_idx][k as usize]); - h_00 + h_10 + h_01 + h_11 - } + for t in tables { + doubled.push(t); } - } - - #[tracing::instrument(skip_all, name = "SharedRaRound3::bind")] - fn bind(self, r2: F::Challenge, order: BindingOrder) -> SharedRaRound4 { - assert_eq!(order, self.binding_order); - - let eq_0_r2 = EqPolynomial::mle(&[F::zero()], &[r2]); - let eq_1_r2 = EqPolynomial::mle(&[F::one()], &[r2]); - - let mut tables_000 = self.tables_00.clone(); - let mut tables_001 = self.tables_00; - let mut tables_010 = self.tables_01.clone(); - let mut tables_011 = self.tables_01; - let mut tables_100 = self.tables_10.clone(); - let mut tables_101 = self.tables_10; - let mut tables_110 = self.tables_11.clone(); - let mut tables_111 = self.tables_11; - + let (lo, hi) = doubled.split_at_mut(n); rayon::join( || { - [ - &mut tables_000, - &mut tables_010, - &mut tables_100, - &mut tables_110, - ] - .into_par_iter() - .for_each(|table| { - table + lo.par_iter_mut().for_each(|group| { + group .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r2)) + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0)) }) }, || { - [ - &mut tables_001, - &mut tables_011, - &mut tables_101, - &mut tables_111, - ] - .into_par_iter() - .for_each(|table| { - table + hi.par_iter_mut().for_each(|group| { + group .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r2)) + .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1)) }) }, ); - - SharedRaRound4 { - tables: [ - tables_000, tables_100, tables_010, tables_110, tables_001, tables_101, tables_011, - tables_111, - ], - indices: self.indices, - num_polys: self.num_polys, - one_hot_params: self.one_hot_params, - binding_order: order, - } + doubled } -} -impl SharedRaRound4 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::LowToHigh => (0..8) - .map(|offset| { - self.indices[8 * j + offset] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) - }) - .sum(), - BindingOrder::HighToLow => { - let eighth = self.indices.len() / 8; - (0..8) - .map(|seg| { - self.indices[seg * eighth + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[seg][poly_idx][k as usize]) - }) - .sum() - } + #[tracing::instrument(skip_all, name = "SharedRaTableRound::bind")] + fn bind(self, r: F::Challenge, order: BindingOrder) -> Self { + if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); } - } - - #[tracing::instrument(skip_all, name = "SharedRaRound4::bind")] - fn bind(self, r3: F::Challenge, order: BindingOrder) -> SharedRaRound5 { - assert_eq!(order, self.binding_order); - - let eq_0_r3 = EqPolynomial::mle(&[F::zero()], &[r3]); - let eq_1_r3 = EqPolynomial::mle(&[F::one()], &[r3]); - - let [t0, t1, t2, t3, t4, t5, t6, t7] = self.tables; - let mut tables_vec: Vec>> = vec![ - t0.clone(), - t1.clone(), - t2.clone(), - t3.clone(), - t4.clone(), - t5.clone(), - t6.clone(), - t7.clone(), - t0, - t1, - t2, - t3, - t4, - t5, - t6, - t7, - ]; - - let (lo, hi) = tables_vec.split_at_mut(8); - rayon::join( - || { - lo.par_iter_mut().for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r3)) - }) - }, - || { - hi.par_iter_mut().for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r3)) - }) - }, - ); - - SharedRaRound5 { - tables: tables_vec, + Self { + tables: Self::double_tables(self.tables, r), indices: self.indices, num_polys: self.num_polys, one_hot_params: self.one_hot_params, binding_order: order, } } -} - -impl SharedRaRound5 { - #[inline] - fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { - match self.binding_order { - BindingOrder::LowToHigh => (0..16) - .map(|offset| { - self.indices[16 * j + offset] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) - }) - .sum(), - BindingOrder::HighToLow => { - let sixteenth = self.indices.len() / 16; - (0..16) - .map(|seg| { - self.indices[seg * sixteenth + j] - .get_index(poly_idx, &self.one_hot_params) - .map_or(F::zero(), |k| self.tables[seg][poly_idx][k as usize]) - }) - .sum() - } - } - } - #[tracing::instrument(skip_all, name = "SharedRaRound5::bind")] - fn bind(self, r4: F::Challenge, order: BindingOrder) -> Vec> { - assert_eq!(order, self.binding_order); - - let eq_0_r4 = EqPolynomial::mle(&[F::zero()], &[r4]); - let eq_1_r4 = EqPolynomial::mle(&[F::one()], &[r4]); - - // 32 groups from 16: [0..16) for bit4=0, [16..32) for bit4=1 - let n_groups = self.tables.len(); - let mut tables: Vec>> = Vec::with_capacity(n_groups * 2); - for t in &self.tables { - tables.push(t.clone()); - } - for t in self.tables { - tables.push(t); - } - - let (lo, hi) = tables.split_at_mut(n_groups); - rayon::join( - || { - lo.par_iter_mut().for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_0_r4)) - }) - }, - || { - hi.par_iter_mut().for_each(|table| { - table - .par_iter_mut() - .for_each(|t| t.par_iter_mut().for_each(|f| *f *= eq_1_r4)) - }) - }, - ); - - let num_polys = self.num_polys; + #[tracing::instrument(skip_all, name = "SharedRaTableRound::materialize")] + fn materialize(self, r: F::Challenge, order: BindingOrder) -> Vec> { + let binding_order = if self.n_groups() > 1 { + assert_eq!(order, self.binding_order); + self.binding_order + } else { + order + }; + let tables = Self::double_tables(self.tables, r); + let n_total = tables.len(); let indices = &self.indices; let one_hot_params = &self.one_hot_params; - let n_total = tables.len(); // 32 let new_len = indices.len() / n_total; - (0..num_polys) + (0..self.num_polys) .into_par_iter() .map(|poly_idx| { - let coeffs: Vec = match order { + let coeffs: Vec = match binding_order { BindingOrder::LowToHigh => (0..new_len) .into_par_iter() .map(|j| { @@ -1046,16 +670,19 @@ impl SharedRaRound5 { }) .collect(), BindingOrder::HighToLow => { + let n_bits = n_total.trailing_zeros() as usize; let segment = indices.len() / n_total; (0..new_len) .into_par_iter() .map(|j| { (0..n_total) .map(|seg| { + let table_idx = + crate::poly::ra_poly::bit_reverse(seg, n_bits); indices[seg * segment + j] .get_index(poly_idx, one_hot_params) .map_or(F::zero(), |k| { - tables[seg][poly_idx][k as usize] + tables[table_idx][poly_idx][k as usize] }) }) .sum() @@ -1067,20 +694,30 @@ impl SharedRaRound5 { }) .collect() } -} -/// Compute all RaIndices in parallel (non-transposed). -/// -/// Returns one `RaIndices` per cycle. -#[tracing::instrument(skip_all, name = "shared_ra_polys::compute_ra_indices")] -pub fn compute_ra_indices( - trace: &[Cycle], - bytecode: &BytecodePreprocessing, - memory_layout: &MemoryLayout, - one_hot_params: &OneHotParams, -) -> Vec { - trace - .par_iter() - .map(|cycle| RaIndices::from_cycle(cycle, bytecode, memory_layout, one_hot_params)) - .collect() + #[inline] + fn get_bound_coeff(&self, poly_idx: usize, j: usize) -> F { + let n_groups = self.n_groups(); + match self.binding_order { + BindingOrder::HighToLow => { + let segment = self.indices.len() / n_groups; + let n_bits = n_groups.trailing_zeros() as usize; + (0..n_groups) + .map(|seg| { + let table_idx = crate::poly::ra_poly::bit_reverse(seg, n_bits); + self.indices[seg * segment + j] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[table_idx][poly_idx][k as usize]) + }) + .sum() + } + BindingOrder::LowToHigh => (0..n_groups) + .map(|offset| { + self.indices[n_groups * j + offset] + .get_index(poly_idx, &self.one_hot_params) + .map_or(F::zero(), |k| self.tables[offset][poly_idx][k as usize]) + }) + .sum(), + } + } } diff --git a/jolt-core/src/subprotocols/mles_product_sum.rs b/jolt-core/src/subprotocols/mles_product_sum.rs index f885652b69..b4d63fea50 100644 --- a/jolt-core/src/subprotocols/mles_product_sum.rs +++ b/jolt-core/src/subprotocols/mles_product_sum.rs @@ -791,8 +791,6 @@ fn eval_prod_8_assign(p: &[(F, F); 8], outputs: &mut [F]) { /// correct alignment. /// - The sub-slices are non-overlapping. fn eval_prod_9_accumulate(p: &[(F, F); 9], outputs: &mut [F::UnreducedProductAccum]) { - // TODO: Implement more optimal way to do this. - // 5x4 split probably better than current 8x1 split. let p8 = p[0..8].try_into().unwrap(); let [a1, a2, a3, a4, a5, a6, a7, a8, a_inf] = eval_linear_prod_8_internal(p8);